zernel/commands/
doctor.rs1use anyhow::Result;
4use std::process::Command;
5
6struct Check {
7 name: &'static str,
8 check: fn() -> CheckResult,
9}
10
11enum CheckResult {
12 Pass(String),
13 Warn(String),
14 Fail(String),
15 Skip(String),
16}
17
18pub async fn run() -> Result<()> {
20 println!("Zernel Doctor v{}", env!("CARGO_PKG_VERSION"));
21 println!("{}", "=".repeat(50));
22 println!();
23
24 let checks: Vec<Check> = vec![
25 Check {
26 name: "Operating System",
27 check: check_os,
28 },
29 Check {
30 name: "Python",
31 check: check_python,
32 },
33 Check {
34 name: "NVIDIA Driver",
35 check: check_nvidia_driver,
36 },
37 Check {
38 name: "CUDA Toolkit",
39 check: check_cuda,
40 },
41 Check {
42 name: "PyTorch",
43 check: check_pytorch,
44 },
45 Check {
46 name: "PyTorch CUDA",
47 check: check_pytorch_cuda,
48 },
49 Check {
50 name: "Git",
51 check: check_git,
52 },
53 Check {
54 name: "zerneld",
55 check: check_zerneld,
56 },
57 Check {
58 name: "Zernel DB",
59 check: check_zernel_db,
60 },
61 ];
62
63 let mut pass = 0;
64 let mut warn = 0;
65 let mut fail = 0;
66 let mut skip = 0;
67
68 for check in &checks {
69 let result = (check.check)();
70 match &result {
71 CheckResult::Pass(msg) => {
72 println!(" [OK] {}: {msg}", check.name);
73 pass += 1;
74 }
75 CheckResult::Warn(msg) => {
76 println!(" [WARN] {}: {msg}", check.name);
77 warn += 1;
78 }
79 CheckResult::Fail(msg) => {
80 println!(" [FAIL] {}: {msg}", check.name);
81 fail += 1;
82 }
83 CheckResult::Skip(msg) => {
84 println!(" [SKIP] {}: {msg}", check.name);
85 skip += 1;
86 }
87 }
88 }
89
90 println!();
91 println!("Results: {pass} passed, {warn} warnings, {fail} failed, {skip} skipped");
92
93 if fail > 0 {
94 println!();
95 println!("Some checks failed. Fix the issues above before running ML workloads.");
96 }
97
98 Ok(())
99}
100
101fn run_cmd(cmd: &str, args: &[&str]) -> Option<String> {
102 Command::new(cmd)
103 .args(args)
104 .output()
105 .ok()
106 .filter(|o| o.status.success())
107 .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
108}
109
110fn check_os() -> CheckResult {
111 let os = std::env::consts::OS;
112 let arch = std::env::consts::ARCH;
113 if os == "linux" {
114 if let Some(kernel) = run_cmd("uname", &["-r"]) {
115 CheckResult::Pass(format!("Linux {kernel} ({arch})"))
116 } else {
117 CheckResult::Pass(format!("Linux ({arch})"))
118 }
119 } else {
120 CheckResult::Warn(format!(
121 "{os} ({arch}) — Zernel targets Linux. Some features unavailable."
122 ))
123 }
124}
125
126fn check_python() -> CheckResult {
127 for cmd in &["python3", "python"] {
128 if let Some(version) = run_cmd(cmd, &["--version"]) {
129 return CheckResult::Pass(version);
130 }
131 }
132 CheckResult::Fail("Python not found".into())
133}
134
135fn check_nvidia_driver() -> CheckResult {
136 if let Some(output) = run_cmd(
137 "nvidia-smi",
138 &["--query-gpu=driver_version", "--format=csv,noheader"],
139 ) {
140 let version = output.lines().next().unwrap_or(&output);
141 CheckResult::Pass(format!("driver v{version}"))
142 } else {
143 CheckResult::Fail("nvidia-smi not found or no GPU detected".into())
144 }
145}
146
147fn check_cuda() -> CheckResult {
148 if let Some(output) = run_cmd("nvcc", &["--version"]) {
149 if let Some(line) = output.lines().find(|l| l.contains("release")) {
150 CheckResult::Pass(line.trim().to_string())
151 } else {
152 CheckResult::Pass(output)
153 }
154 } else {
155 CheckResult::Warn("nvcc not found — CUDA toolkit may not be installed".into())
156 }
157}
158
159fn check_pytorch() -> CheckResult {
160 if let Some(version) = run_cmd("python3", &["-c", "import torch; print(torch.__version__)"]) {
161 CheckResult::Pass(format!("PyTorch {version}"))
162 } else if let Some(version) =
163 run_cmd("python", &["-c", "import torch; print(torch.__version__)"])
164 {
165 CheckResult::Pass(format!("PyTorch {version}"))
166 } else {
167 CheckResult::Warn("PyTorch not installed".into())
168 }
169}
170
171fn check_pytorch_cuda() -> CheckResult {
172 let script = "import torch; print(torch.cuda.is_available(), torch.cuda.device_count() if torch.cuda.is_available() else 0)";
173 if let Some(output) =
174 run_cmd("python3", &["-c", script]).or_else(|| run_cmd("python", &["-c", script]))
175 {
176 let parts: Vec<&str> = output.split_whitespace().collect();
177 match parts.first().map(|s| s.as_ref()) {
178 Some("True") => {
179 let gpus = parts.get(1).unwrap_or(&"?");
180 CheckResult::Pass(format!("CUDA available — {gpus} GPU(s)"))
181 }
182 Some("False") => CheckResult::Warn("CUDA not available to PyTorch".into()),
183 _ => CheckResult::Warn(format!("unexpected output: {output}")),
184 }
185 } else {
186 CheckResult::Skip("PyTorch not installed".into())
187 }
188}
189
190fn check_git() -> CheckResult {
191 if let Some(version) = run_cmd("git", &["--version"]) {
192 CheckResult::Pass(version)
193 } else {
194 CheckResult::Warn("git not found — experiment tracking won't record commits".into())
195 }
196}
197
198fn check_zerneld() -> CheckResult {
199 let port = crate::telemetry::client::metrics_port();
200 let addr = format!("127.0.0.1:{port}");
201 let result = std::net::TcpStream::connect_timeout(
202 &addr.parse().expect("valid address"),
203 std::time::Duration::from_millis(500),
204 );
205 match result {
206 Ok(_) => CheckResult::Pass(format!("running on port {port}")),
207 Err(_) => CheckResult::Warn("not running — start with: zerneld --simulate".into()),
208 }
209}
210
211fn check_zernel_db() -> CheckResult {
212 let db_path = crate::experiments::tracker::experiments_db_path();
213 if db_path.exists() {
214 CheckResult::Pass(format!("{}", db_path.display()))
215 } else {
216 CheckResult::Pass(format!("will be created at {}", db_path.display()))
217 }
218}