1use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::process::Command;
8
9#[derive(Subcommand)]
10pub enum DebugCommands {
11 WhySlow,
13 Oom,
15 Nan {
17 script: String,
19 },
20 Hang,
22 Checkpoint {
24 path: String,
26 },
27 Trace {
29 script: String,
31 #[arg(trailing_var_arg = true)]
32 args: Vec<String>,
33 },
34}
35
36fn run_python(code: &str) -> Result<String> {
37 let output = Command::new("python3")
38 .args(["-c", code])
39 .output()
40 .with_context(|| "python3 not found")?;
41 Ok(String::from_utf8_lossy(&output.stdout).to_string()
42 + &String::from_utf8_lossy(&output.stderr))
43}
44
45pub async fn run(cmd: DebugCommands) -> Result<()> {
46 match cmd {
47 DebugCommands::WhySlow => {
48 println!("Zernel Performance Diagnosis");
49 println!("{}", "=".repeat(60));
50 println!();
51
52 println!("[1/4] GPU Utilization...");
54 if let Ok(output) = Command::new("nvidia-smi")
55 .args([
56 "--query-gpu=index,utilization.gpu,memory.used,memory.total",
57 "--format=csv,noheader,nounits",
58 ])
59 .output()
60 {
61 let data = String::from_utf8_lossy(&output.stdout);
62 for line in data.lines() {
63 let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
64 if f.len() >= 4 {
65 let util: u32 = f[1].parse().unwrap_or(0);
66 let diagnosis = if util < 30 {
67 "LOW — likely data loading bottleneck"
68 } else if util < 70 {
69 "MEDIUM — possible data pipeline or CPU bottleneck"
70 } else {
71 "GOOD — GPU well utilized"
72 };
73 println!(" GPU {}: {}% — {}", f[0], util, diagnosis);
74 }
75 }
76 } else {
77 println!(" nvidia-smi not available");
78 }
79 println!();
80
81 println!("[2/4] CPU Utilization...");
83 #[cfg(target_os = "linux")]
84 {
85 if let Ok(content) = std::fs::read_to_string("/proc/loadavg") {
86 let parts: Vec<&str> = content.split_whitespace().collect();
87 if let Some(load) = parts.first() {
88 let cores = std::thread::available_parallelism()
89 .map(|n| n.get())
90 .unwrap_or(1);
91 let load_val: f64 = load.parse().unwrap_or(0.0);
92 let pct = (load_val / cores as f64 * 100.0) as u32;
93 let diagnosis = if pct > 90 {
94 "HIGH — CPU may be bottleneck (data preprocessing?)"
95 } else {
96 "OK"
97 };
98 println!(" Load: {load} ({cores} cores, {pct}% utilized) — {diagnosis}");
99 }
100 }
101 }
102 #[cfg(not(target_os = "linux"))]
103 println!(" (CPU check requires Linux)");
104 println!();
105
106 println!("[3/4] System Memory...");
108 #[cfg(target_os = "linux")]
109 {
110 if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
111 let mut total = 0u64;
112 let mut available = 0u64;
113 for line in content.lines() {
114 if line.starts_with("MemTotal:") {
115 total = line
116 .split_whitespace()
117 .nth(1)
118 .and_then(|s| s.parse().ok())
119 .unwrap_or(0);
120 }
121 if line.starts_with("MemAvailable:") {
122 available = line
123 .split_whitespace()
124 .nth(1)
125 .and_then(|s| s.parse().ok())
126 .unwrap_or(0);
127 }
128 }
129 let used_pct = if total > 0 {
130 ((total - available) * 100 / total) as u32
131 } else {
132 0
133 };
134 let diagnosis = if used_pct > 90 {
135 "HIGH — may cause swap/OOM"
136 } else {
137 "OK"
138 };
139 println!(
140 " {used_pct}% used ({} / {} GB) — {diagnosis}",
141 (total - available) / 1048576,
142 total / 1048576
143 );
144 }
145 }
146 #[cfg(not(target_os = "linux"))]
147 println!(" (memory check requires Linux)");
148 println!();
149
150 println!("[4/4] Recommendations");
152 println!(" - If GPU util < 50%: increase DataLoader num_workers, use pin_memory=True");
153 println!(
154 " - If GPU memory near limit: reduce batch size or use gradient checkpointing"
155 );
156 println!(" - If CPU is bottleneck: move preprocessing to GPU or use faster storage");
157 println!(" - Run: zernel bench dataloader — to measure data pipeline throughput");
158 println!(" - Run: zernel gpu top — to monitor GPU usage in real-time");
159 }
160
161 DebugCommands::Oom => {
162 println!("GPU OOM Debugger");
163 println!("{}", "=".repeat(60));
164 println!();
165
166 let out = run_python(
167 "import torch; \
168 for i in range(torch.cuda.device_count()): \
169 total=torch.cuda.get_device_properties(i).total_mem/(1024**3); \
170 reserved=torch.cuda.memory_reserved(i)/(1024**3); \
171 allocated=torch.cuda.memory_allocated(i)/(1024**3); \
172 free=total-reserved; \
173 print(f'GPU {i}: {allocated:.1f}/{total:.1f} GB allocated, {free:.1f} GB free')"
174 )?;
175 println!("{out}");
176
177 println!("Tips to fix OOM:");
178 println!(" 1. Reduce batch_size");
179 println!(" 2. Use torch.cuda.amp (mixed precision) — halves memory");
180 println!(" 3. Use gradient_checkpointing_enable() — trades compute for memory");
181 println!(" 4. Use DeepSpeed ZeRO Stage 2/3 — shards optimizer states");
182 println!(" 5. Use model.to(dtype=torch.bfloat16)");
183 println!(" 6. Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True");
184 }
185
186 DebugCommands::Nan { script } => {
187 println!("NaN Detector — running {script} with anomaly detection...");
188 let code = format!(
189 "import torch; torch.autograd.set_detect_anomaly(True); exec(open('{script}').read())"
190 );
191 let status = Command::new("python3").args(["-c", &code]).status()?;
192 if !status.success() {
193 println!("Script exited with error — check output above for NaN source.");
194 }
195 }
196
197 DebugCommands::Hang => {
198 println!("NCCL Hang Detector");
199 println!();
200 println!("Set these environment variables before training:");
201 println!(" export NCCL_DEBUG=INFO");
202 println!(" export NCCL_DEBUG_SUBSYS=ALL");
203 println!(" export TORCH_DISTRIBUTED_DEBUG=DETAIL");
204 println!(" export NCCL_ASYNC_ERROR_HANDLING=1");
205 println!(" export NCCL_TIMEOUT=300 # 5 min timeout");
206 println!();
207 println!("Then run: zernel run train.py");
208 println!("If it hangs, check which rank is stuck in the NCCL logs.");
209 }
210
211 DebugCommands::Checkpoint { path } => {
212 println!("Verifying checkpoint: {path}");
213 let out = run_python(&format!(
214 "import torch, os; \
215 ckpt=torch.load('{path}', map_location='cpu', weights_only=False); \
216 if isinstance(ckpt, dict): \
217 print(f'Type: dict with {{len(ckpt)}} keys'); \
218 for k in list(ckpt.keys())[:20]: \
219 v=ckpt[k]; \
220 if hasattr(v,'shape'): print(f' {{k}}: {{v.dtype}} {{list(v.shape)}}'); \
221 else: print(f' {{k}}: {{type(v).__name__}}'); \
222 if len(ckpt)>20: print(f' ... and {{len(ckpt)-20}} more keys'); \
223 else: print(f'Type: {{type(ckpt).__name__}}'); \
224 size=os.path.getsize('{path}')/(1024**3); \
225 print(f'Size: {{size:.2f}} GB')"
226 ))?;
227 println!("{out}");
228 }
229
230 DebugCommands::Trace { script, args } => {
231 println!("Running {script} with enhanced tracing...");
232 let mut cmd = Command::new("python3");
233 cmd.args(["-u", &script]);
234 cmd.args(&args);
235 cmd.env("CUDA_LAUNCH_BLOCKING", "1");
236 cmd.env("TORCH_SHOW_CPP_STACKTRACES", "1");
237 let status = cmd.status()?;
238 if !status.success() {
239 println!("Script exited with code {}", status.code().unwrap_or(-1));
240 }
241 }
242 }
243 Ok(())
244}