zernel/commands/
debug.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel debug — ML training debugger
4
5use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::process::Command;
8
9#[derive(Subcommand)]
10pub enum DebugCommands {
11    /// Analyze why training is slow
12    WhySlow,
13    /// Trace GPU out-of-memory errors
14    Oom,
15    /// Detect NaN gradients and trace to the source
16    Nan {
17        /// Training script to run with NaN detection
18        script: String,
19    },
20    /// Detect NCCL deadlocks / straggler ranks
21    Hang,
22    /// Verify a checkpoint file
23    Checkpoint {
24        /// Path to checkpoint file
25        path: String,
26    },
27    /// Run a script with enhanced tracing
28    Trace {
29        /// Script to trace
30        script: String,
31        #[arg(trailing_var_arg = true)]
32        args: Vec<String>,
33    },
34}
35
36fn run_python(code: &str) -> Result<String> {
37    let output = Command::new("python3")
38        .args(["-c", code])
39        .output()
40        .with_context(|| "python3 not found")?;
41    Ok(String::from_utf8_lossy(&output.stdout).to_string()
42        + &String::from_utf8_lossy(&output.stderr))
43}
44
45pub async fn run(cmd: DebugCommands) -> Result<()> {
46    match cmd {
47        DebugCommands::WhySlow => {
48            println!("Zernel Performance Diagnosis");
49            println!("{}", "=".repeat(60));
50            println!();
51
52            // Check GPU utilization
53            println!("[1/4] GPU Utilization...");
54            if let Ok(output) = Command::new("nvidia-smi")
55                .args([
56                    "--query-gpu=index,utilization.gpu,memory.used,memory.total",
57                    "--format=csv,noheader,nounits",
58                ])
59                .output()
60            {
61                let data = String::from_utf8_lossy(&output.stdout);
62                for line in data.lines() {
63                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
64                    if f.len() >= 4 {
65                        let util: u32 = f[1].parse().unwrap_or(0);
66                        let diagnosis = if util < 30 {
67                            "LOW — likely data loading bottleneck"
68                        } else if util < 70 {
69                            "MEDIUM — possible data pipeline or CPU bottleneck"
70                        } else {
71                            "GOOD — GPU well utilized"
72                        };
73                        println!("  GPU {}: {}% — {}", f[0], util, diagnosis);
74                    }
75                }
76            } else {
77                println!("  nvidia-smi not available");
78            }
79            println!();
80
81            // Check CPU
82            println!("[2/4] CPU Utilization...");
83            #[cfg(target_os = "linux")]
84            {
85                if let Ok(content) = std::fs::read_to_string("/proc/loadavg") {
86                    let parts: Vec<&str> = content.split_whitespace().collect();
87                    if let Some(load) = parts.first() {
88                        let cores = std::thread::available_parallelism()
89                            .map(|n| n.get())
90                            .unwrap_or(1);
91                        let load_val: f64 = load.parse().unwrap_or(0.0);
92                        let pct = (load_val / cores as f64 * 100.0) as u32;
93                        let diagnosis = if pct > 90 {
94                            "HIGH — CPU may be bottleneck (data preprocessing?)"
95                        } else {
96                            "OK"
97                        };
98                        println!("  Load: {load} ({cores} cores, {pct}% utilized) — {diagnosis}");
99                    }
100                }
101            }
102            #[cfg(not(target_os = "linux"))]
103            println!("  (CPU check requires Linux)");
104            println!();
105
106            // Check memory
107            println!("[3/4] System Memory...");
108            #[cfg(target_os = "linux")]
109            {
110                if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
111                    let mut total = 0u64;
112                    let mut available = 0u64;
113                    for line in content.lines() {
114                        if line.starts_with("MemTotal:") {
115                            total = line
116                                .split_whitespace()
117                                .nth(1)
118                                .and_then(|s| s.parse().ok())
119                                .unwrap_or(0);
120                        }
121                        if line.starts_with("MemAvailable:") {
122                            available = line
123                                .split_whitespace()
124                                .nth(1)
125                                .and_then(|s| s.parse().ok())
126                                .unwrap_or(0);
127                        }
128                    }
129                    let used_pct = if total > 0 {
130                        ((total - available) * 100 / total) as u32
131                    } else {
132                        0
133                    };
134                    let diagnosis = if used_pct > 90 {
135                        "HIGH — may cause swap/OOM"
136                    } else {
137                        "OK"
138                    };
139                    println!(
140                        "  {used_pct}% used ({} / {} GB) — {diagnosis}",
141                        (total - available) / 1048576,
142                        total / 1048576
143                    );
144                }
145            }
146            #[cfg(not(target_os = "linux"))]
147            println!("  (memory check requires Linux)");
148            println!();
149
150            // Check I/O
151            println!("[4/4] Recommendations");
152            println!("  - If GPU util < 50%: increase DataLoader num_workers, use pin_memory=True");
153            println!(
154                "  - If GPU memory near limit: reduce batch size or use gradient checkpointing"
155            );
156            println!("  - If CPU is bottleneck: move preprocessing to GPU or use faster storage");
157            println!("  - Run: zernel bench dataloader  — to measure data pipeline throughput");
158            println!("  - Run: zernel gpu top  — to monitor GPU usage in real-time");
159        }
160
161        DebugCommands::Oom => {
162            println!("GPU OOM Debugger");
163            println!("{}", "=".repeat(60));
164            println!();
165
166            let out = run_python(
167                "import torch; \
168                 for i in range(torch.cuda.device_count()): \
169                     total=torch.cuda.get_device_properties(i).total_mem/(1024**3); \
170                     reserved=torch.cuda.memory_reserved(i)/(1024**3); \
171                     allocated=torch.cuda.memory_allocated(i)/(1024**3); \
172                     free=total-reserved; \
173                     print(f'GPU {i}: {allocated:.1f}/{total:.1f} GB allocated, {free:.1f} GB free')"
174            )?;
175            println!("{out}");
176
177            println!("Tips to fix OOM:");
178            println!("  1. Reduce batch_size");
179            println!("  2. Use torch.cuda.amp (mixed precision) — halves memory");
180            println!("  3. Use gradient_checkpointing_enable() — trades compute for memory");
181            println!("  4. Use DeepSpeed ZeRO Stage 2/3 — shards optimizer states");
182            println!("  5. Use model.to(dtype=torch.bfloat16)");
183            println!("  6. Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True");
184        }
185
186        DebugCommands::Nan { script } => {
187            println!("NaN Detector — running {script} with anomaly detection...");
188            let code = format!(
189                "import torch; torch.autograd.set_detect_anomaly(True); exec(open('{script}').read())"
190            );
191            let status = Command::new("python3").args(["-c", &code]).status()?;
192            if !status.success() {
193                println!("Script exited with error — check output above for NaN source.");
194            }
195        }
196
197        DebugCommands::Hang => {
198            println!("NCCL Hang Detector");
199            println!();
200            println!("Set these environment variables before training:");
201            println!("  export NCCL_DEBUG=INFO");
202            println!("  export NCCL_DEBUG_SUBSYS=ALL");
203            println!("  export TORCH_DISTRIBUTED_DEBUG=DETAIL");
204            println!("  export NCCL_ASYNC_ERROR_HANDLING=1");
205            println!("  export NCCL_TIMEOUT=300  # 5 min timeout");
206            println!();
207            println!("Then run: zernel run train.py");
208            println!("If it hangs, check which rank is stuck in the NCCL logs.");
209        }
210
211        DebugCommands::Checkpoint { path } => {
212            println!("Verifying checkpoint: {path}");
213            let out = run_python(&format!(
214                "import torch, os; \
215                 ckpt=torch.load('{path}', map_location='cpu', weights_only=False); \
216                 if isinstance(ckpt, dict): \
217                     print(f'Type: dict with {{len(ckpt)}} keys'); \
218                     for k in list(ckpt.keys())[:20]: \
219                         v=ckpt[k]; \
220                         if hasattr(v,'shape'): print(f'  {{k}}: {{v.dtype}} {{list(v.shape)}}'); \
221                         else: print(f'  {{k}}: {{type(v).__name__}}'); \
222                     if len(ckpt)>20: print(f'  ... and {{len(ckpt)-20}} more keys'); \
223                 else: print(f'Type: {{type(ckpt).__name__}}'); \
224                 size=os.path.getsize('{path}')/(1024**3); \
225                 print(f'Size: {{size:.2f}} GB')"
226            ))?;
227            println!("{out}");
228        }
229
230        DebugCommands::Trace { script, args } => {
231            println!("Running {script} with enhanced tracing...");
232            let mut cmd = Command::new("python3");
233            cmd.args(["-u", &script]);
234            cmd.args(&args);
235            cmd.env("CUDA_LAUNCH_BLOCKING", "1");
236            cmd.env("TORCH_SHOW_CPP_STACKTRACES", "1");
237            let status = cmd.status()?;
238            if !status.success() {
239                println!("Script exited with code {}", status.code().unwrap_or(-1));
240            }
241        }
242    }
243    Ok(())
244}