zernel/commands/
job_k8s.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! Kubernetes-based distributed training backend.
4//!
5//! Generates PyTorchJob YAML manifests and manages them via kubectl.
6
7use anyhow::{Context, Result};
8use std::path::Path;
9
10/// Generate a Kubeflow PyTorchJob YAML manifest.
11fn generate_pytorchjob_yaml(
12    job_id: &str,
13    image: &str,
14    script: &str,
15    gpus_per_node: u32,
16    nodes: u32,
17    namespace: &str,
18    args: &[String],
19) -> String {
20    let script_args = if args.is_empty() {
21        String::new()
22    } else {
23        format!(
24            "\n{}",
25            args.iter()
26                .map(|a| format!("            - \"{}\"", a))
27                .collect::<Vec<_>>()
28                .join("\n")
29        )
30    };
31
32    format!(
33        r#"apiVersion: kubeflow.org/v1
34kind: PyTorchJob
35metadata:
36  name: {job_id}
37  namespace: {namespace}
38spec:
39  pytorchReplicaSpecs:
40    Master:
41      replicas: 1
42      restartPolicy: OnFailure
43      template:
44        spec:
45          containers:
46            - name: pytorch
47              image: {image}
48              command:
49                - python3
50                - {script}{script_args}
51              resources:
52                limits:
53                  nvidia.com/gpu: {gpus_per_node}
54              env:
55                - name: NCCL_DEBUG
56                  value: "WARN"
57    Worker:
58      replicas: {workers}
59      restartPolicy: OnFailure
60      template:
61        spec:
62          containers:
63            - name: pytorch
64              image: {image}
65              command:
66                - python3
67                - {script}{script_args}
68              resources:
69                limits:
70                  nvidia.com/gpu: {gpus_per_node}
71              env:
72                - name: NCCL_DEBUG
73                  value: "WARN"
74"#,
75        job_id = job_id,
76        namespace = namespace,
77        image = image,
78        script = script,
79        gpus_per_node = gpus_per_node,
80        workers = nodes.saturating_sub(1),
81        script_args = script_args,
82    )
83}
84
85/// Submit a distributed training job to Kubernetes via PyTorchJob CRD.
86#[allow(clippy::too_many_arguments)]
87pub async fn run_k8s_job(
88    job_id: &str,
89    script: &str,
90    image: &str,
91    gpus_per_node: u32,
92    nodes: u32,
93    namespace: &str,
94    args: &[String],
95    log_dir: &Path,
96) -> Result<i32> {
97    // Check kubectl
98    let kubectl_check = std::process::Command::new("kubectl")
99        .args(["version", "--client", "--short"])
100        .output();
101
102    match kubectl_check {
103        Ok(o) if o.status.success() => {}
104        _ => anyhow::bail!("kubectl not found. Install Kubernetes CLI first."),
105    }
106
107    // Generate manifest
108    let yaml =
109        generate_pytorchjob_yaml(job_id, image, script, gpus_per_node, nodes, namespace, args);
110    let manifest_path = log_dir.join("pytorchjob.yaml");
111    std::fs::write(&manifest_path, &yaml)?;
112
113    println!("Kubernetes PyTorchJob");
114    println!("  Image:     {image}");
115    println!("  Nodes:     {nodes} (1 master + {} workers)", nodes - 1);
116    println!("  GPUs/node: {gpus_per_node}");
117    println!("  Namespace: {namespace}");
118    println!("  Manifest:  {}", manifest_path.display());
119    println!();
120
121    // Apply manifest
122    println!("Applying PyTorchJob...");
123    let apply = tokio::process::Command::new("kubectl")
124        .args(["apply", "-f"])
125        .arg(&manifest_path)
126        .status()
127        .await
128        .with_context(|| "kubectl apply failed")?;
129
130    if !apply.success() {
131        anyhow::bail!("kubectl apply failed for {}", manifest_path.display());
132    }
133
134    println!("Job submitted: {job_id}");
135    println!();
136
137    // Poll pod status
138    println!("Waiting for pods...");
139    let mut last_status = String::new();
140    for _ in 0..120 {
141        // 10 min timeout
142        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
143
144        let status_output = tokio::process::Command::new("kubectl")
145            .args([
146                "get",
147                "pods",
148                "-l",
149                &format!("training.kubeflow.org/job-name={job_id}"),
150                "-n",
151                namespace,
152                "--no-headers",
153                "-o",
154                "custom-columns=NAME:.metadata.name,STATUS:.status.phase",
155            ])
156            .output()
157            .await;
158
159        if let Ok(output) = status_output {
160            let status = String::from_utf8_lossy(&output.stdout).to_string();
161            if status != last_status {
162                print!("{status}");
163                last_status = status.clone();
164            }
165
166            if status.contains("Running") || status.contains("Succeeded") {
167                break;
168            }
169            if status.contains("Failed") || status.contains("Error") {
170                println!("Job failed. Check: kubectl logs -l training.kubeflow.org/job-name={job_id} -n {namespace}");
171                return Ok(1);
172            }
173        }
174    }
175
176    // Stream master logs
177    println!();
178    println!("--- Master logs ---");
179    let log_status = tokio::process::Command::new("kubectl")
180        .args([
181            "logs",
182            "-f",
183            "-l",
184            &format!(
185                "training.kubeflow.org/job-name={job_id},training.kubeflow.org/replica-type=master"
186            ),
187            "-n",
188            namespace,
189        ])
190        .status()
191        .await
192        .unwrap_or_default();
193
194    Ok(if log_status.success() { 0 } else { 1 })
195}
196
197/// Cancel a Kubernetes job.
198pub async fn cancel_k8s_job(job_id: &str, namespace: &str) -> Result<()> {
199    println!("Deleting PyTorchJob {job_id}...");
200    let status = tokio::process::Command::new("kubectl")
201        .args(["delete", "pytorchjob", job_id, "-n", namespace])
202        .status()
203        .await?;
204
205    if status.success() {
206        println!("Job deleted.");
207    } else {
208        println!("Failed to delete job (may already be cleaned up).");
209    }
210    Ok(())
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn generates_valid_yaml() {
219        let yaml = generate_pytorchjob_yaml(
220            "test-job",
221            "myimage:latest",
222            "train.py",
223            4,
224            3,
225            "default",
226            &["--epochs".into(), "10".into()],
227        );
228        assert!(yaml.contains("kind: PyTorchJob"));
229        assert!(yaml.contains("test-job"));
230        assert!(yaml.contains("myimage:latest"));
231        assert!(yaml.contains("nvidia.com/gpu: 4"));
232        assert!(yaml.contains("replicas: 2")); // 3 nodes - 1 master = 2 workers
233        assert!(yaml.contains("--epochs"));
234    }
235}