zernel/commands/
autopilot.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel autopilot — Autonomous training optimizer
4//!
5//! Monitors training in real-time and automatically fixes problems:
6//! - Detects GPU underutilization → suggests increasing DataLoader workers
7//! - Detects memory pressure → suggests gradient checkpointing
8//! - Detects NaN gradients → stops early and reports the layer
9//! - Detects data bottleneck → suggests prefetching
10//! - Tracks loss curve and detects divergence
11
12use anyhow::{Context, Result};
13use clap::Subcommand;
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16
17#[derive(Subcommand)]
18pub enum AutopilotCommands {
19    /// Run a training script with autonomous monitoring and optimization
20    Run {
21        /// Training script
22        script: String,
23        /// Additional arguments
24        #[arg(trailing_var_arg = true)]
25        args: Vec<String>,
26    },
27    /// Analyze a running training job
28    Analyze,
29}
30
31struct AutopilotState {
32    step: u64,
33    losses: Vec<f64>,
34    gpu_utils: Vec<u32>,
35    warnings: Vec<String>,
36    interventions: Vec<String>,
37}
38
39impl AutopilotState {
40    fn new() -> Self {
41        Self {
42            step: 0,
43            losses: Vec::new(),
44            gpu_utils: Vec::new(),
45            warnings: Vec::new(),
46            interventions: Vec::new(),
47        }
48    }
49
50    fn check_loss_divergence(&mut self) {
51        if self.losses.len() < 10 {
52            return;
53        }
54        let recent = &self.losses[self.losses.len() - 5..];
55        let earlier = &self.losses[self.losses.len() - 10..self.losses.len() - 5];
56        let recent_avg: f64 = recent.iter().sum::<f64>() / 5.0;
57        let earlier_avg: f64 = earlier.iter().sum::<f64>() / 5.0;
58
59        if recent_avg > earlier_avg * 1.5 {
60            let msg = format!(
61                "Step {}: Loss diverging ({:.4} → {:.4}). Consider reducing learning rate.",
62                self.step, earlier_avg, recent_avg
63            );
64            self.warnings.push(msg.clone());
65            println!("  ⚠ AUTOPILOT: {msg}");
66        }
67
68        // Check for NaN
69        if let Some(last) = self.losses.last() {
70            if last.is_nan() || last.is_infinite() {
71                let msg = format!("Step {}: NaN/Inf detected in loss! Stopping.", self.step);
72                self.warnings.push(msg.clone());
73                println!("  🛑 AUTOPILOT: {msg}");
74                println!("  Fix: reduce learning rate, check data for corrupted samples,");
75                println!("       or enable gradient clipping: torch.nn.utils.clip_grad_norm_");
76            }
77        }
78    }
79
80    fn check_gpu_utilization(&mut self) {
81        if let Ok(output) = std::process::Command::new("nvidia-smi")
82            .args([
83                "--query-gpu=utilization.gpu",
84                "--format=csv,noheader,nounits",
85            ])
86            .output()
87        {
88            if let Ok(stdout) = String::from_utf8(output.stdout) {
89                if let Ok(util) = stdout.trim().parse::<u32>() {
90                    self.gpu_utils.push(util);
91
92                    if self.gpu_utils.len() >= 5 {
93                        let recent: u32 = self.gpu_utils[self.gpu_utils.len() - 5..]
94                            .iter()
95                            .sum::<u32>()
96                            / 5;
97
98                        if recent < 30 && !self.interventions.contains(&"low_gpu".to_string()) {
99                            let msg = format!(
100                                "Step {}: GPU utilization low ({}%). Data pipeline is the bottleneck.",
101                                self.step, recent
102                            );
103                            println!("  ⚡ AUTOPILOT: {msg}");
104                            println!("    → Increase num_workers in DataLoader");
105                            println!("    → Enable pin_memory=True");
106                            println!("    → Use prefetch_factor=4");
107                            self.interventions.push("low_gpu".to_string());
108                        }
109
110                        if recent > 95 && !self.interventions.contains(&"high_gpu".to_string()) {
111                            println!(
112                                "  ✓ AUTOPILOT: GPU utilization excellent ({}%). Training is GPU-bound.",
113                                recent
114                            );
115                            self.interventions.push("high_gpu".to_string());
116                        }
117                    }
118                }
119            }
120        }
121    }
122
123    fn print_summary(&self) {
124        println!();
125        println!("Zernel Autopilot Summary");
126        println!("{}", "=".repeat(50));
127        println!("  Steps monitored: {}", self.step);
128        println!("  Warnings:        {}", self.warnings.len());
129        println!("  Interventions:   {}", self.interventions.len());
130
131        if !self.losses.is_empty() {
132            let first = self.losses[0];
133            let last = self.losses[self.losses.len() - 1];
134            let improvement = if first > 0.0 {
135                (1.0 - last / first) * 100.0
136            } else {
137                0.0
138            };
139            println!(
140                "  Loss:            {:.4} → {:.4} ({:.1}% improvement)",
141                first, last, improvement
142            );
143        }
144
145        if !self.gpu_utils.is_empty() {
146            let avg: u32 = self.gpu_utils.iter().sum::<u32>() / self.gpu_utils.len() as u32;
147            println!("  Avg GPU util:    {}%", avg);
148        }
149
150        if !self.warnings.is_empty() {
151            println!();
152            println!("  Warnings:");
153            for w in &self.warnings {
154                println!("    - {w}");
155            }
156        }
157    }
158}
159
160pub async fn run(cmd: AutopilotCommands) -> Result<()> {
161    match cmd {
162        AutopilotCommands::Run { script, args } => {
163            println!("Zernel Autopilot");
164            println!("{}", "=".repeat(50));
165            println!("  Script:  {script}");
166            println!("  Mode:    autonomous monitoring + optimization");
167            println!();
168            println!("Autopilot will:");
169            println!("  - Monitor GPU utilization and suggest DataLoader changes");
170            println!("  - Track loss curve and detect divergence/NaN");
171            println!("  - Alert on memory pressure");
172            println!("  - Report optimization opportunities");
173            println!();
174
175            let mut state = AutopilotState::new();
176            let extractor = crate::experiments::tracker::MetricExtractor::new();
177
178            let python = if std::process::Command::new("python3")
179                .arg("--version")
180                .output()
181                .is_ok()
182            {
183                "python3"
184            } else {
185                "python"
186            };
187
188            let mut child = tokio::process::Command::new(python)
189                .arg(&script)
190                .args(&args)
191                .stdout(Stdio::piped())
192                .stderr(Stdio::piped())
193                .spawn()
194                .with_context(|| format!("failed to launch {script}"))?;
195
196            let stdout = child
197                .stdout
198                .take()
199                .ok_or_else(|| anyhow::anyhow!("no stdout"))?;
200
201            let mut reader = BufReader::new(stdout);
202            let mut line = String::new();
203            let mut check_interval = 0u64;
204
205            loop {
206                line.clear();
207                match reader.read_line(&mut line).await {
208                    Ok(0) => break,
209                    Ok(_) => {
210                        let trimmed = line.trim_end();
211                        println!("{trimmed}");
212
213                        // Extract metrics
214                        let metrics = extractor.extract_from_line(trimmed);
215                        if let Some(&loss) = metrics.get("loss") {
216                            state.losses.push(loss);
217                            state.step += 1;
218                            state.check_loss_divergence();
219                        }
220                        if let Some(&step) = metrics.get("step") {
221                            state.step = step as u64;
222                        }
223
224                        // Periodic GPU check (every 10 lines)
225                        check_interval += 1;
226                        if check_interval.is_multiple_of(10) {
227                            state.check_gpu_utilization();
228                        }
229                    }
230                    Err(_) => break,
231                }
232            }
233
234            let status = child.wait().await?;
235            state.print_summary();
236
237            if !status.success() {
238                println!();
239                println!(
240                    "Training failed with exit code {}",
241                    status.code().unwrap_or(-1)
242                );
243            }
244        }
245
246        AutopilotCommands::Analyze => {
247            println!("Zernel Autopilot — Live Analysis");
248            println!("{}", "=".repeat(50));
249            println!();
250
251            // Check current state
252            if let Ok(output) = std::process::Command::new("nvidia-smi")
253                .args([
254                    "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw",
255                    "--format=csv,noheader,nounits",
256                ])
257                .output()
258            {
259                let stdout = String::from_utf8_lossy(&output.stdout);
260                for line in stdout.lines() {
261                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
262                    if f.len() >= 5 {
263                        let util: u32 = f[0].parse().unwrap_or(0);
264                        let mem_used: u64 = f[1].parse().unwrap_or(0);
265                        let mem_total: u64 = f[2].parse().unwrap_or(1);
266                        let mem_pct = mem_used * 100 / mem_total;
267
268                        println!("  GPU Utilization: {}%", util);
269                        println!("  Memory: {}% ({}/{}MB)", mem_pct, mem_used, mem_total);
270                        println!("  Temperature: {}°C", f[3]);
271                        println!("  Power: {}W", f[4]);
272                        println!();
273
274                        if util < 30 {
275                            println!("  ⚡ RECOMMENDATION: GPU is underutilized.");
276                            println!("    → Increase DataLoader num_workers");
277                            println!("    → Enable pin_memory=True");
278                        } else if util > 90 {
279                            println!("  ✓ GPU utilization is excellent.");
280                        }
281
282                        if mem_pct > 90 {
283                            println!("  ⚠ RECOMMENDATION: GPU memory nearly full.");
284                            println!("    → Enable gradient checkpointing");
285                            println!("    → Use mixed precision (BF16/FP16)");
286                            println!("    → Reduce batch size");
287                        }
288                    }
289                }
290            } else {
291                println!("  nvidia-smi not available.");
292            }
293        }
294    }
295    Ok(())
296}