zernel/commands/
doctor.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3use anyhow::Result;
4use std::process::Command;
5
6struct Check {
7    name: &'static str,
8    check: fn() -> CheckResult,
9}
10
11enum CheckResult {
12    Pass(String),
13    Warn(String),
14    Fail(String),
15    Skip(String),
16}
17
18/// Diagnose the Zernel environment.
19pub async fn run() -> Result<()> {
20    println!("Zernel Doctor v{}", env!("CARGO_PKG_VERSION"));
21    println!("{}", "=".repeat(50));
22    println!();
23
24    let checks: Vec<Check> = vec![
25        Check {
26            name: "Operating System",
27            check: check_os,
28        },
29        Check {
30            name: "Python",
31            check: check_python,
32        },
33        Check {
34            name: "NVIDIA Driver",
35            check: check_nvidia_driver,
36        },
37        Check {
38            name: "CUDA Toolkit",
39            check: check_cuda,
40        },
41        Check {
42            name: "PyTorch",
43            check: check_pytorch,
44        },
45        Check {
46            name: "PyTorch CUDA",
47            check: check_pytorch_cuda,
48        },
49        Check {
50            name: "Git",
51            check: check_git,
52        },
53        Check {
54            name: "zerneld",
55            check: check_zerneld,
56        },
57        Check {
58            name: "Zernel DB",
59            check: check_zernel_db,
60        },
61    ];
62
63    let mut pass = 0;
64    let mut warn = 0;
65    let mut fail = 0;
66    let mut skip = 0;
67
68    for check in &checks {
69        let result = (check.check)();
70        match &result {
71            CheckResult::Pass(msg) => {
72                println!("  [OK]   {}: {msg}", check.name);
73                pass += 1;
74            }
75            CheckResult::Warn(msg) => {
76                println!("  [WARN] {}: {msg}", check.name);
77                warn += 1;
78            }
79            CheckResult::Fail(msg) => {
80                println!("  [FAIL] {}: {msg}", check.name);
81                fail += 1;
82            }
83            CheckResult::Skip(msg) => {
84                println!("  [SKIP] {}: {msg}", check.name);
85                skip += 1;
86            }
87        }
88    }
89
90    println!();
91    println!("Results: {pass} passed, {warn} warnings, {fail} failed, {skip} skipped");
92
93    if fail > 0 {
94        println!();
95        println!("Some checks failed. Fix the issues above before running ML workloads.");
96    }
97
98    Ok(())
99}
100
101fn run_cmd(cmd: &str, args: &[&str]) -> Option<String> {
102    Command::new(cmd)
103        .args(args)
104        .output()
105        .ok()
106        .filter(|o| o.status.success())
107        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
108}
109
110fn check_os() -> CheckResult {
111    let os = std::env::consts::OS;
112    let arch = std::env::consts::ARCH;
113    if os == "linux" {
114        if let Some(kernel) = run_cmd("uname", &["-r"]) {
115            CheckResult::Pass(format!("Linux {kernel} ({arch})"))
116        } else {
117            CheckResult::Pass(format!("Linux ({arch})"))
118        }
119    } else {
120        CheckResult::Warn(format!(
121            "{os} ({arch}) — Zernel targets Linux. Some features unavailable."
122        ))
123    }
124}
125
126fn check_python() -> CheckResult {
127    for cmd in &["python3", "python"] {
128        if let Some(version) = run_cmd(cmd, &["--version"]) {
129            return CheckResult::Pass(version);
130        }
131    }
132    CheckResult::Fail("Python not found".into())
133}
134
135fn check_nvidia_driver() -> CheckResult {
136    if let Some(output) = run_cmd(
137        "nvidia-smi",
138        &["--query-gpu=driver_version", "--format=csv,noheader"],
139    ) {
140        let version = output.lines().next().unwrap_or(&output);
141        CheckResult::Pass(format!("driver v{version}"))
142    } else {
143        CheckResult::Fail("nvidia-smi not found or no GPU detected".into())
144    }
145}
146
147fn check_cuda() -> CheckResult {
148    if let Some(output) = run_cmd("nvcc", &["--version"]) {
149        if let Some(line) = output.lines().find(|l| l.contains("release")) {
150            CheckResult::Pass(line.trim().to_string())
151        } else {
152            CheckResult::Pass(output)
153        }
154    } else {
155        CheckResult::Warn("nvcc not found — CUDA toolkit may not be installed".into())
156    }
157}
158
159fn check_pytorch() -> CheckResult {
160    if let Some(version) = run_cmd("python3", &["-c", "import torch; print(torch.__version__)"]) {
161        CheckResult::Pass(format!("PyTorch {version}"))
162    } else if let Some(version) =
163        run_cmd("python", &["-c", "import torch; print(torch.__version__)"])
164    {
165        CheckResult::Pass(format!("PyTorch {version}"))
166    } else {
167        CheckResult::Warn("PyTorch not installed".into())
168    }
169}
170
171fn check_pytorch_cuda() -> CheckResult {
172    let script = "import torch; print(torch.cuda.is_available(), torch.cuda.device_count() if torch.cuda.is_available() else 0)";
173    if let Some(output) =
174        run_cmd("python3", &["-c", script]).or_else(|| run_cmd("python", &["-c", script]))
175    {
176        let parts: Vec<&str> = output.split_whitespace().collect();
177        match parts.first().map(|s| s.as_ref()) {
178            Some("True") => {
179                let gpus = parts.get(1).unwrap_or(&"?");
180                CheckResult::Pass(format!("CUDA available — {gpus} GPU(s)"))
181            }
182            Some("False") => CheckResult::Warn("CUDA not available to PyTorch".into()),
183            _ => CheckResult::Warn(format!("unexpected output: {output}")),
184        }
185    } else {
186        CheckResult::Skip("PyTorch not installed".into())
187    }
188}
189
190fn check_git() -> CheckResult {
191    if let Some(version) = run_cmd("git", &["--version"]) {
192        CheckResult::Pass(version)
193    } else {
194        CheckResult::Warn("git not found — experiment tracking won't record commits".into())
195    }
196}
197
198fn check_zerneld() -> CheckResult {
199    let port = crate::telemetry::client::metrics_port();
200    let addr = format!("127.0.0.1:{port}");
201    let result = std::net::TcpStream::connect_timeout(
202        &addr.parse().expect("valid address"),
203        std::time::Duration::from_millis(500),
204    );
205    match result {
206        Ok(_) => CheckResult::Pass(format!("running on port {port}")),
207        Err(_) => CheckResult::Warn("not running — start with: zerneld --simulate".into()),
208    }
209}
210
211fn check_zernel_db() -> CheckResult {
212    let db_path = crate::experiments::tracker::experiments_db_path();
213    if db_path.exists() {
214        CheckResult::Pass(format!("{}", db_path.display()))
215    } else {
216        CheckResult::Pass(format!("will be created at {}", db_path.display()))
217    }
218}