zernel/commands/
job_k8s.rs1use anyhow::{Context, Result};
8use std::path::Path;
9
10fn generate_pytorchjob_yaml(
12 job_id: &str,
13 image: &str,
14 script: &str,
15 gpus_per_node: u32,
16 nodes: u32,
17 namespace: &str,
18 args: &[String],
19) -> String {
20 let script_args = if args.is_empty() {
21 String::new()
22 } else {
23 format!(
24 "\n{}",
25 args.iter()
26 .map(|a| format!(" - \"{}\"", a))
27 .collect::<Vec<_>>()
28 .join("\n")
29 )
30 };
31
32 format!(
33 r#"apiVersion: kubeflow.org/v1
34kind: PyTorchJob
35metadata:
36 name: {job_id}
37 namespace: {namespace}
38spec:
39 pytorchReplicaSpecs:
40 Master:
41 replicas: 1
42 restartPolicy: OnFailure
43 template:
44 spec:
45 containers:
46 - name: pytorch
47 image: {image}
48 command:
49 - python3
50 - {script}{script_args}
51 resources:
52 limits:
53 nvidia.com/gpu: {gpus_per_node}
54 env:
55 - name: NCCL_DEBUG
56 value: "WARN"
57 Worker:
58 replicas: {workers}
59 restartPolicy: OnFailure
60 template:
61 spec:
62 containers:
63 - name: pytorch
64 image: {image}
65 command:
66 - python3
67 - {script}{script_args}
68 resources:
69 limits:
70 nvidia.com/gpu: {gpus_per_node}
71 env:
72 - name: NCCL_DEBUG
73 value: "WARN"
74"#,
75 job_id = job_id,
76 namespace = namespace,
77 image = image,
78 script = script,
79 gpus_per_node = gpus_per_node,
80 workers = nodes.saturating_sub(1),
81 script_args = script_args,
82 )
83}
84
85#[allow(clippy::too_many_arguments)]
87pub async fn run_k8s_job(
88 job_id: &str,
89 script: &str,
90 image: &str,
91 gpus_per_node: u32,
92 nodes: u32,
93 namespace: &str,
94 args: &[String],
95 log_dir: &Path,
96) -> Result<i32> {
97 let kubectl_check = std::process::Command::new("kubectl")
99 .args(["version", "--client", "--short"])
100 .output();
101
102 match kubectl_check {
103 Ok(o) if o.status.success() => {}
104 _ => anyhow::bail!("kubectl not found. Install Kubernetes CLI first."),
105 }
106
107 let yaml =
109 generate_pytorchjob_yaml(job_id, image, script, gpus_per_node, nodes, namespace, args);
110 let manifest_path = log_dir.join("pytorchjob.yaml");
111 std::fs::write(&manifest_path, &yaml)?;
112
113 println!("Kubernetes PyTorchJob");
114 println!(" Image: {image}");
115 println!(" Nodes: {nodes} (1 master + {} workers)", nodes - 1);
116 println!(" GPUs/node: {gpus_per_node}");
117 println!(" Namespace: {namespace}");
118 println!(" Manifest: {}", manifest_path.display());
119 println!();
120
121 println!("Applying PyTorchJob...");
123 let apply = tokio::process::Command::new("kubectl")
124 .args(["apply", "-f"])
125 .arg(&manifest_path)
126 .status()
127 .await
128 .with_context(|| "kubectl apply failed")?;
129
130 if !apply.success() {
131 anyhow::bail!("kubectl apply failed for {}", manifest_path.display());
132 }
133
134 println!("Job submitted: {job_id}");
135 println!();
136
137 println!("Waiting for pods...");
139 let mut last_status = String::new();
140 for _ in 0..120 {
141 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
143
144 let status_output = tokio::process::Command::new("kubectl")
145 .args([
146 "get",
147 "pods",
148 "-l",
149 &format!("training.kubeflow.org/job-name={job_id}"),
150 "-n",
151 namespace,
152 "--no-headers",
153 "-o",
154 "custom-columns=NAME:.metadata.name,STATUS:.status.phase",
155 ])
156 .output()
157 .await;
158
159 if let Ok(output) = status_output {
160 let status = String::from_utf8_lossy(&output.stdout).to_string();
161 if status != last_status {
162 print!("{status}");
163 last_status = status.clone();
164 }
165
166 if status.contains("Running") || status.contains("Succeeded") {
167 break;
168 }
169 if status.contains("Failed") || status.contains("Error") {
170 println!("Job failed. Check: kubectl logs -l training.kubeflow.org/job-name={job_id} -n {namespace}");
171 return Ok(1);
172 }
173 }
174 }
175
176 println!();
178 println!("--- Master logs ---");
179 let log_status = tokio::process::Command::new("kubectl")
180 .args([
181 "logs",
182 "-f",
183 "-l",
184 &format!(
185 "training.kubeflow.org/job-name={job_id},training.kubeflow.org/replica-type=master"
186 ),
187 "-n",
188 namespace,
189 ])
190 .status()
191 .await
192 .unwrap_or_default();
193
194 Ok(if log_status.success() { 0 } else { 1 })
195}
196
197pub async fn cancel_k8s_job(job_id: &str, namespace: &str) -> Result<()> {
199 println!("Deleting PyTorchJob {job_id}...");
200 let status = tokio::process::Command::new("kubectl")
201 .args(["delete", "pytorchjob", job_id, "-n", namespace])
202 .status()
203 .await?;
204
205 if status.success() {
206 println!("Job deleted.");
207 } else {
208 println!("Failed to delete job (may already be cleaned up).");
209 }
210 Ok(())
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn generates_valid_yaml() {
219 let yaml = generate_pytorchjob_yaml(
220 "test-job",
221 "myimage:latest",
222 "train.py",
223 4,
224 3,
225 "default",
226 &["--epochs".into(), "10".into()],
227 );
228 assert!(yaml.contains("kind: PyTorchJob"));
229 assert!(yaml.contains("test-job"));
230 assert!(yaml.contains("myimage:latest"));
231 assert!(yaml.contains("nvidia.com/gpu: 4"));
232 assert!(yaml.contains("replicas: 2")); assert!(yaml.contains("--epochs"));
234 }
235}