zernel/commands/
job_ssh.rs1use anyhow::{Context, Result};
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, BufReader};
11
12pub fn parse_hosts(spec: &str) -> Result<Vec<String>> {
15 let path = std::path::Path::new(spec);
16 if path.exists() {
17 let content = std::fs::read_to_string(path)?;
18 Ok(content
19 .lines()
20 .map(|l| l.trim().to_string())
21 .filter(|l| !l.is_empty() && !l.starts_with('#'))
22 .collect())
23 } else {
24 Ok(spec.split(',').map(|s| s.trim().to_string()).collect())
25 }
26}
27
28#[allow(clippy::too_many_arguments)]
30pub async fn run_ssh_job(
31 job_id: &str,
32 script: &str,
33 hosts: &[String],
34 gpus_per_node: u32,
35 framework: &str,
36 backend: &str,
37 args: &[String],
38 log_dir: &std::path::Path,
39) -> Result<i32> {
40 let master_addr = &hosts[0];
41 let master_port = 29500;
42 let num_nodes = hosts.len();
43
44 println!("SSH Multi-Node Launch");
45 println!(" Master: {master_addr}:{master_port}");
46 println!(" Nodes: {num_nodes}");
47 println!(" Hosts: {}", hosts.join(", "));
48 println!();
49
50 let mut handles = Vec::new();
51
52 for (rank, host) in hosts.iter().enumerate() {
53 let job_dir = format!("/tmp/zernel-{job_id}");
54
55 let setup = tokio::process::Command::new("ssh")
57 .args([
58 "-o",
59 "BatchMode=yes",
60 "-o",
61 "StrictHostKeyChecking=no",
62 host,
63 &format!("mkdir -p {job_dir}"),
64 ])
65 .status()
66 .await
67 .with_context(|| format!("SSH to {host} failed — is passwordless SSH configured?"))?;
68
69 if !setup.success() {
70 anyhow::bail!("failed to create directory on {host}");
71 }
72
73 let scp = tokio::process::Command::new("scp")
74 .args([
75 "-o",
76 "BatchMode=yes",
77 "-o",
78 "StrictHostKeyChecking=no",
79 script,
80 &format!("{host}:{job_dir}/"),
81 ])
82 .status()
83 .await
84 .with_context(|| format!("SCP to {host} failed"))?;
85
86 if !scp.success() {
87 anyhow::bail!("failed to copy script to {host}");
88 }
89
90 let script_name = std::path::Path::new(script)
92 .file_name()
93 .map(|f| f.to_string_lossy().to_string())
94 .unwrap_or_else(|| script.to_string());
95
96 let remote_cmd = match framework {
97 "pytorch" => {
98 let mut cmd_parts = vec![
99 "cd".to_string(),
100 job_dir.clone(),
101 "&&".into(),
102 "torchrun".into(),
103 format!("--nproc_per_node={gpus_per_node}"),
104 format!("--nnodes={num_nodes}"),
105 format!("--node_rank={rank}"),
106 format!("--master_addr={master_addr}"),
107 format!("--master_port={master_port}"),
108 script_name,
109 ];
110 cmd_parts.extend(args.iter().cloned());
111 cmd_parts.join(" ")
112 }
113 "accelerate" | "hf" => {
114 let total = gpus_per_node * num_nodes as u32;
115 let mut cmd_parts = vec![
116 "cd".to_string(),
117 job_dir.clone(),
118 "&&".into(),
119 "accelerate".into(),
120 "launch".into(),
121 format!("--num_processes={total}"),
122 format!("--machine_rank={rank}"),
123 format!("--main_process_ip={master_addr}"),
124 format!("--main_process_port={master_port}"),
125 script_name,
126 ];
127 cmd_parts.extend(args.iter().cloned());
128 cmd_parts.join(" ")
129 }
130 _ => anyhow::bail!("unsupported framework: {framework}"),
131 };
132
133 let env_prefix = format!(
135 "NCCL_SOCKET_IFNAME=eth0 NCCL_DEBUG=WARN{}",
136 if backend == "nccl" {
137 " NCCL_P2P_DISABLE=0"
138 } else {
139 ""
140 }
141 );
142
143 let full_cmd = format!("{env_prefix} {remote_cmd}");
144
145 let host_clone = host.clone();
147 let log_file = log_dir.join(format!("rank-{rank}.log"));
148
149 let handle = tokio::spawn(async move {
150 let mut child = tokio::process::Command::new("ssh")
151 .args([
152 "-o",
153 "BatchMode=yes",
154 "-o",
155 "StrictHostKeyChecking=no",
156 &host_clone,
157 &full_cmd,
158 ])
159 .stdout(Stdio::piped())
160 .stderr(Stdio::piped())
161 .spawn()
162 .expect("failed to spawn SSH");
163
164 let stdout = child.stdout.take();
165 let stderr = child.stderr.take();
166
167 let rank_clone = rank;
169 let log_file_clone = log_file.clone();
170
171 let out_handle = tokio::spawn(async move {
172 if let Some(stdout) = stdout {
173 let mut reader = BufReader::new(stdout);
174 let mut line = String::new();
175 let mut log = std::fs::File::create(&log_file_clone).ok();
176 loop {
177 line.clear();
178 match reader.read_line(&mut line).await {
179 Ok(0) => break,
180 Ok(_) => {
181 print!("[rank {rank_clone}] {line}");
182 if let Some(ref mut f) = log {
183 use std::io::Write;
184 let _ = write!(f, "{line}");
185 }
186 }
187 Err(_) => break,
188 }
189 }
190 }
191 });
192
193 let err_handle = tokio::spawn(async move {
194 if let Some(stderr) = stderr {
195 let mut reader = BufReader::new(stderr);
196 let mut line = String::new();
197 loop {
198 line.clear();
199 match reader.read_line(&mut line).await {
200 Ok(0) => break,
201 Ok(_) => eprint!("[rank {rank}] {line}"),
202 Err(_) => break,
203 }
204 }
205 }
206 });
207
208 let (_, _) = tokio::join!(out_handle, err_handle);
209 let status = child
210 .wait()
211 .await
212 .unwrap_or_else(|_| std::process::ExitStatus::default());
213 status.code().unwrap_or(-1)
214 });
215
216 handles.push(handle);
217 }
218
219 let mut exit_code = 0;
221 for handle in handles {
222 let code = handle.await.unwrap_or(-1);
223 if code != 0 {
224 exit_code = code;
225 }
226 }
227
228 Ok(exit_code)
229}
230
231pub async fn cancel_ssh_job(job_id: &str, hosts: &[String]) -> Result<()> {
233 for host in hosts {
234 println!(" Killing job on {host}...");
235 let _ = tokio::process::Command::new("ssh")
236 .args([
237 "-o",
238 "BatchMode=yes",
239 "-o",
240 "StrictHostKeyChecking=no",
241 host,
242 &format!("pkill -f zernel-{job_id} || true"),
243 ])
244 .status()
245 .await;
246 }
247 Ok(())
248}