zernel/commands/
profile.rs1use anyhow::{Context, Result};
9use clap::Subcommand;
10use std::process::Command;
11
12#[derive(Subcommand)]
13pub enum ProfileCommands {
14 Run {
16 script: String,
18 #[arg(long, default_value = "10")]
20 steps: u32,
21 },
22 Timeline {
24 #[arg(long, default_value = "30")]
26 duration: u32,
27 },
28 Cuda {
30 script: String,
32 },
33}
34
35pub async fn run(cmd: ProfileCommands) -> Result<()> {
36 match cmd {
37 ProfileCommands::Run { script, steps } => {
38 println!("Zernel Training Profiler");
39 println!("{}", "=".repeat(60));
40 println!(" Script: {script}");
41 println!(" Steps: {steps}");
42 println!();
43
44 let profile_code = format!(
46 r#"
47import torch
48import torch.autograd.profiler as profiler
49import time
50import sys
51import os
52
53# Enable CUDA synchronous execution for accurate timing
54os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
55
56print("Running profiler...")
57print()
58
59# Import and run the training script with profiling
60with profiler.profile(use_cuda=True, profile_memory=True, record_shapes=True) as prof:
61 # Execute the training script
62 exec(open('{script}').read())
63
64# Print profiler results
65print()
66print("=" * 70)
67print("ZERNEL TRAINING PROFILE")
68print("-" * 70)
69print()
70
71# Table sorted by CUDA time
72print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
73
74print()
75print("-" * 70)
76
77# Memory summary
78if torch.cuda.is_available():
79 print()
80 print("GPU Memory Summary:")
81 print(f" Peak allocated: {{torch.cuda.max_memory_allocated() / 1e9:.2f}} GB")
82 print(f" Peak reserved: {{torch.cuda.max_memory_reserved() / 1e9:.2f}} GB")
83 print(f" Current: {{torch.cuda.memory_allocated() / 1e9:.2f}} GB")
84
85# Export chrome trace
86trace_file = '/tmp/zernel_profile_trace.json'
87prof.export_chrome_trace(trace_file)
88print(f)
89print(f"Chrome trace saved to: {{trace_file}}")
90print("View in Chrome: chrome://tracing → Load → select the file")
91"#
92 );
93
94 let wrapper_path = "/tmp/zernel_profile_wrapper.py";
95 std::fs::write(wrapper_path, &profile_code)?;
96
97 let status = tokio::process::Command::new("python3")
98 .arg(wrapper_path)
99 .env("CUDA_LAUNCH_BLOCKING", "1")
100 .status()
101 .await
102 .with_context(|| format!("failed to run {script}"))?;
103
104 if !status.success() {
105 println!("PyTorch profiler failed. Running basic timing analysis...");
107 println!();
108
109 let basic_code = r#"
110import torch, time
111
112if torch.cuda.is_available():
113 # Measure H2D transfer
114 data = torch.randn(256, 3, 224, 224)
115 torch.cuda.synchronize()
116 t0 = time.time()
117 for _ in range(10):
118 gpu_data = data.cuda()
119 del gpu_data
120 torch.cuda.synchronize()
121 t1 = time.time()
122 h2d_ms = (t1 - t0) / 10 * 1000
123 print(f" H2D transfer: {{h2d_ms:.2f}} ms/batch")
124
125 # Measure allocation
126 torch.cuda.synchronize()
127 t0 = time.time()
128 for _ in range(100):
129 x = torch.empty(1024*1024*64, device='cuda')
130 del x
131 torch.cuda.synchronize()
132 t1 = time.time()
133 alloc_us = (t1 - t0) / 100 * 1e6
134 print(f" GPU alloc (256MB): {{alloc_us:.0f}} us")
135
136 # Measure sync overhead
137 torch.cuda.synchronize()
138 t0 = time.time()
139 for _ in range(1000):
140 torch.cuda.synchronize()
141 t1 = time.time()
142 sync_us = (t1 - t0) / 1000 * 1e6
143 print(f" CUDA sync: {{sync_us:.1f}} us")
144
145 print()
146 print(f" Peak GPU memory: {{torch.cuda.max_memory_allocated() / 1e9:.2f}} GB")
147"#;
148 let _ = tokio::process::Command::new("python3")
149 .args(["-c", basic_code])
150 .status()
151 .await;
152 }
153 }
154
155 ProfileCommands::Timeline { duration } => {
156 println!("GPU Utilization Timeline ({duration}s)");
157 println!("{}", "=".repeat(60));
158 println!();
159
160 let start = std::time::Instant::now();
161 let mut samples = Vec::new();
162
163 while start.elapsed().as_secs() < duration as u64 {
164 let output = Command::new("nvidia-smi")
165 .args([
166 "--query-gpu=utilization.gpu,memory.used,power.draw",
167 "--format=csv,noheader,nounits",
168 ])
169 .output();
170
171 if let Ok(o) = output {
172 let stdout = String::from_utf8_lossy(&o.stdout);
173 if let Some(line) = stdout.lines().next() {
174 let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
175 if f.len() >= 3 {
176 let util: u32 = f[0].parse().unwrap_or(0);
177 let elapsed = start.elapsed().as_secs();
178 let bar_len = (util as usize * 40) / 100;
179 let bar =
180 format!("{}{}", "#".repeat(bar_len), " ".repeat(40 - bar_len));
181 println!(
182 " {:>3}s [{bar}] {:>3}% {:>5}MB {:>5}W",
183 elapsed, util, f[1], f[2]
184 );
185 samples.push(util);
186 }
187 }
188 }
189
190 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
191 }
192
193 if !samples.is_empty() {
194 let avg: u32 = samples.iter().sum::<u32>() / samples.len() as u32;
195 let max = samples.iter().max().unwrap_or(&0);
196 let min = samples.iter().min().unwrap_or(&0);
197 println!();
198 println!("Summary: avg={avg}% min={min}% max={max}%");
199
200 if avg < 50 {
201 println!(" → GPU underutilized. Run: zernel debug why-slow");
202 }
203 }
204 }
205
206 ProfileCommands::Cuda { script } => {
207 println!("CUDA Operation Profile: {script}");
208 println!("{}", "=".repeat(60));
209 println!();
210 println!("Running with NVIDIA Nsight Systems...");
211
212 let nsys_check = Command::new("nsys").arg("--version").output();
213 match nsys_check {
214 Ok(o) if o.status.success() => {
215 let status = tokio::process::Command::new("nsys")
216 .args([
217 "profile",
218 "--stats=true",
219 "--output=/tmp/zernel_nsys",
220 "python3",
221 &script,
222 ])
223 .status()
224 .await?;
225
226 if status.success() {
227 println!("Profile saved to: /tmp/zernel_nsys.nsys-rep");
228 println!("View with: nsys-ui /tmp/zernel_nsys.nsys-rep");
229 }
230 }
231 _ => {
232 println!(" nsys not found. Using PyTorch CUDA profiler instead...");
233 let status = tokio::process::Command::new("python3")
234 .args(["-c", &format!(
235 "import torch; torch.cuda.cudart().cudaProfilerStart(); exec(open('{script}').read()); torch.cuda.cudart().cudaProfilerStop()"
236 )])
237 .env("CUDA_LAUNCH_BLOCKING", "1")
238 .status()
239 .await?;
240 let _ = status;
241 }
242 }
243 }
244 }
245 Ok(())
246}