1use anyhow::Result;
6use clap::Subcommand;
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::process::Command;
10
11#[derive(Subcommand)]
12pub enum ClusterCommands {
13 Add {
15 host: String,
17 #[arg(long, default_value = "8")]
19 gpus: u32,
20 #[arg(long, default_value = "root")]
22 user: String,
23 },
24 Remove {
26 host: String,
28 },
29 Status,
31 Ssh {
33 host: String,
35 },
36 Sync {
38 path: String,
40 #[arg(long, default_value = "~/")]
42 to: String,
43 },
44 Run {
46 command: String,
48 #[arg(long)]
50 on: Option<String>,
51 },
52 Drain {
54 host: String,
56 },
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60struct ClusterNode {
61 host: String,
62 user: String,
63 gpus: u32,
64 status: String,
65}
66
67fn cluster_file() -> PathBuf {
68 let dir = crate::experiments::tracker::zernel_dir().join("cluster");
69 std::fs::create_dir_all(&dir).ok();
70 dir.join("nodes.json")
71}
72
73fn load_nodes() -> Vec<ClusterNode> {
74 let path = cluster_file();
75 if path.exists() {
76 std::fs::read_to_string(&path)
77 .ok()
78 .and_then(|s| serde_json::from_str(&s).ok())
79 .unwrap_or_default()
80 } else {
81 Vec::new()
82 }
83}
84
85fn save_nodes(nodes: &[ClusterNode]) -> Result<()> {
86 std::fs::write(cluster_file(), serde_json::to_string_pretty(nodes)?)?;
87 Ok(())
88}
89
90fn ssh_cmd(user: &str, host: &str, cmd: &str) -> Command {
91 let mut c = Command::new("ssh");
92 c.args([
93 "-o",
94 "BatchMode=yes",
95 "-o",
96 "StrictHostKeyChecking=no",
97 "-o",
98 "ConnectTimeout=5",
99 &format!("{user}@{host}"),
100 cmd,
101 ]);
102 c
103}
104
105pub async fn run(cmd: ClusterCommands) -> Result<()> {
106 match cmd {
107 ClusterCommands::Add { host, gpus, user } => {
108 let mut nodes = load_nodes();
109 nodes.retain(|n| n.host != host);
110
111 print!("Testing SSH to {user}@{host}... ");
113 let test = ssh_cmd(&user, &host, "echo ok").output();
114 match test {
115 Ok(o) if o.status.success() => println!("OK"),
116 _ => {
117 println!("FAILED");
118 println!("Ensure passwordless SSH is configured:");
119 println!(" ssh-copy-id {user}@{host}");
120 return Ok(());
121 }
122 }
123
124 nodes.push(ClusterNode {
125 host: host.clone(),
126 user,
127 gpus,
128 status: "active".into(),
129 });
130 save_nodes(&nodes)?;
131 println!(
132 "Added {host} ({gpus} GPUs) to cluster ({} total nodes)",
133 nodes.len()
134 );
135 }
136
137 ClusterCommands::Remove { host } => {
138 let mut nodes = load_nodes();
139 let before = nodes.len();
140 nodes.retain(|n| n.host != host);
141 save_nodes(&nodes)?;
142 if nodes.len() < before {
143 println!("Removed {host} from cluster");
144 } else {
145 println!("Node {host} not found in cluster");
146 }
147 }
148
149 ClusterCommands::Status => {
150 let nodes = load_nodes();
151 if nodes.is_empty() {
152 println!("No nodes in cluster. Add one: zernel cluster add <host> --gpus 8");
153 return Ok(());
154 }
155
156 println!("Zernel Cluster Status");
157 println!("{}", "=".repeat(70));
158 println!(
159 "{:<20} {:<8} {:>5} {:>10} {:>10} {:>8}",
160 "Host", "Status", "GPUs", "GPU Util", "Memory", "Temp"
161 );
162 println!("{}", "-".repeat(70));
163
164 for node in &nodes {
165 let info = ssh_cmd(&node.user, &node.host,
167 "nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu --format=csv,noheader,nounits 2>/dev/null | head -1"
168 ).output();
169
170 match info {
171 Ok(o) if o.status.success() => {
172 let data = String::from_utf8_lossy(&o.stdout);
173 let f: Vec<&str> = data.trim().split(',').map(|s| s.trim()).collect();
174 if f.len() >= 4 {
175 println!(
176 "{:<20} {:<8} {:>5} {:>8}% {:>5}/{:<4}MB {:>5}°C",
177 node.host, "online", node.gpus, f[0], f[1], f[2], f[3]
178 );
179 } else {
180 println!("{:<20} {:<8} {:>5}", node.host, "online", node.gpus);
181 }
182 }
183 _ => {
184 println!("{:<20} {:<8} {:>5}", node.host, "offline", node.gpus);
185 }
186 }
187 }
188
189 let total_gpus: u32 = nodes.iter().map(|n| n.gpus).sum();
190 println!();
191 println!("Total: {} nodes, {} GPUs", nodes.len(), total_gpus);
192 }
193
194 ClusterCommands::Ssh { host } => {
195 let nodes = load_nodes();
196 let node = nodes.iter().find(|n| n.host == host);
197 match node {
198 Some(n) => {
199 let status = Command::new("ssh")
200 .args([&format!("{}@{}", n.user, n.host)])
201 .status()?;
202 let _ = status;
203 }
204 None => println!("Node {host} not in cluster. Add it: zernel cluster add {host}"),
205 }
206 }
207
208 ClusterCommands::Sync { path, to } => {
209 let nodes = load_nodes();
210 if nodes.is_empty() {
211 println!("No nodes in cluster");
212 return Ok(());
213 }
214
215 for node in &nodes {
216 print!("Syncing to {}@{}:{}... ", node.user, node.host, to);
217 let status = Command::new("rsync")
218 .args([
219 "-avz",
220 "--progress",
221 &path,
222 &format!("{}@{}:{}", node.user, node.host, to),
223 ])
224 .output();
225 match status {
226 Ok(o) if o.status.success() => println!("OK"),
227 _ => println!("FAILED"),
228 }
229 }
230 }
231
232 ClusterCommands::Run { command, on } => {
233 let nodes = load_nodes();
234 let targets: Vec<&ClusterNode> = if let Some(ref host) = on {
235 nodes.iter().filter(|n| n.host == *host).collect()
236 } else {
237 nodes.iter().collect()
238 };
239
240 for node in targets {
241 println!("--- {}@{} ---", node.user, node.host);
242 let output = ssh_cmd(&node.user, &node.host, &command).output();
243 match output {
244 Ok(o) => {
245 print!("{}", String::from_utf8_lossy(&o.stdout));
246 if !o.stderr.is_empty() {
247 eprint!("{}", String::from_utf8_lossy(&o.stderr));
248 }
249 }
250 Err(e) => println!(" ERROR: {e}"),
251 }
252 println!();
253 }
254 }
255
256 ClusterCommands::Drain { host } => {
257 println!("Draining {host}...");
258 let mut nodes = load_nodes();
259 if let Some(node) = nodes.iter_mut().find(|n| n.host == host) {
260 node.status = "draining".into();
261 save_nodes(&nodes)?;
262 println!(" Status set to 'draining'");
263 println!(" New jobs will not be scheduled to this node");
264 println!(" Existing jobs will complete normally");
265 } else {
266 println!("Node {host} not found");
267 }
268 }
269 }
270 Ok(())
271}