zernel/commands/
model.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3use crate::validation;
4use anyhow::{Context, Result};
5use chrono::Utc;
6use clap::Subcommand;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Subcommand)]
11pub enum ModelCommands {
12    /// Save a model checkpoint to the registry
13    Save {
14        /// Path to the checkpoint directory or file
15        path: String,
16        /// Model name
17        #[arg(long)]
18        name: Option<String>,
19        /// Tag (e.g., production, staging)
20        #[arg(long)]
21        tag: Option<String>,
22    },
23    /// List models in the registry
24    List,
25    /// Deploy a model for inference
26    Deploy {
27        /// Model name:tag
28        model: String,
29        /// Deployment target (local, docker, sagemaker)
30        #[arg(long, default_value = "local")]
31        target: String,
32        /// Port for inference server
33        #[arg(long, default_value = "8080")]
34        port: u16,
35        /// Docker registry to push to (for --target docker)
36        #[arg(long)]
37        registry: Option<String>,
38        /// AWS region (for --target sagemaker)
39        #[arg(long, default_value = "us-east-1")]
40        region: String,
41        /// SageMaker instance type
42        #[arg(long, default_value = "ml.g5.xlarge")]
43        instance_type: String,
44    },
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48struct ModelEntry {
49    name: String,
50    version: String,
51    tag: String,
52    source_path: String,
53    saved_at: String,
54    git_commit: Option<String>,
55    size_bytes: u64,
56}
57
58fn registry_dir() -> PathBuf {
59    let dir = dirs::home_dir()
60        .unwrap_or_else(|| PathBuf::from("."))
61        .join(".zernel")
62        .join("models");
63    std::fs::create_dir_all(&dir).ok();
64    dir
65}
66
67fn registry_file() -> PathBuf {
68    registry_dir().join("registry.json")
69}
70
71fn load_registry() -> Vec<ModelEntry> {
72    let path = registry_file();
73    if path.exists() {
74        let data = std::fs::read_to_string(&path).unwrap_or_default();
75        serde_json::from_str(&data).unwrap_or_default()
76    } else {
77        Vec::new()
78    }
79}
80
81fn save_registry(entries: &[ModelEntry]) -> Result<()> {
82    let data = serde_json::to_string_pretty(entries)?;
83    std::fs::write(registry_file(), data)?;
84    Ok(())
85}
86
87fn dir_size(path: &Path) -> u64 {
88    if path.is_file() {
89        return path.metadata().map(|m| m.len()).unwrap_or(0);
90    }
91    walkdir(path)
92}
93
94fn walkdir(path: &Path) -> u64 {
95    let mut total = 0u64;
96    if let Ok(entries) = std::fs::read_dir(path) {
97        for entry in entries.flatten() {
98            let Ok(ft) = entry.file_type() else {
99                continue;
100            };
101            if ft.is_file() {
102                total += entry.metadata().map(|m| m.len()).unwrap_or(0);
103            } else if ft.is_dir() {
104                total += walkdir(&entry.path());
105            }
106        }
107    }
108    total
109}
110
111fn format_size(bytes: u64) -> String {
112    if bytes < 1024 {
113        format!("{bytes} B")
114    } else if bytes < 1024 * 1024 {
115        format!("{:.1} KB", bytes as f64 / 1024.0)
116    } else if bytes < 1024 * 1024 * 1024 {
117        format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
118    } else {
119        format!("{:.1} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
120    }
121}
122
123pub async fn run(cmd: ModelCommands) -> Result<()> {
124    match cmd {
125        ModelCommands::Save { path, name, tag } => {
126            let source = Path::new(&path);
127            if !source.exists() {
128                anyhow::bail!("path does not exist: {path}");
129            }
130
131            let model_name = name.unwrap_or_else(|| {
132                source
133                    .file_name()
134                    .map(|s| s.to_string_lossy().to_string())
135                    .unwrap_or_else(|| "unnamed".into())
136            });
137            let model_tag = tag.unwrap_or_else(|| "latest".into());
138
139            // Validate name and tag to prevent path traversal
140            validation::validate_name(&model_name)?;
141            validation::validate_tag(&model_tag)?;
142
143            let git_commit = std::process::Command::new("git")
144                .args(["rev-parse", "--short", "HEAD"])
145                .output()
146                .ok()
147                .and_then(|o| {
148                    if o.status.success() {
149                        Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
150                    } else {
151                        None
152                    }
153                });
154
155            let size = dir_size(source);
156
157            // Copy checkpoint to registry
158            let dest = registry_dir().join(&model_name).join(&model_tag);
159            std::fs::create_dir_all(&dest)?;
160
161            if source.is_file() {
162                let fname = source
163                    .file_name()
164                    .ok_or_else(|| anyhow::anyhow!("source path has no filename"))?;
165                std::fs::copy(source, dest.join(fname))
166                    .with_context(|| format!("failed to copy {path}"))?;
167            } else {
168                copy_dir_recursive(source, &dest)?;
169            }
170
171            let mut registry = load_registry();
172            // Remove old entry with same name:tag
173            registry.retain(|e| !(e.name == model_name && e.tag == model_tag));
174
175            let version = format!(
176                "{}.0.0",
177                registry.iter().filter(|e| e.name == model_name).count() + 1
178            );
179
180            registry.push(ModelEntry {
181                name: model_name.clone(),
182                version: version.clone(),
183                tag: model_tag.clone(),
184                source_path: path.clone(),
185                saved_at: Utc::now().to_rfc3339(),
186                git_commit,
187                size_bytes: size,
188            });
189
190            save_registry(&registry)?;
191
192            println!("Saved model: {model_name}:{model_tag}");
193            println!("  Version: {version}");
194            println!("  Size:    {}", format_size(size));
195            println!("  Path:    {}", dest.display());
196        }
197        ModelCommands::List => {
198            let registry = load_registry();
199            if registry.is_empty() {
200                println!("No models saved. Use `zernel model save <path>` to register one.");
201                return Ok(());
202            }
203
204            let header = format!(
205                "{:<20} {:<10} {:<12} {:>10} {:>10}",
206                "Name", "Version", "Tag", "Size", "Saved"
207            );
208            println!("{header}");
209            println!("{}", "-".repeat(70));
210
211            for entry in &registry {
212                let saved = &entry.saved_at[..10]; // just date
213                println!(
214                    "{:<20} {:<10} {:<12} {:>10} {}",
215                    entry.name,
216                    entry.version,
217                    entry.tag,
218                    format_size(entry.size_bytes),
219                    saved,
220                );
221            }
222        }
223        ModelCommands::Deploy {
224            model,
225            target,
226            port,
227            registry: docker_registry,
228            region,
229            instance_type,
230        } => {
231            let (name, tag) = model.split_once(':').unwrap_or((&model, "latest"));
232
233            let registry = load_registry();
234            let entry = registry
235                .iter()
236                .find(|e| e.name == name && e.tag == tag)
237                .ok_or_else(|| anyhow::anyhow!("model not found: {name}:{tag}"))?;
238
239            let model_path = registry_dir().join(name).join(tag);
240
241            println!(
242                "Deploying {name}:{tag} (v{}) to {target} on port {port}",
243                entry.version
244            );
245            println!("  Source: {}", model_path.display());
246
247            match target.as_str() {
248                "local" => {
249                    // Check vllm is installed
250                    let vllm_check = std::process::Command::new("python3")
251                        .args(["-c", "import vllm; print(vllm.__version__)"])
252                        .output();
253
254                    match vllm_check {
255                        Ok(output) if output.status.success() => {
256                            let version =
257                                String::from_utf8_lossy(&output.stdout).trim().to_string();
258                            println!("  vLLM:   v{version}");
259                        }
260                        _ => {
261                            println!();
262                            println!("  vLLM not found. Install with: pip install vllm");
263                            println!(
264                                "  Or run manually: python3 -m vllm.entrypoints.openai.api_server --model {} --port {port}",
265                                model_path.display()
266                            );
267                            return Ok(());
268                        }
269                    }
270
271                    println!();
272                    println!("Starting vLLM inference server...");
273                    println!("  URL: http://localhost:{port}/v1");
274                    println!("  Press Ctrl+C to stop");
275                    println!();
276
277                    // Launch vLLM
278                    let status = tokio::process::Command::new("python3")
279                        .args([
280                            "-m",
281                            "vllm.entrypoints.openai.api_server",
282                            "--model",
283                            &model_path.to_string_lossy(),
284                            "--port",
285                            &port.to_string(),
286                        ])
287                        .status()
288                        .await
289                        .with_context(|| "failed to start vLLM server")?;
290
291                    if !status.success() {
292                        anyhow::bail!("vLLM exited with code {}", status.code().unwrap_or(-1));
293                    }
294                }
295                "docker" => {
296                    println!();
297
298                    // Generate Dockerfile
299                    let dockerfile_content = format!(
300                        "FROM vllm/vllm-openai:latest\n\
301                         COPY . /model\n\
302                         EXPOSE {port}\n\
303                         ENTRYPOINT [\"python3\", \"-m\", \"vllm.entrypoints.openai.api_server\", \
304                         \"--model\", \"/model\", \"--port\", \"{port}\"]\n"
305                    );
306
307                    let dockerfile_path = model_path.join("Dockerfile.zernel");
308                    std::fs::write(&dockerfile_path, &dockerfile_content)?;
309                    println!("  Generated: {}", dockerfile_path.display());
310
311                    let image_tag = format!("zernel-{name}:{tag}");
312                    println!("  Building Docker image: {image_tag}");
313
314                    let build_status = std::process::Command::new("docker")
315                        .args(["build", "-t", &image_tag, "-f"])
316                        .arg(&dockerfile_path)
317                        .arg(&model_path)
318                        .status()?;
319
320                    if !build_status.success() {
321                        anyhow::bail!("docker build failed");
322                    }
323
324                    println!("  Image built: {image_tag}");
325
326                    if let Some(ref reg) = docker_registry {
327                        let remote_tag = format!("{reg}/{image_tag}");
328                        println!("  Pushing to {remote_tag}...");
329
330                        std::process::Command::new("docker")
331                            .args(["tag", &image_tag, &remote_tag])
332                            .status()?;
333
334                        let push_status = std::process::Command::new("docker")
335                            .args(["push", &remote_tag])
336                            .status()?;
337
338                        if push_status.success() {
339                            println!("  Pushed: {remote_tag}");
340                        } else {
341                            anyhow::bail!("docker push failed");
342                        }
343                    }
344
345                    println!();
346                    println!("Run locally: docker run --gpus all -p {port}:{port} {image_tag}");
347                }
348
349                "sagemaker" => {
350                    println!();
351                    println!("  Region:   {region}");
352                    println!("  Instance: {instance_type}");
353
354                    // Check AWS CLI
355                    let aws_check = std::process::Command::new("aws")
356                        .args(["sts", "get-caller-identity"])
357                        .output();
358
359                    match aws_check {
360                        Ok(output) if output.status.success() => {
361                            let identity = String::from_utf8_lossy(&output.stdout);
362                            println!("  AWS:      authenticated");
363                            let _ = identity;
364                        }
365                        _ => {
366                            println!();
367                            println!("  AWS CLI not configured. Run: aws configure");
368                            return Ok(());
369                        }
370                    }
371
372                    let s3_path = format!("s3://zernel-models/{name}/{tag}/");
373                    println!("  Uploading to {s3_path}...");
374
375                    let sync_status = std::process::Command::new("aws")
376                        .args([
377                            "s3",
378                            "sync",
379                            &model_path.to_string_lossy(),
380                            &s3_path,
381                            "--region",
382                            &region,
383                        ])
384                        .status()?;
385
386                    if !sync_status.success() {
387                        anyhow::bail!("aws s3 sync failed");
388                    }
389
390                    let sm_model_name = format!("zernel-{name}-{tag}");
391
392                    println!("  Creating SageMaker model: {sm_model_name}");
393                    let create_status = std::process::Command::new("aws")
394                        .args([
395                            "sagemaker", "create-model",
396                            "--model-name", &sm_model_name,
397                            "--primary-container",
398                            &format!("Image=763104351884.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118-ubuntu20.04,ModelDataUrl={s3_path}"),
399                            "--execution-role-arn", "arn:aws:iam::role/SageMakerExecutionRole",
400                            "--region", &region,
401                        ])
402                        .status()?;
403
404                    if !create_status.success() {
405                        println!("  SageMaker model creation failed.");
406                        println!("  You may need to configure the IAM role and container image.");
407                        println!(
408                            "  Manual: aws sagemaker create-model --model-name {sm_model_name} ..."
409                        );
410                        return Ok(());
411                    }
412
413                    println!("  Creating endpoint config...");
414                    let _ = std::process::Command::new("aws")
415                        .args([
416                            "sagemaker", "create-endpoint-config",
417                            "--endpoint-config-name", &format!("{sm_model_name}-config"),
418                            "--production-variants",
419                            &format!("VariantName=default,ModelName={sm_model_name},InstanceType={instance_type},InitialInstanceCount=1"),
420                            "--region", &region,
421                        ])
422                        .status();
423
424                    println!("  Creating endpoint...");
425                    let _ = std::process::Command::new("aws")
426                        .args([
427                            "sagemaker",
428                            "create-endpoint",
429                            "--endpoint-name",
430                            &sm_model_name,
431                            "--endpoint-config-name",
432                            &format!("{sm_model_name}-config"),
433                            "--region",
434                            &region,
435                        ])
436                        .status();
437
438                    println!();
439                    println!("  Endpoint: {sm_model_name}");
440                    println!("  Check status: aws sagemaker describe-endpoint --endpoint-name {sm_model_name} --region {region}");
441                }
442
443                other => {
444                    println!();
445                    println!("Unknown target: '{other}'");
446                    println!("Available: local, docker, sagemaker");
447                }
448            }
449        }
450    }
451    Ok(())
452}
453
454fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> {
455    std::fs::create_dir_all(dst)?;
456    for entry in std::fs::read_dir(src)? {
457        let entry = entry?;
458        let ft = entry.file_type()?;
459        let dest_path = dst.join(entry.file_name());
460        if ft.is_file() {
461            std::fs::copy(entry.path(), &dest_path)?;
462        } else if ft.is_dir() {
463            copy_dir_recursive(&entry.path(), &dest_path)?;
464        }
465    }
466    Ok(())
467}