1use crate::experiments::tracker;
4use anyhow::{Context, Result};
5use chrono::Utc;
6use clap::Subcommand;
7use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use std::process::Stdio;
11use tokio::io::{AsyncBufReadExt, BufReader};
12
13#[derive(Subcommand)]
14pub enum JobCommands {
15 Submit {
17 script: String,
19 #[arg(long, default_value = "1")]
21 gpus_per_node: u32,
22 #[arg(long, default_value = "1")]
24 nodes: u32,
25 #[arg(long, default_value = "pytorch")]
27 framework: String,
28 #[arg(long, default_value = "nccl")]
30 backend: String,
31 #[arg(long, default_value = "local")]
33 target: String,
34 #[arg(long)]
36 hosts: Option<String>,
37 #[arg(long)]
39 image: Option<String>,
40 #[arg(long, default_value = "default")]
42 namespace: String,
43 #[arg(trailing_var_arg = true)]
45 args: Vec<String>,
46 },
47 List,
49 Status {
51 id: String,
53 },
54 Cancel {
56 id: String,
58 },
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct Job {
63 id: String,
64 script: String,
65 status: String,
66 gpus_per_node: u32,
67 nodes: u32,
68 framework: String,
69 backend: String,
70 pid: Option<u32>,
71 submitted_at: String,
72 finished_at: Option<String>,
73 exit_code: Option<i32>,
74}
75
76fn jobs_db_path() -> PathBuf {
77 let dir = tracker::zernel_dir().join("jobs");
78 std::fs::create_dir_all(&dir).ok();
79 dir.join("jobs.db")
80}
81
82fn jobs_log_dir(job_id: &str) -> PathBuf {
83 let dir = tracker::zernel_dir().join("jobs").join(job_id);
84 std::fs::create_dir_all(&dir).ok();
85 dir
86}
87
88fn open_jobs_db() -> Result<Connection> {
89 let conn = Connection::open(jobs_db_path())?;
90 conn.execute_batch(
91 "CREATE TABLE IF NOT EXISTS jobs (
92 id TEXT PRIMARY KEY,
93 script TEXT NOT NULL,
94 status TEXT NOT NULL,
95 gpus_per_node INTEGER NOT NULL,
96 nodes INTEGER NOT NULL,
97 framework TEXT NOT NULL,
98 backend TEXT NOT NULL,
99 pid INTEGER,
100 submitted_at TEXT NOT NULL,
101 finished_at TEXT,
102 exit_code INTEGER
103 );",
104 )?;
105 Ok(conn)
106}
107
108fn generate_job_id() -> String {
109 let now = chrono::Utc::now();
110 let short = &uuid::Uuid::new_v4().to_string()[..8];
111 format!("job-{}-{}", now.format("%Y%m%d-%H%M%S"), short)
112}
113
114fn detect_gpu_count() -> u32 {
115 std::process::Command::new("nvidia-smi")
116 .args(["--query-gpu=count", "--format=csv,noheader"])
117 .output()
118 .ok()
119 .and_then(|o| {
120 if o.status.success() {
121 String::from_utf8_lossy(&o.stdout)
122 .trim()
123 .lines()
124 .next()
125 .and_then(|s| s.parse().ok())
126 } else {
127 None
128 }
129 })
130 .unwrap_or(0)
131}
132
133pub async fn run(cmd: JobCommands) -> Result<()> {
134 match cmd {
135 JobCommands::Submit {
136 script,
137 gpus_per_node,
138 nodes,
139 framework,
140 backend,
141 target,
142 hosts,
143 image,
144 namespace,
145 args,
146 } => {
147 let script_path = std::path::Path::new(&script);
148 if (target == "local" || target == "ssh") && !script_path.exists() {
149 anyhow::bail!("script not found: {script}");
150 }
151
152 if target == "ssh" {
154 let host_list = super::job_ssh::parse_hosts(
155 hosts
156 .as_deref()
157 .ok_or_else(|| anyhow::anyhow!("--hosts required for --target ssh"))?,
158 )?;
159 if host_list.len() < nodes as usize {
160 anyhow::bail!("need {} hosts but only {} provided", nodes, host_list.len());
161 }
162 let job_id = generate_job_id();
163 let log_dir = jobs_log_dir(&job_id);
164
165 let conn = open_jobs_db()?;
166 conn.execute(
167 "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, submitted_at) VALUES (?1, ?2, 'running', ?3, ?4, ?5, ?6, ?7)",
168 (&job_id, &script, gpus_per_node, nodes, &framework, &backend, Utc::now().to_rfc3339()),
169 )?;
170
171 let exit_code = super::job_ssh::run_ssh_job(
172 &job_id,
173 &script,
174 &host_list[..nodes as usize],
175 gpus_per_node,
176 &framework,
177 &backend,
178 &args,
179 &log_dir,
180 )
181 .await?;
182
183 let status = if exit_code == 0 { "done" } else { "failed" };
184 conn.execute(
185 "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
186 (status, Utc::now().to_rfc3339(), exit_code, &job_id),
187 )?;
188 println!("\nJob {job_id}: {status} (exit {exit_code})");
189 return Ok(());
190 }
191
192 if target == "k8s" {
193 let img = image
194 .as_deref()
195 .ok_or_else(|| anyhow::anyhow!("--image required for --target k8s"))?;
196 let job_id = generate_job_id();
197 let log_dir = jobs_log_dir(&job_id);
198
199 let conn = open_jobs_db()?;
200 conn.execute(
201 "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, submitted_at) VALUES (?1, ?2, 'running', ?3, ?4, ?5, ?6, ?7)",
202 (&job_id, &script, gpus_per_node, nodes, &framework, &backend, Utc::now().to_rfc3339()),
203 )?;
204
205 let exit_code = super::job_k8s::run_k8s_job(
206 &job_id,
207 &script,
208 img,
209 gpus_per_node,
210 nodes,
211 &namespace,
212 &args,
213 &log_dir,
214 )
215 .await?;
216
217 let status = if exit_code == 0 { "done" } else { "failed" };
218 conn.execute(
219 "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
220 (status, Utc::now().to_rfc3339(), exit_code, &job_id),
221 )?;
222 println!("\nJob {job_id}: {status} (exit {exit_code})");
223 return Ok(());
224 }
225
226 if !script_path.exists() {
228 anyhow::bail!("script not found: {script}");
229 }
230
231 let detected_gpus = detect_gpu_count();
232 let gpus = if gpus_per_node == 1 && detected_gpus > 1 {
233 println!("Detected {detected_gpus} GPUs. Using all.");
234 detected_gpus
235 } else {
236 gpus_per_node
237 };
238
239 let total_procs = gpus * nodes;
240 let job_id = generate_job_id();
241 let log_dir = jobs_log_dir(&job_id);
242 let log_path = log_dir.join("output.log");
243
244 let (launcher, launch_args) = match framework.as_str() {
246 "pytorch" => {
247 let mut a = vec![
248 format!("--nproc_per_node={gpus}"),
249 format!("--nnodes={nodes}"),
250 "--master_addr=localhost".into(),
251 "--master_port=29500".into(),
252 script.clone(),
253 ];
254 a.extend(args.clone());
255 ("torchrun".to_string(), a)
256 }
257 "accelerate" | "hf" => {
258 let mut a = vec![
259 "launch".into(),
260 format!("--num_processes={total_procs}"),
261 script.clone(),
262 ];
263 a.extend(args.clone());
264 ("accelerate".to_string(), a)
265 }
266 other => {
267 anyhow::bail!("unsupported framework: {other}. Use 'pytorch' or 'accelerate'.");
268 }
269 };
270
271 println!("Zernel Job Submit");
272 println!(" Job ID: {job_id}");
273 println!(" Script: {script}");
274 println!(" Framework: {framework}");
275 println!(" Backend: {backend}");
276 println!(" GPUs/node: {gpus}");
277 println!(" Nodes: {nodes}");
278 println!(" Total procs: {total_procs}");
279 println!(" Launcher: {launcher} {}", launch_args.join(" "));
280 println!(" Log: {}", log_path.display());
281 println!();
282
283 let mut env_vars = vec![
285 ("NCCL_SOCKET_IFNAME", "eth0,en0".to_string()),
286 ("NCCL_DEBUG", "WARN".to_string()),
287 ];
288 if backend == "nccl" {
289 env_vars.push(("NCCL_P2P_DISABLE", "0".to_string()));
290 }
291
292 let mut child = tokio::process::Command::new(&launcher)
294 .args(&launch_args)
295 .envs(env_vars)
296 .stdout(Stdio::piped())
297 .stderr(Stdio::piped())
298 .spawn()
299 .with_context(|| format!("failed to launch {launcher}. Is it installed?"))?;
300
301 let pid = child.id().unwrap_or(0);
302
303 let conn = open_jobs_db()?;
305 conn.execute(
306 "INSERT INTO jobs (id, script, status, gpus_per_node, nodes, framework, backend, pid, submitted_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
307 (
308 &job_id, &script, "running", gpus, nodes, &framework, &backend, pid,
309 Utc::now().to_rfc3339(),
310 ),
311 )?;
312
313 println!("Job started (PID: {pid})");
314 println!();
315
316 let stdout = child.stdout.take();
318 let stderr = child.stderr.take();
319 let log_path_clone = log_path.clone();
320
321 let stdout_handle = tokio::spawn(async move {
322 let Some(stdout) = stdout else { return };
323 let mut reader = BufReader::new(stdout);
324 let mut line = String::new();
325 let mut log_file = std::fs::File::create(&log_path_clone).ok();
326
327 loop {
328 line.clear();
329 match reader.read_line(&mut line).await {
330 Ok(0) => break,
331 Ok(_) => {
332 print!("{line}");
333 if let Some(ref mut f) = log_file {
334 use std::io::Write;
335 let _ = f.write_all(line.as_bytes());
336 }
337 }
338 Err(_) => break,
339 }
340 }
341 });
342
343 let stderr_handle = tokio::spawn(async move {
344 let Some(stderr) = stderr else { return };
345 let mut reader = BufReader::new(stderr);
346 let mut line = String::new();
347 loop {
348 line.clear();
349 match reader.read_line(&mut line).await {
350 Ok(0) => break,
351 Ok(_) => eprint!("{line}"),
352 Err(_) => break,
353 }
354 }
355 });
356
357 let (_, _) = tokio::join!(stdout_handle, stderr_handle);
358 let status = child.wait().await?;
359 let exit_code = status.code().unwrap_or(-1);
360
361 let final_status = if status.success() { "done" } else { "failed" };
363 conn.execute(
364 "UPDATE jobs SET status = ?1, finished_at = ?2, exit_code = ?3 WHERE id = ?4",
365 (final_status, Utc::now().to_rfc3339(), exit_code, &job_id),
366 )?;
367
368 println!();
369 println!("---");
370 println!(" Status: {final_status}");
371 println!(" Exit: {exit_code}");
372 println!(" Job ID: {job_id}");
373 println!(" Log: {}", log_path.display());
374 }
375
376 JobCommands::List => {
377 let conn = open_jobs_db()?;
378 let mut stmt = conn.prepare(
379 "SELECT id, script, status, gpus_per_node, nodes, framework, submitted_at, exit_code FROM jobs ORDER BY submitted_at DESC LIMIT 20",
380 )?;
381
382 let mut jobs = Vec::new();
383 let mut rows = stmt.query([])?;
384 while let Some(row) = rows.next()? {
385 jobs.push(Job {
386 id: row.get(0)?,
387 script: row.get(1)?,
388 status: row.get(2)?,
389 gpus_per_node: row.get(3)?,
390 nodes: row.get(4)?,
391 framework: row.get(5)?,
392 backend: String::new(),
393 pid: None,
394 submitted_at: row.get(6)?,
395 finished_at: None,
396 exit_code: row.get(7)?,
397 });
398 }
399
400 if jobs.is_empty() {
401 println!("No jobs. Submit one with: zernel job submit <script>");
402 return Ok(());
403 }
404
405 let header = format!(
406 "{:<30} {:<20} {:<10} {:>5} {:>5} {:<12} {:>6}",
407 "ID", "Script", "Status", "GPUs", "Nodes", "Framework", "Exit"
408 );
409 println!("{header}");
410 println!("{}", "-".repeat(95));
411
412 for j in &jobs {
413 let exit_str = j
414 .exit_code
415 .map(|e| e.to_string())
416 .unwrap_or_else(|| "-".into());
417 println!(
418 "{:<30} {:<20} {:<10} {:>5} {:>5} {:<12} {:>6}",
419 j.id, j.script, j.status, j.gpus_per_node, j.nodes, j.framework, exit_str
420 );
421 }
422 }
423
424 JobCommands::Status { id } => {
425 let conn = open_jobs_db()?;
426 let mut stmt = conn.prepare(
427 "SELECT id, script, status, gpus_per_node, nodes, framework, backend, pid, submitted_at, finished_at, exit_code FROM jobs WHERE id = ?1",
428 )?;
429
430 let job = stmt
431 .query_row([&id], |row| {
432 Ok(Job {
433 id: row.get(0)?,
434 script: row.get(1)?,
435 status: row.get(2)?,
436 gpus_per_node: row.get(3)?,
437 nodes: row.get(4)?,
438 framework: row.get(5)?,
439 backend: row.get(6)?,
440 pid: row.get(7)?,
441 submitted_at: row.get(8)?,
442 finished_at: row.get(9)?,
443 exit_code: row.get(10)?,
444 })
445 })
446 .ok();
447
448 match job {
449 Some(j) => {
450 println!("Job: {}", j.id);
451 println!(" Script: {}", j.script);
452 println!(" Status: {}", j.status);
453 println!(" Framework: {}", j.framework);
454 println!(" Backend: {}", j.backend);
455 println!(" GPUs/node: {}", j.gpus_per_node);
456 println!(" Nodes: {}", j.nodes);
457 if let Some(pid) = j.pid {
458 println!(" PID: {pid}");
459 }
460 println!(" Submitted: {}", j.submitted_at);
461 if let Some(fin) = &j.finished_at {
462 println!(" Finished: {fin}");
463 }
464 if let Some(exit) = j.exit_code {
465 println!(" Exit code: {exit}");
466 }
467
468 let log_path = jobs_log_dir(&j.id).join("output.log");
470 if log_path.exists() {
471 println!();
472 println!(" Last output:");
473 if let Ok(content) = std::fs::read_to_string(&log_path) {
474 let lines: Vec<&str> = content.lines().collect();
475 let start = lines.len().saturating_sub(10);
476 for line in &lines[start..] {
477 println!(" {line}");
478 }
479 }
480 }
481 }
482 None => {
483 println!("Job not found: {id}");
484 }
485 }
486 }
487
488 JobCommands::Cancel { id } => {
489 let conn = open_jobs_db()?;
490 let pid: Option<u32> = conn
491 .query_row(
492 "SELECT pid FROM jobs WHERE id = ?1 AND status = 'running'",
493 [&id],
494 |row| row.get(0),
495 )
496 .ok();
497
498 match pid {
499 Some(pid) if pid > 0 => {
500 println!("Cancelling job {id} (PID: {pid})...");
501
502 #[cfg(unix)]
504 {
505 unsafe {
506 libc::kill(pid as i32, libc::SIGTERM);
507 }
508 }
509 #[cfg(not(unix))]
510 {
511 let _ = std::process::Command::new("taskkill")
512 .args(["/PID", &pid.to_string(), "/F"])
513 .output();
514 }
515
516 conn.execute(
517 "UPDATE jobs SET status = 'cancelled', finished_at = ?1 WHERE id = ?2",
518 (Utc::now().to_rfc3339(), &id),
519 )?;
520
521 println!("Job cancelled.");
522 }
523 _ => {
524 println!("No running job found with ID: {id}");
525 }
526 }
527 }
528 }
529 Ok(())
530}