zernel/commands/
gpu.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel gpu — GPU management CLI (nvidia-smi replacement)
4
5use anyhow::{Context, Result};
6use clap::Subcommand;
7
8use std::process::Command;
9
10#[derive(Subcommand)]
11pub enum GpuCommands {
12    /// Show clean GPU overview (default if no subcommand)
13    Status,
14    /// Real-time GPU process viewer (like htop for GPUs)
15    Top,
16    /// Show GPU memory usage by process
17    Mem,
18    /// Kill all processes on a GPU
19    Kill {
20        /// GPU index (0, 1, 2, ...)
21        gpu: u32,
22    },
23    /// Reserve GPUs for exclusive use
24    Lock {
25        /// GPU indices (comma-separated: 0,1,2)
26        gpus: String,
27        /// Job or user to lock for
28        #[arg(long, rename_all = "kebab-case")]
29        for_job: Option<String>,
30    },
31    /// Release GPU reservation
32    Unlock {
33        /// GPU indices (comma-separated)
34        gpus: String,
35    },
36    /// Monitor GPU temperature with alerts
37    Temp {
38        /// Alert threshold in Celsius
39        #[arg(long, default_value = "85")]
40        alert: u32,
41    },
42    /// Set GPU power limits
43    Power {
44        /// Power limit (e.g., 300W)
45        #[arg(long)]
46        limit: Option<String>,
47    },
48    /// GPU health check (ECC errors, throttling, PCIe)
49    Health,
50}
51
52fn query_nvidia_smi(fields: &str) -> Result<String> {
53    let output = Command::new("nvidia-smi")
54        .args(["--query-gpu", fields, "--format=csv,noheader,nounits"])
55        .output()
56        .with_context(|| "nvidia-smi not found — is the NVIDIA driver installed?")?;
57    if !output.status.success() {
58        anyhow::bail!(
59            "nvidia-smi failed: {}",
60            String::from_utf8_lossy(&output.stderr)
61        );
62    }
63    Ok(String::from_utf8_lossy(&output.stdout).to_string())
64}
65
66fn query_nvidia_smi_procs() -> Result<String> {
67    let output = Command::new("nvidia-smi")
68        .args([
69            "--query-compute-apps",
70            "gpu_uuid,pid,process_name,used_gpu_memory",
71            "--format=csv,noheader,nounits",
72        ])
73        .output()
74        .with_context(|| "nvidia-smi not found")?;
75    Ok(String::from_utf8_lossy(&output.stdout).to_string())
76}
77
78pub async fn run(cmd: GpuCommands) -> Result<()> {
79    match cmd {
80        GpuCommands::Status => {
81            let data = query_nvidia_smi(
82                "index,name,utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw,power.limit",
83            )?;
84
85            println!("Zernel GPU Status");
86            println!("{}", "=".repeat(80));
87            println!(
88                "{:<5} {:<22} {:>5} {:>12} {:>6} {:>8} {:>10}",
89                "GPU", "Name", "Util", "Memory", "Temp", "Power", "Limit"
90            );
91            println!("{}", "-".repeat(80));
92
93            for line in data.lines() {
94                let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
95                if f.len() >= 8 {
96                    let util_pct: u32 = f[2].parse().unwrap_or(0);
97                    let bar = gpu_bar(util_pct, 10);
98                    println!(
99                        "{:<5} {:<22} {} {:>5}/{:<5} MB {:>3}°C {:>6}W/{:<4}W",
100                        f[0], f[1], bar, f[3], f[4], f[5], f[6], f[7]
101                    );
102                }
103            }
104            println!();
105        }
106
107        GpuCommands::Top => {
108            println!("Zernel GPU Top — Press Ctrl+C to exit");
109            println!();
110            loop {
111                // Clear screen
112                print!("\x1B[2J\x1B[H");
113
114                let gpu_data = query_nvidia_smi(
115                    "index,name,utilization.gpu,memory.used,memory.total,temperature.gpu",
116                )?;
117                let proc_data = query_nvidia_smi_procs()?;
118
119                println!("Zernel GPU Top");
120                println!("{}", "=".repeat(80));
121
122                for line in gpu_data.lines() {
123                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
124                    if f.len() >= 6 {
125                        let util: u32 = f[2].parse().unwrap_or(0);
126                        println!(
127                            "GPU {} ({}) {} {}/{}MB {}°C",
128                            f[0],
129                            f[1],
130                            gpu_bar(util, 20),
131                            f[3],
132                            f[4],
133                            f[5]
134                        );
135                    }
136                }
137
138                println!();
139                println!("{:<8} {:<30} {:>10}", "PID", "Process", "GPU Mem (MB)");
140                println!("{}", "-".repeat(50));
141
142                if proc_data.trim().is_empty() {
143                    println!("  (no GPU processes running)");
144                } else {
145                    for line in proc_data.lines() {
146                        let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
147                        if f.len() >= 4 {
148                            println!("{:<8} {:<30} {:>10}", f[1], f[2], f[3]);
149                        }
150                    }
151                }
152
153                tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
154            }
155        }
156
157        GpuCommands::Mem => {
158            let proc_data = query_nvidia_smi_procs()?;
159
160            println!("GPU Memory by Process");
161            println!("{}", "=".repeat(60));
162            println!(
163                "{:<8} {:<30} {:>10} {:>8}",
164                "PID", "Process", "GPU Mem", "GPU"
165            );
166            println!("{}", "-".repeat(60));
167
168            if proc_data.trim().is_empty() {
169                println!("  (no GPU processes running)");
170            } else {
171                let mut total: u64 = 0;
172                for line in proc_data.lines() {
173                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
174                    if f.len() >= 4 {
175                        let mem: u64 = f[3].parse().unwrap_or(0);
176                        total += mem;
177                        println!("{:<8} {:<30} {:>7} MB {:>8}", f[1], f[2], f[3], f[0]);
178                    }
179                }
180                println!("{}", "-".repeat(60));
181                println!("{:<38} {:>7} MB", "TOTAL", total);
182            }
183        }
184
185        GpuCommands::Kill { gpu } => {
186            let proc_data = query_nvidia_smi_procs()?;
187            let gpu_str = gpu.to_string();
188            let mut killed = 0;
189
190            // Get GPU UUID for this index
191            let uuid_data = query_nvidia_smi("index,uuid")?;
192            let target_uuid: Option<String> = uuid_data.lines().find_map(|line| {
193                let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
194                if f.len() >= 2 && f[0] == gpu_str {
195                    Some(f[1].to_string())
196                } else {
197                    None
198                }
199            });
200
201            let Some(uuid) = target_uuid else {
202                println!("GPU {gpu} not found");
203                return Ok(());
204            };
205
206            for line in proc_data.lines() {
207                let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
208                if f.len() >= 2 && f[0].contains(&uuid) {
209                    if let Ok(pid) = f[1].parse::<u32>() {
210                        println!("Killing PID {pid} ({})...", f.get(2).unwrap_or(&""));
211                        #[cfg(unix)]
212                        unsafe {
213                            libc::kill(pid as i32, libc::SIGTERM);
214                        }
215                        #[cfg(not(unix))]
216                        {
217                            let _ = Command::new("taskkill")
218                                .args(["/PID", &pid.to_string(), "/F"])
219                                .output();
220                        }
221                        killed += 1;
222                    }
223                }
224            }
225
226            if killed == 0 {
227                println!("No processes found on GPU {gpu}");
228            } else {
229                println!("Killed {killed} process(es) on GPU {gpu}");
230            }
231        }
232
233        GpuCommands::Lock { gpus, for_job } => {
234            let job_name = for_job.unwrap_or_else(|| "manual".into());
235            // Write lock file
236            let lock_dir = crate::experiments::tracker::zernel_dir().join("gpu-locks");
237            std::fs::create_dir_all(&lock_dir)?;
238            for gpu in gpus.split(',') {
239                let gpu = gpu.trim();
240                let lock_file = lock_dir.join(format!("gpu-{gpu}.lock"));
241                std::fs::write(&lock_file, &job_name)?;
242                println!("GPU {gpu} locked for: {job_name}");
243            }
244            println!("Set CUDA_VISIBLE_DEVICES={gpus} in your training script.");
245        }
246
247        GpuCommands::Unlock { gpus } => {
248            let lock_dir = crate::experiments::tracker::zernel_dir().join("gpu-locks");
249            for gpu in gpus.split(',') {
250                let gpu = gpu.trim();
251                let lock_file = lock_dir.join(format!("gpu-{gpu}.lock"));
252                if lock_file.exists() {
253                    std::fs::remove_file(&lock_file)?;
254                    println!("GPU {gpu} unlocked");
255                } else {
256                    println!("GPU {gpu} was not locked");
257                }
258            }
259        }
260
261        GpuCommands::Temp { alert } => {
262            println!("Monitoring GPU temperatures (alert at {alert}°C)...");
263            loop {
264                let data = query_nvidia_smi("index,temperature.gpu")?;
265                for line in data.lines() {
266                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
267                    if f.len() >= 2 {
268                        let temp: u32 = f[1].parse().unwrap_or(0);
269                        let indicator = if temp >= alert { "ALERT" } else { "ok" };
270                        println!("GPU {}: {}°C [{}]", f[0], temp, indicator);
271                    }
272                }
273                println!();
274                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
275            }
276        }
277
278        GpuCommands::Power { limit } => {
279            if let Some(limit_str) = limit {
280                let watts = limit_str.trim_end_matches('W').trim_end_matches('w');
281                let status = Command::new("nvidia-smi").args(["-pl", watts]).status()?;
282                if status.success() {
283                    println!("Power limit set to {watts}W across all GPUs");
284                } else {
285                    println!("Failed to set power limit (requires root)");
286                }
287            } else {
288                let data = query_nvidia_smi("index,power.draw,power.limit,power.max_limit")?;
289                println!("GPU Power Status");
290                println!("{:<5} {:>10} {:>10} {:>10}", "GPU", "Draw", "Limit", "Max");
291                for line in data.lines() {
292                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
293                    if f.len() >= 4 {
294                        println!("{:<5} {:>8}W {:>8}W {:>8}W", f[0], f[1], f[2], f[3]);
295                    }
296                }
297            }
298        }
299
300        GpuCommands::Health => {
301            println!("Zernel GPU Health Check");
302            println!("{}", "=".repeat(60));
303
304            let data = query_nvidia_smi(
305                "index,name,ecc.errors.corrected.volatile.total,ecc.errors.uncorrected.volatile.total,clocks_throttle_reasons.active,pcie.link.gen.current,pcie.link.width.current",
306            )?;
307
308            for line in data.lines() {
309                let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
310                if f.len() >= 7 {
311                    println!("GPU {} ({})", f[0], f[1]);
312                    let ecc_corr = f[2];
313                    let ecc_uncorr = f[3];
314                    let throttle = f[4];
315                    let pcie_gen = f[5];
316                    let pcie_width = f[6];
317
318                    let ecc_status = if ecc_uncorr != "0" && ecc_uncorr != "N/A" {
319                        "FAIL — uncorrected ECC errors"
320                    } else if ecc_corr != "0" && ecc_corr != "N/A" {
321                        "WARN — corrected ECC errors"
322                    } else {
323                        "OK"
324                    };
325
326                    let throttle_status =
327                        if throttle.contains("0x") && throttle != "0x0000000000000000" {
328                            "WARN — throttling active"
329                        } else {
330                            "OK"
331                        };
332
333                    println!("  ECC:        {ecc_status}");
334                    println!("  Throttling: {throttle_status}");
335                    println!("  PCIe:       Gen{pcie_gen} x{pcie_width}");
336                    println!();
337                }
338            }
339        }
340    }
341    Ok(())
342}
343
344fn gpu_bar(pct: u32, width: usize) -> String {
345    let filled = (pct as usize * width) / 100;
346    let empty = width.saturating_sub(filled);
347    let color = if pct > 90 {
348        "\x1b[32m"
349    } else if pct > 70 {
350        "\x1b[33m"
351    } else {
352        "\x1b[31m"
353    };
354    format!(
355        "{color}[{}{}]\x1b[0m {:>3}%",
356        "#".repeat(filled),
357        " ".repeat(empty),
358        pct
359    )
360}