zernel/commands/
job_ssh.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! SSH-based multi-node distributed training backend.
4//!
5//! Distributes training across multiple nodes via passwordless SSH.
6//! Each node runs torchrun with the correct --node_rank and --master_addr.
7
8use anyhow::{Context, Result};
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, BufReader};
11
12/// Parse a hosts specification into a list of hostnames.
13/// Accepts: "host1,host2,host3" or a path to a file with one host per line.
14pub fn parse_hosts(spec: &str) -> Result<Vec<String>> {
15    let path = std::path::Path::new(spec);
16    if path.exists() {
17        let content = std::fs::read_to_string(path)?;
18        Ok(content
19            .lines()
20            .map(|l| l.trim().to_string())
21            .filter(|l| !l.is_empty() && !l.starts_with('#'))
22            .collect())
23    } else {
24        Ok(spec.split(',').map(|s| s.trim().to_string()).collect())
25    }
26}
27
28/// Launch a distributed training job across multiple nodes via SSH.
29#[allow(clippy::too_many_arguments)]
30pub async fn run_ssh_job(
31    job_id: &str,
32    script: &str,
33    hosts: &[String],
34    gpus_per_node: u32,
35    framework: &str,
36    backend: &str,
37    args: &[String],
38    log_dir: &std::path::Path,
39) -> Result<i32> {
40    let master_addr = &hosts[0];
41    let master_port = 29500;
42    let num_nodes = hosts.len();
43
44    println!("SSH Multi-Node Launch");
45    println!("  Master:  {master_addr}:{master_port}");
46    println!("  Nodes:   {num_nodes}");
47    println!("  Hosts:   {}", hosts.join(", "));
48    println!();
49
50    let mut handles = Vec::new();
51
52    for (rank, host) in hosts.iter().enumerate() {
53        let job_dir = format!("/tmp/zernel-{job_id}");
54
55        // 1. Create remote working directory + copy script
56        let setup = tokio::process::Command::new("ssh")
57            .args([
58                "-o",
59                "BatchMode=yes",
60                "-o",
61                "StrictHostKeyChecking=no",
62                host,
63                &format!("mkdir -p {job_dir}"),
64            ])
65            .status()
66            .await
67            .with_context(|| format!("SSH to {host} failed — is passwordless SSH configured?"))?;
68
69        if !setup.success() {
70            anyhow::bail!("failed to create directory on {host}");
71        }
72
73        let scp = tokio::process::Command::new("scp")
74            .args([
75                "-o",
76                "BatchMode=yes",
77                "-o",
78                "StrictHostKeyChecking=no",
79                script,
80                &format!("{host}:{job_dir}/"),
81            ])
82            .status()
83            .await
84            .with_context(|| format!("SCP to {host} failed"))?;
85
86        if !scp.success() {
87            anyhow::bail!("failed to copy script to {host}");
88        }
89
90        // 2. Build remote launch command
91        let script_name = std::path::Path::new(script)
92            .file_name()
93            .map(|f| f.to_string_lossy().to_string())
94            .unwrap_or_else(|| script.to_string());
95
96        let remote_cmd = match framework {
97            "pytorch" => {
98                let mut cmd_parts = vec![
99                    "cd".to_string(),
100                    job_dir.clone(),
101                    "&&".into(),
102                    "torchrun".into(),
103                    format!("--nproc_per_node={gpus_per_node}"),
104                    format!("--nnodes={num_nodes}"),
105                    format!("--node_rank={rank}"),
106                    format!("--master_addr={master_addr}"),
107                    format!("--master_port={master_port}"),
108                    script_name,
109                ];
110                cmd_parts.extend(args.iter().cloned());
111                cmd_parts.join(" ")
112            }
113            "accelerate" | "hf" => {
114                let total = gpus_per_node * num_nodes as u32;
115                let mut cmd_parts = vec![
116                    "cd".to_string(),
117                    job_dir.clone(),
118                    "&&".into(),
119                    "accelerate".into(),
120                    "launch".into(),
121                    format!("--num_processes={total}"),
122                    format!("--machine_rank={rank}"),
123                    format!("--main_process_ip={master_addr}"),
124                    format!("--main_process_port={master_port}"),
125                    script_name,
126                ];
127                cmd_parts.extend(args.iter().cloned());
128                cmd_parts.join(" ")
129            }
130            _ => anyhow::bail!("unsupported framework: {framework}"),
131        };
132
133        // Set NCCL environment
134        let env_prefix = format!(
135            "NCCL_SOCKET_IFNAME=eth0 NCCL_DEBUG=WARN{}",
136            if backend == "nccl" {
137                " NCCL_P2P_DISABLE=0"
138            } else {
139                ""
140            }
141        );
142
143        let full_cmd = format!("{env_prefix} {remote_cmd}");
144
145        // 3. Launch via SSH
146        let host_clone = host.clone();
147        let log_file = log_dir.join(format!("rank-{rank}.log"));
148
149        let handle = tokio::spawn(async move {
150            let mut child = tokio::process::Command::new("ssh")
151                .args([
152                    "-o",
153                    "BatchMode=yes",
154                    "-o",
155                    "StrictHostKeyChecking=no",
156                    &host_clone,
157                    &full_cmd,
158                ])
159                .stdout(Stdio::piped())
160                .stderr(Stdio::piped())
161                .spawn()
162                .expect("failed to spawn SSH");
163
164            let stdout = child.stdout.take();
165            let stderr = child.stderr.take();
166
167            // Stream output with rank prefix
168            let rank_clone = rank;
169            let log_file_clone = log_file.clone();
170
171            let out_handle = tokio::spawn(async move {
172                if let Some(stdout) = stdout {
173                    let mut reader = BufReader::new(stdout);
174                    let mut line = String::new();
175                    let mut log = std::fs::File::create(&log_file_clone).ok();
176                    loop {
177                        line.clear();
178                        match reader.read_line(&mut line).await {
179                            Ok(0) => break,
180                            Ok(_) => {
181                                print!("[rank {rank_clone}] {line}");
182                                if let Some(ref mut f) = log {
183                                    use std::io::Write;
184                                    let _ = write!(f, "{line}");
185                                }
186                            }
187                            Err(_) => break,
188                        }
189                    }
190                }
191            });
192
193            let err_handle = tokio::spawn(async move {
194                if let Some(stderr) = stderr {
195                    let mut reader = BufReader::new(stderr);
196                    let mut line = String::new();
197                    loop {
198                        line.clear();
199                        match reader.read_line(&mut line).await {
200                            Ok(0) => break,
201                            Ok(_) => eprint!("[rank {rank}] {line}"),
202                            Err(_) => break,
203                        }
204                    }
205                }
206            });
207
208            let (_, _) = tokio::join!(out_handle, err_handle);
209            let status = child
210                .wait()
211                .await
212                .unwrap_or_else(|_| std::process::ExitStatus::default());
213            status.code().unwrap_or(-1)
214        });
215
216        handles.push(handle);
217    }
218
219    // Wait for all nodes
220    let mut exit_code = 0;
221    for handle in handles {
222        let code = handle.await.unwrap_or(-1);
223        if code != 0 {
224            exit_code = code;
225        }
226    }
227
228    Ok(exit_code)
229}
230
231/// Cancel an SSH job by killing processes on all hosts.
232pub async fn cancel_ssh_job(job_id: &str, hosts: &[String]) -> Result<()> {
233    for host in hosts {
234        println!("  Killing job on {host}...");
235        let _ = tokio::process::Command::new("ssh")
236            .args([
237                "-o",
238                "BatchMode=yes",
239                "-o",
240                "StrictHostKeyChecking=no",
241                host,
242                &format!("pkill -f zernel-{job_id} || true"),
243            ])
244            .status()
245            .await;
246    }
247    Ok(())
248}