zernel/commands/
job.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3use crate::experiments::tracker;
4use anyhow::{Context, Result};
5use chrono::Utc;
6use clap::Subcommand;
7use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, BufReader};
12
13#[derive(Subcommand)]
14pub enum JobCommands {
15    /// Submit a distributed training job
16    Submit {
17        /// Script to run
18        script: String,
19        /// GPUs per node
20        #[arg(long, default_value = "1")]
21        gpus_per_node: u32,
22        /// Number of nodes
23        #[arg(long, default_value = "1")]
24        nodes: u32,
25        /// Framework (pytorch, accelerate)
26        #[arg(long, default_value = "pytorch")]
27        framework: String,
28        /// Communication backend (nccl, gloo)
29        #[arg(long, default_value = "nccl")]
30        backend: String,
31        /// Execution target (local, ssh, k8s)
32        #[arg(long, default_value = "local")]
33        target: String,
34        /// SSH hosts (comma-separated or file path) for --target ssh
35        #[arg(long)]
36        hosts: Option<String>,
37        /// Container image for --target k8s
38        #[arg(long)]
39        image: Option<String>,
40        /// Kubernetes namespace for --target k8s
41        #[arg(long, default_value = "default")]
42        namespace: String,
43        /// Additional script arguments
44        #[arg(trailing_var_arg = true)]
45        args: Vec<String>,
46    },
47    /// List running and completed jobs
48    List,
49    /// Show job status and output
50    Status {
51        /// Job ID
52        id: String,
53    },
54    /// Cancel a running job
55    Cancel {
56        /// Job ID
57        id: String,
58    },
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct Job {
63    id: String,
64    script: String,
65    status: String,
66    gpus_per_node: u32,
67    nodes: u32,
68    framework: String,
69    backend: String,
70    pid: Option<u32>,
71    submitted_at: String,
72    finished_at: Option<String>,
73    exit_code: Option<i32>,
74}
75
76fn jobs_db_path() -> PathBuf {
77    let dir = tracker::zernel_dir().join("jobs");
78    std::fs::create_dir_all(&dir).ok();
79    dir.join("jobs.db")
80}
81
82fn jobs_log_dir(job_id: &str) -> PathBuf {
83    let dir = tracker::zernel_dir().join("jobs").join(job_id);
84    std::fs::create_dir_all(&dir).ok();
85    dir
86}
87
88fn open_jobs_db() -> Result<Connection> {
89    let conn = Connection::open(jobs_db_path())?;
90    conn.execute_batch(
91        "CREATE TABLE IF NOT EXISTS jobs (
92            id TEXT PRIMARY KEY,
93            script TEXT NOT NULL,
94            status TEXT NOT NULL,
95            gpus_per_node INTEGER NOT NULL,
96            nodes INTEGER NOT NULL,
97            framework TEXT NOT NULL,
98            backend TEXT NOT NULL,
99            pid INTEGER,
100            submitted_at TEXT NOT NULL,
101            finished_at TEXT,
102            exit_code INTEGER
103        );",
104    )?;
105    Ok(conn)
106}
107
108fn generate_job_id() -> String {
109    let now = chrono::Utc::now();
110    let short = &uuid::Uuid::new_v4().to_string()[..8];
111    format!("job-{}-{}", now.format("%Y%m%d-%H%M%S"), short)
112}
113
114fn detect_gpu_count() -> u32 {
115    std::process::Command::new("nvidia-smi")
116        .args(["--query-gpu=count", "--format=csv,noheader"])
117        .output()
118        .ok()
119        .and_then(|o| {
120            if o.status.success() {
121                String::from_utf8_lossy(&o.stdout)
122                    .trim()
123                    .lines()
124                    .next()
125                    .and_then(|s| s.parse().ok())
126            } else {
127                None
128            }
129        })
130        .unwrap_or(0)
131}
132
133pub async fn run(cmd: JobCommands) -> Result<()> {
134    match cmd {
135        JobCommands::Submit {
136            script,
137            gpus_per_node,
138            nodes,
139            framework,
140            backend,
141            target,
142            hosts,
143            image,
144            namespace,
145            args,
146        } => {
147            let script_path = std::path::Path::new(&script);
148            if (target == "local" || target == "ssh") && !script_path.exists() {
149                anyhow::bail!("script not found: {script}");
150            }
151
152            // Dispatch to SSH or K8s backends
153            if target == "ssh" {
154                let host_list = super::job_ssh::parse_hosts(
155                    hosts
156                        .as_deref()
157                        .ok_or_else(|| anyhow::anyhow!("--hosts required for --target ssh"))?,
158                )?;
159                if host_list.len() < nodes as usize {
160                    anyhow::bail!("need {} hosts but only {} provided", nodes, host_list.len());
161                }
162                let job_id = generate_job_id();
163                let log_dir = jobs_log_dir(&job_id);
164
165                let conn = open_jobs_db()?;
166                conn.execute(
167                    "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, submitted_at) VALUES (?1, ?2, 'running', ?3, ?4, ?5, ?6, ?7)",
168                    (&job_id, &script, gpus_per_node, nodes, &framework, &backend, Utc::now().to_rfc3339()),
169                )?;
170
171                let exit_code = super::job_ssh::run_ssh_job(
172                    &job_id,
173                    &script,
174                    &host_list[..nodes as usize],
175                    gpus_per_node,
176                    &framework,
177                    &backend,
178                    &args,
179                    &log_dir,
180                )
181                .await?;
182
183                let status = if exit_code == 0 { "done" } else { "failed" };
184                conn.execute(
185                    "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
186                    (status, Utc::now().to_rfc3339(), exit_code, &job_id),
187                )?;
188                println!("\nJob {job_id}: {status} (exit {exit_code})");
189                return Ok(());
190            }
191
192            if target == "k8s" {
193                let img = image
194                    .as_deref()
195                    .ok_or_else(|| anyhow::anyhow!("--image required for --target k8s"))?;
196                let job_id = generate_job_id();
197                let log_dir = jobs_log_dir(&job_id);
198
199                let conn = open_jobs_db()?;
200                conn.execute(
201                    "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, submitted_at) VALUES (?1, ?2, 'running', ?3, ?4, ?5, ?6, ?7)",
202                    (&job_id, &script, gpus_per_node, nodes, &framework, &backend, Utc::now().to_rfc3339()),
203                )?;
204
205                let exit_code = super::job_k8s::run_k8s_job(
206                    &job_id,
207                    &script,
208                    img,
209                    gpus_per_node,
210                    nodes,
211                    &namespace,
212                    &args,
213                    &log_dir,
214                )
215                .await?;
216
217                let status = if exit_code == 0 { "done" } else { "failed" };
218                conn.execute(
219                    "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
220                    (status, Utc::now().to_rfc3339(), exit_code, &job_id),
221                )?;
222                println!("\nJob {job_id}: {status} (exit {exit_code})");
223                return Ok(());
224            }
225
226            // Local target (existing code)
227            if !script_path.exists() {
228                anyhow::bail!("script not found: {script}");
229            }
230
231            let detected_gpus = detect_gpu_count();
232            let gpus = if gpus_per_node == 1 && detected_gpus > 1 {
233                println!("Detected {detected_gpus} GPUs. Using all.");
234                detected_gpus
235            } else {
236                gpus_per_node
237            };
238
239            let total_procs = gpus * nodes;
240            let job_id = generate_job_id();
241            let log_dir = jobs_log_dir(&job_id);
242            let log_path = log_dir.join("output.log");
243
244            // Build launch command
245            let (launcher, launch_args) = match framework.as_str() {
246                "pytorch" => {
247                    let mut a = vec![
248                        format!("--nproc_per_node={gpus}"),
249                        format!("--nnodes={nodes}"),
250                        "--master_addr=localhost".into(),
251                        "--master_port=29500".into(),
252                        script.clone(),
253                    ];
254                    a.extend(args.clone());
255                    ("torchrun".to_string(), a)
256                }
257                "accelerate" | "hf" => {
258                    let mut a = vec![
259                        "launch".into(),
260                        format!("--num_processes={total_procs}"),
261                        script.clone(),
262                    ];
263                    a.extend(args.clone());
264                    ("accelerate".to_string(), a)
265                }
266                other => {
267                    anyhow::bail!("unsupported framework: {other}. Use 'pytorch' or 'accelerate'.");
268                }
269            };
270
271            println!("Zernel Job Submit");
272            println!("  Job ID:      {job_id}");
273            println!("  Script:      {script}");
274            println!("  Framework:   {framework}");
275            println!("  Backend:     {backend}");
276            println!("  GPUs/node:   {gpus}");
277            println!("  Nodes:       {nodes}");
278            println!("  Total procs: {total_procs}");
279            println!("  Launcher:    {launcher} {}", launch_args.join(" "));
280            println!("  Log:         {}", log_path.display());
281            println!();
282
283            // Set environment
284            let mut env_vars = vec![
285                ("NCCL_SOCKET_IFNAME", "eth0,en0".to_string()),
286                ("NCCL_DEBUG", "WARN".to_string()),
287            ];
288            if backend == "nccl" {
289                env_vars.push(("NCCL_P2P_DISABLE", "0".to_string()));
290            }
291
292            // Spawn process
293            let mut child = tokio::process::Command::new(&launcher)
294                .args(&launch_args)
295                .envs(env_vars)
296                .stdout(Stdio::piped())
297                .stderr(Stdio::piped())
298                .spawn()
299                .with_context(|| format!("failed to launch {launcher}. Is it installed?"))?;
300
301            let pid = child.id().unwrap_or(0);
302
303            // Record in DB
304            let conn = open_jobs_db()?;
305            conn.execute(
306                "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, pid, submitted_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
307                (
308                    &job_id, &script, "running", gpus, nodes, &framework, &backend, pid,
309                    Utc::now().to_rfc3339(),
310                ),
311            )?;
312
313            println!("Job started (PID: {pid})");
314            println!();
315
316            // Capture output
317            let stdout = child.stdout.take();
318            let stderr = child.stderr.take();
319            let log_path_clone = log_path.clone();
320
321            let stdout_handle = tokio::spawn(async move {
322                let Some(stdout) = stdout else { return };
323                let mut reader = BufReader::new(stdout);
324                let mut line = String::new();
325                let mut log_file = std::fs::File::create(&log_path_clone).ok();
326
327                loop {
328                    line.clear();
329                    match reader.read_line(&mut line).await {
330                        Ok(0) => break,
331                        Ok(_) => {
332                            print!("{line}");
333                            if let Some(ref mut f) = log_file {
334                                use std::io::Write;
335                                let _ = f.write_all(line.as_bytes());
336                            }
337                        }
338                        Err(_) => break,
339                    }
340                }
341            });
342
343            let stderr_handle = tokio::spawn(async move {
344                let Some(stderr) = stderr else { return };
345                let mut reader = BufReader::new(stderr);
346                let mut line = String::new();
347                loop {
348                    line.clear();
349                    match reader.read_line(&mut line).await {
350                        Ok(0) => break,
351                        Ok(_) => eprint!("{line}"),
352                        Err(_) => break,
353                    }
354                }
355            });
356
357            let (_, _) = tokio::join!(stdout_handle, stderr_handle);
358            let status = child.wait().await?;
359            let exit_code = status.code().unwrap_or(-1);
360
361            // Update DB
362            let final_status = if status.success() { "done" } else { "failed" };
363            conn.execute(
364                "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
365                (final_status, Utc::now().to_rfc3339(), exit_code, &job_id),
366            )?;
367
368            println!();
369            println!("---");
370            println!("  Status: {final_status}");
371            println!("  Exit:   {exit_code}");
372            println!("  Job ID: {job_id}");
373            println!("  Log:    {}", log_path.display());
374        }
375
376        JobCommands::List => {
377            let conn = open_jobs_db()?;
378            let mut stmt = conn.prepare(
379                "SELECT id, script, status, gpus_per_node, nodes, framework, submitted_at, exit_code FROM jobs ORDER BY submitted_at DESC LIMIT 20",
380            )?;
381
382            let mut jobs = Vec::new();
383            let mut rows = stmt.query([])?;
384            while let Some(row) = rows.next()? {
385                jobs.push(Job {
386                    id: row.get(0)?,
387                    script: row.get(1)?,
388                    status: row.get(2)?,
389                    gpus_per_node: row.get(3)?,
390                    nodes: row.get(4)?,
391                    framework: row.get(5)?,
392                    backend: String::new(),
393                    pid: None,
394                    submitted_at: row.get(6)?,
395                    finished_at: None,
396                    exit_code: row.get(7)?,
397                });
398            }
399
400            if jobs.is_empty() {
401                println!("No jobs. Submit one with: zernel job submit <script>");
402                return Ok(());
403            }
404
405            let header = format!(
406                "{:<30} {:<20} {:<10} {:>5} {:>5} {:<12} {:>6}",
407                "ID", "Script", "Status", "GPUs", "Nodes", "Framework", "Exit"
408            );
409            println!("{header}");
410            println!("{}", "-".repeat(95));
411
412            for j in &jobs {
413                let exit_str = j
414                    .exit_code
415                    .map(|e| e.to_string())
416                    .unwrap_or_else(|| "-".into());
417                println!(
418                    "{:<30} {:<20} {:<10} {:>5} {:>5} {:<12} {:>6}",
419                    j.id, j.script, j.status, j.gpus_per_node, j.nodes, j.framework, exit_str
420                );
421            }
422        }
423
424        JobCommands::Status { id } => {
425            let conn = open_jobs_db()?;
426            let mut stmt = conn.prepare(
427                "SELECT id, script, status, gpus_per_node, nodes, framework, backend, pid, submitted_at, finished_at, exit_code FROM jobs WHERE id = ?1",
428            )?;
429
430            let job = stmt
431                .query_row([&id], |row| {
432                    Ok(Job {
433                        id: row.get(0)?,
434                        script: row.get(1)?,
435                        status: row.get(2)?,
436                        gpus_per_node: row.get(3)?,
437                        nodes: row.get(4)?,
438                        framework: row.get(5)?,
439                        backend: row.get(6)?,
440                        pid: row.get(7)?,
441                        submitted_at: row.get(8)?,
442                        finished_at: row.get(9)?,
443                        exit_code: row.get(10)?,
444                    })
445                })
446                .ok();
447
448            match job {
449                Some(j) => {
450                    println!("Job: {}", j.id);
451                    println!("  Script:     {}", j.script);
452                    println!("  Status:     {}", j.status);
453                    println!("  Framework:  {}", j.framework);
454                    println!("  Backend:    {}", j.backend);
455                    println!("  GPUs/node:  {}", j.gpus_per_node);
456                    println!("  Nodes:      {}", j.nodes);
457                    if let Some(pid) = j.pid {
458                        println!("  PID:        {pid}");
459                    }
460                    println!("  Submitted:  {}", j.submitted_at);
461                    if let Some(fin) = &j.finished_at {
462                        println!("  Finished:   {fin}");
463                    }
464                    if let Some(exit) = j.exit_code {
465                        println!("  Exit code:  {exit}");
466                    }
467
468                    // Show last 10 lines of log
469                    let log_path = jobs_log_dir(&j.id).join("output.log");
470                    if log_path.exists() {
471                        println!();
472                        println!("  Last output:");
473                        if let Ok(content) = std::fs::read_to_string(&log_path) {
474                            let lines: Vec<&str> = content.lines().collect();
475                            let start = lines.len().saturating_sub(10);
476                            for line in &lines[start..] {
477                                println!("    {line}");
478                            }
479                        }
480                    }
481                }
482                None => {
483                    println!("Job not found: {id}");
484                }
485            }
486        }
487
488        JobCommands::Cancel { id } => {
489            let conn = open_jobs_db()?;
490            let pid: Option<u32> = conn
491                .query_row(
492                    "SELECT pid FROM jobs WHERE id = ?1 AND status = 'running'",
493                    [&id],
494                    |row| row.get(0),
495                )
496                .ok();
497
498            match pid {
499                Some(pid) if pid > 0 => {
500                    println!("Cancelling job {id} (PID: {pid})...");
501
502                    // Send SIGTERM (or taskkill on Windows)
503                    #[cfg(unix)]
504                    {
505                        unsafe {
506                            libc::kill(pid as i32, libc::SIGTERM);
507                        }
508                    }
509                    #[cfg(not(unix))]
510                    {
511                        let _ = std::process::Command::new("taskkill")
512                            .args(["/PID", &pid.to_string(), "/F"])
513                            .output();
514                    }
515
516                    conn.execute(
517                        "UPDATE jobs SET status = 'cancelled', finished_at = ?1 WHERE id = ?2",
518                        (Utc::now().to_rfc3339(), &id),
519                    )?;
520
521                    println!("Job cancelled.");
522                }
523                _ => {
524                    println!("No running job found with ID: {id}");
525                }
526            }
527        }
528    }
529    Ok(())
530}