zernel/commands/
profile.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel profile — Full training pipeline profiler
4//!
5//! Detailed breakdown of where every millisecond goes in a training step.
6//! Shows waterfall chart in the terminal.
7
8use anyhow::{Context, Result};
9use clap::Subcommand;
10use std::process::Command;
11
12#[derive(Subcommand)]
13pub enum ProfileCommands {
14    /// Profile a training script (runs 10 steps with tracing)
15    Run {
16        /// Training script
17        script: String,
18        /// Number of profiling steps
19        #[arg(long, default_value = "10")]
20        steps: u32,
21    },
22    /// Show GPU utilization timeline
23    Timeline {
24        /// Duration in seconds
25        #[arg(long, default_value = "30")]
26        duration: u32,
27    },
28    /// Profile CUDA operations
29    Cuda {
30        /// Training script
31        script: String,
32    },
33}
34
35pub async fn run(cmd: ProfileCommands) -> Result<()> {
36    match cmd {
37        ProfileCommands::Run { script, steps } => {
38            println!("Zernel Training Profiler");
39            println!("{}", "=".repeat(60));
40            println!("  Script: {script}");
41            println!("  Steps:  {steps}");
42            println!();
43
44            // Generate a profiling wrapper script
45            let profile_code = format!(
46                r#"
47import torch
48import torch.autograd.profiler as profiler
49import time
50import sys
51import os
52
53# Enable CUDA synchronous execution for accurate timing
54os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
55
56print("Running profiler...")
57print()
58
59# Import and run the training script with profiling
60with profiler.profile(use_cuda=True, profile_memory=True, record_shapes=True) as prof:
61    # Execute the training script
62    exec(open('{script}').read())
63
64# Print profiler results
65print()
66print("=" * 70)
67print("ZERNEL TRAINING PROFILE")
68print("-" * 70)
69print()
70
71# Table sorted by CUDA time
72print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
73
74print()
75print("-" * 70)
76
77# Memory summary
78if torch.cuda.is_available():
79    print()
80    print("GPU Memory Summary:")
81    print(f"  Peak allocated: {{torch.cuda.max_memory_allocated() / 1e9:.2f}} GB")
82    print(f"  Peak reserved:  {{torch.cuda.max_memory_reserved() / 1e9:.2f}} GB")
83    print(f"  Current:        {{torch.cuda.memory_allocated() / 1e9:.2f}} GB")
84
85# Export chrome trace
86trace_file = '/tmp/zernel_profile_trace.json'
87prof.export_chrome_trace(trace_file)
88print(f)
89print(f"Chrome trace saved to: {{trace_file}}")
90print("View in Chrome: chrome://tracing → Load → select the file")
91"#
92            );
93
94            let wrapper_path = "/tmp/zernel_profile_wrapper.py";
95            std::fs::write(wrapper_path, &profile_code)?;
96
97            let status = tokio::process::Command::new("python3")
98                .arg(wrapper_path)
99                .env("CUDA_LAUNCH_BLOCKING", "1")
100                .status()
101                .await
102                .with_context(|| format!("failed to run {script}"))?;
103
104            if !status.success() {
105                // Fallback: just run with basic timing
106                println!("PyTorch profiler failed. Running basic timing analysis...");
107                println!();
108
109                let basic_code = r#"
110import torch, time
111
112if torch.cuda.is_available():
113    # Measure H2D transfer
114    data = torch.randn(256, 3, 224, 224)
115    torch.cuda.synchronize()
116    t0 = time.time()
117    for _ in range(10):
118        gpu_data = data.cuda()
119        del gpu_data
120    torch.cuda.synchronize()
121    t1 = time.time()
122    h2d_ms = (t1 - t0) / 10 * 1000
123    print(f"  H2D transfer:     {{h2d_ms:.2f}} ms/batch")
124
125    # Measure allocation
126    torch.cuda.synchronize()
127    t0 = time.time()
128    for _ in range(100):
129        x = torch.empty(1024*1024*64, device='cuda')
130        del x
131    torch.cuda.synchronize()
132    t1 = time.time()
133    alloc_us = (t1 - t0) / 100 * 1e6
134    print(f"  GPU alloc (256MB): {{alloc_us:.0f}} us")
135
136    # Measure sync overhead
137    torch.cuda.synchronize()
138    t0 = time.time()
139    for _ in range(1000):
140        torch.cuda.synchronize()
141    t1 = time.time()
142    sync_us = (t1 - t0) / 1000 * 1e6
143    print(f"  CUDA sync:        {{sync_us:.1f}} us")
144
145    print()
146    print(f"  Peak GPU memory:  {{torch.cuda.max_memory_allocated() / 1e9:.2f}} GB")
147"#;
148                let _ = tokio::process::Command::new("python3")
149                    .args(["-c", basic_code])
150                    .status()
151                    .await;
152            }
153        }
154
155        ProfileCommands::Timeline { duration } => {
156            println!("GPU Utilization Timeline ({duration}s)");
157            println!("{}", "=".repeat(60));
158            println!();
159
160            let start = std::time::Instant::now();
161            let mut samples = Vec::new();
162
163            while start.elapsed().as_secs() < duration as u64 {
164                let output = Command::new("nvidia-smi")
165                    .args([
166                        "--query-gpu=utilization.gpu,memory.used,power.draw",
167                        "--format=csv,noheader,nounits",
168                    ])
169                    .output();
170
171                if let Ok(o) = output {
172                    let stdout = String::from_utf8_lossy(&o.stdout);
173                    if let Some(line) = stdout.lines().next() {
174                        let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
175                        if f.len() >= 3 {
176                            let util: u32 = f[0].parse().unwrap_or(0);
177                            let elapsed = start.elapsed().as_secs();
178                            let bar_len = (util as usize * 40) / 100;
179                            let bar =
180                                format!("{}{}", "#".repeat(bar_len), " ".repeat(40 - bar_len));
181                            println!(
182                                "  {:>3}s [{bar}] {:>3}% {:>5}MB {:>5}W",
183                                elapsed, util, f[1], f[2]
184                            );
185                            samples.push(util);
186                        }
187                    }
188                }
189
190                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
191            }
192
193            if !samples.is_empty() {
194                let avg: u32 = samples.iter().sum::<u32>() / samples.len() as u32;
195                let max = samples.iter().max().unwrap_or(&0);
196                let min = samples.iter().min().unwrap_or(&0);
197                println!();
198                println!("Summary: avg={avg}% min={min}% max={max}%");
199
200                if avg < 50 {
201                    println!("  → GPU underutilized. Run: zernel debug why-slow");
202                }
203            }
204        }
205
206        ProfileCommands::Cuda { script } => {
207            println!("CUDA Operation Profile: {script}");
208            println!("{}", "=".repeat(60));
209            println!();
210            println!("Running with NVIDIA Nsight Systems...");
211
212            let nsys_check = Command::new("nsys").arg("--version").output();
213            match nsys_check {
214                Ok(o) if o.status.success() => {
215                    let status = tokio::process::Command::new("nsys")
216                        .args([
217                            "profile",
218                            "--stats=true",
219                            "--output=/tmp/zernel_nsys",
220                            "python3",
221                            &script,
222                        ])
223                        .status()
224                        .await?;
225
226                    if status.success() {
227                        println!("Profile saved to: /tmp/zernel_nsys.nsys-rep");
228                        println!("View with: nsys-ui /tmp/zernel_nsys.nsys-rep");
229                    }
230                }
231                _ => {
232                    println!("  nsys not found. Using PyTorch CUDA profiler instead...");
233                    let status = tokio::process::Command::new("python3")
234                        .args(["-c", &format!(
235                            "import torch; torch.cuda.cudart().cudaProfilerStart(); exec(open('{script}').read()); torch.cuda.cudart().cudaProfilerStop()"
236                        )])
237                        .env("CUDA_LAUNCH_BLOCKING", "1")
238                        .status()
239                        .await?;
240                    let _ = status;
241                }
242            }
243        }
244    }
245    Ok(())
246}