1use crate::validation;
4use anyhow::{Context, Result};
5use chrono::Utc;
6use clap::Subcommand;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Subcommand)]
11pub enum ModelCommands {
12 Save {
14 path: String,
16 #[arg(long)]
18 name: Option<String>,
19 #[arg(long)]
21 tag: Option<String>,
22 },
23 List,
25 Deploy {
27 model: String,
29 #[arg(long, default_value = "local")]
31 target: String,
32 #[arg(long, default_value = "8080")]
34 port: u16,
35 #[arg(long)]
37 registry: Option<String>,
38 #[arg(long, default_value = "us-east-1")]
40 region: String,
41 #[arg(long, default_value = "ml.g5.xlarge")]
43 instance_type: String,
44 },
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48struct ModelEntry {
49 name: String,
50 version: String,
51 tag: String,
52 source_path: String,
53 saved_at: String,
54 git_commit: Option<String>,
55 size_bytes: u64,
56}
57
58fn registry_dir() -> PathBuf {
59 let dir = dirs::home_dir()
60 .unwrap_or_else(|| PathBuf::from("."))
61 .join(".zernel")
62 .join("models");
63 std::fs::create_dir_all(&dir).ok();
64 dir
65}
66
67fn registry_file() -> PathBuf {
68 registry_dir().join("registry.json")
69}
70
71fn load_registry() -> Vec<ModelEntry> {
72 let path = registry_file();
73 if path.exists() {
74 let data = std::fs::read_to_string(&path).unwrap_or_default();
75 serde_json::from_str(&data).unwrap_or_default()
76 } else {
77 Vec::new()
78 }
79}
80
81fn save_registry(entries: &[ModelEntry]) -> Result<()> {
82 let data = serde_json::to_string_pretty(entries)?;
83 std::fs::write(registry_file(), data)?;
84 Ok(())
85}
86
87fn dir_size(path: &Path) -> u64 {
88 if path.is_file() {
89 return path.metadata().map(|m| m.len()).unwrap_or(0);
90 }
91 walkdir(path)
92}
93
94fn walkdir(path: &Path) -> u64 {
95 let mut total = 0u64;
96 if let Ok(entries) = std::fs::read_dir(path) {
97 for entry in entries.flatten() {
98 let Ok(ft) = entry.file_type() else {
99 continue;
100 };
101 if ft.is_file() {
102 total += entry.metadata().map(|m| m.len()).unwrap_or(0);
103 } else if ft.is_dir() {
104 total += walkdir(&entry.path());
105 }
106 }
107 }
108 total
109}
110
111fn format_size(bytes: u64) -> String {
112 if bytes < 1024 {
113 format!("{bytes} B")
114 } else if bytes < 1024 * 1024 {
115 format!("{:.1} KB", bytes as f64 / 1024.0)
116 } else if bytes < 1024 * 1024 * 1024 {
117 format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
118 } else {
119 format!("{:.1} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
120 }
121}
122
123pub async fn run(cmd: ModelCommands) -> Result<()> {
124 match cmd {
125 ModelCommands::Save { path, name, tag } => {
126 let source = Path::new(&path);
127 if !source.exists() {
128 anyhow::bail!("path does not exist: {path}");
129 }
130
131 let model_name = name.unwrap_or_else(|| {
132 source
133 .file_name()
134 .map(|s| s.to_string_lossy().to_string())
135 .unwrap_or_else(|| "unnamed".into())
136 });
137 let model_tag = tag.unwrap_or_else(|| "latest".into());
138
139 validation::validate_name(&model_name)?;
141 validation::validate_tag(&model_tag)?;
142
143 let git_commit = std::process::Command::new("git")
144 .args(["rev-parse", "--short", "HEAD"])
145 .output()
146 .ok()
147 .and_then(|o| {
148 if o.status.success() {
149 Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
150 } else {
151 None
152 }
153 });
154
155 let size = dir_size(source);
156
157 let dest = registry_dir().join(&model_name).join(&model_tag);
159 std::fs::create_dir_all(&dest)?;
160
161 if source.is_file() {
162 let fname = source
163 .file_name()
164 .ok_or_else(|| anyhow::anyhow!("source path has no filename"))?;
165 std::fs::copy(source, dest.join(fname))
166 .with_context(|| format!("failed to copy {path}"))?;
167 } else {
168 copy_dir_recursive(source, &dest)?;
169 }
170
171 let mut registry = load_registry();
172 registry.retain(|e| !(e.name == model_name && e.tag == model_tag));
174
175 let version = format!(
176 "{}.0.0",
177 registry.iter().filter(|e| e.name == model_name).count() + 1
178 );
179
180 registry.push(ModelEntry {
181 name: model_name.clone(),
182 version: version.clone(),
183 tag: model_tag.clone(),
184 source_path: path.clone(),
185 saved_at: Utc::now().to_rfc3339(),
186 git_commit,
187 size_bytes: size,
188 });
189
190 save_registry(®istry)?;
191
192 println!("Saved model: {model_name}:{model_tag}");
193 println!(" Version: {version}");
194 println!(" Size: {}", format_size(size));
195 println!(" Path: {}", dest.display());
196 }
197 ModelCommands::List => {
198 let registry = load_registry();
199 if registry.is_empty() {
200 println!("No models saved. Use `zernel model save <path>` to register one.");
201 return Ok(());
202 }
203
204 let header = format!(
205 "{:<20} {:<10} {:<12} {:>10} {:>10}",
206 "Name", "Version", "Tag", "Size", "Saved"
207 );
208 println!("{header}");
209 println!("{}", "-".repeat(70));
210
211 for entry in ®istry {
212 let saved = &entry.saved_at[..10]; println!(
214 "{:<20} {:<10} {:<12} {:>10} {}",
215 entry.name,
216 entry.version,
217 entry.tag,
218 format_size(entry.size_bytes),
219 saved,
220 );
221 }
222 }
223 ModelCommands::Deploy {
224 model,
225 target,
226 port,
227 registry: docker_registry,
228 region,
229 instance_type,
230 } => {
231 let (name, tag) = model.split_once(':').unwrap_or((&model, "latest"));
232
233 let registry = load_registry();
234 let entry = registry
235 .iter()
236 .find(|e| e.name == name && e.tag == tag)
237 .ok_or_else(|| anyhow::anyhow!("model not found: {name}:{tag}"))?;
238
239 let model_path = registry_dir().join(name).join(tag);
240
241 println!(
242 "Deploying {name}:{tag} (v{}) to {target} on port {port}",
243 entry.version
244 );
245 println!(" Source: {}", model_path.display());
246
247 match target.as_str() {
248 "local" => {
249 let vllm_check = std::process::Command::new("python3")
251 .args(["-c", "import vllm; print(vllm.__version__)"])
252 .output();
253
254 match vllm_check {
255 Ok(output) if output.status.success() => {
256 let version =
257 String::from_utf8_lossy(&output.stdout).trim().to_string();
258 println!(" vLLM: v{version}");
259 }
260 _ => {
261 println!();
262 println!(" vLLM not found. Install with: pip install vllm");
263 println!(
264 " Or run manually: python3 -m vllm.entrypoints.openai.api_server --model {} --port {port}",
265 model_path.display()
266 );
267 return Ok(());
268 }
269 }
270
271 println!();
272 println!("Starting vLLM inference server...");
273 println!(" URL: http://localhost:{port}/v1");
274 println!(" Press Ctrl+C to stop");
275 println!();
276
277 let status = tokio::process::Command::new("python3")
279 .args([
280 "-m",
281 "vllm.entrypoints.openai.api_server",
282 "--model",
283 &model_path.to_string_lossy(),
284 "--port",
285 &port.to_string(),
286 ])
287 .status()
288 .await
289 .with_context(|| "failed to start vLLM server")?;
290
291 if !status.success() {
292 anyhow::bail!("vLLM exited with code {}", status.code().unwrap_or(-1));
293 }
294 }
295 "docker" => {
296 println!();
297
298 let dockerfile_content = format!(
300 "FROM vllm/vllm-openai:latest\n\
301 COPY . /model\n\
302 EXPOSE {port}\n\
303 ENTRYPOINT [\"python3\", \"-m\", \"vllm.entrypoints.openai.api_server\", \
304 \"--model\", \"/model\", \"--port\", \"{port}\"]\n"
305 );
306
307 let dockerfile_path = model_path.join("Dockerfile.zernel");
308 std::fs::write(&dockerfile_path, &dockerfile_content)?;
309 println!(" Generated: {}", dockerfile_path.display());
310
311 let image_tag = format!("zernel-{name}:{tag}");
312 println!(" Building Docker image: {image_tag}");
313
314 let build_status = std::process::Command::new("docker")
315 .args(["build", "-t", &image_tag, "-f"])
316 .arg(&dockerfile_path)
317 .arg(&model_path)
318 .status()?;
319
320 if !build_status.success() {
321 anyhow::bail!("docker build failed");
322 }
323
324 println!(" Image built: {image_tag}");
325
326 if let Some(ref reg) = docker_registry {
327 let remote_tag = format!("{reg}/{image_tag}");
328 println!(" Pushing to {remote_tag}...");
329
330 std::process::Command::new("docker")
331 .args(["tag", &image_tag, &remote_tag])
332 .status()?;
333
334 let push_status = std::process::Command::new("docker")
335 .args(["push", &remote_tag])
336 .status()?;
337
338 if push_status.success() {
339 println!(" Pushed: {remote_tag}");
340 } else {
341 anyhow::bail!("docker push failed");
342 }
343 }
344
345 println!();
346 println!("Run locally: docker run --gpus all -p {port}:{port} {image_tag}");
347 }
348
349 "sagemaker" => {
350 println!();
351 println!(" Region: {region}");
352 println!(" Instance: {instance_type}");
353
354 let aws_check = std::process::Command::new("aws")
356 .args(["sts", "get-caller-identity"])
357 .output();
358
359 match aws_check {
360 Ok(output) if output.status.success() => {
361 let identity = String::from_utf8_lossy(&output.stdout);
362 println!(" AWS: authenticated");
363 let _ = identity;
364 }
365 _ => {
366 println!();
367 println!(" AWS CLI not configured. Run: aws configure");
368 return Ok(());
369 }
370 }
371
372 let s3_path = format!("s3://zernel-models/{name}/{tag}/");
373 println!(" Uploading to {s3_path}...");
374
375 let sync_status = std::process::Command::new("aws")
376 .args([
377 "s3",
378 "sync",
379 &model_path.to_string_lossy(),
380 &s3_path,
381 "--region",
382 ®ion,
383 ])
384 .status()?;
385
386 if !sync_status.success() {
387 anyhow::bail!("aws s3 sync failed");
388 }
389
390 let sm_model_name = format!("zernel-{name}-{tag}");
391
392 println!(" Creating SageMaker model: {sm_model_name}");
393 let create_status = std::process::Command::new("aws")
394 .args([
395 "sagemaker", "create-model",
396 "--model-name", &sm_model_name,
397 "--primary-container",
398 &format!("Image=763104351884.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118-ubuntu20.04,ModelDataUrl={s3_path}"),
399 "--execution-role-arn", "arn:aws:iam::role/SageMakerExecutionRole",
400 "--region", ®ion,
401 ])
402 .status()?;
403
404 if !create_status.success() {
405 println!(" SageMaker model creation failed.");
406 println!(" You may need to configure the IAM role and container image.");
407 println!(
408 " Manual: aws sagemaker create-model --model-name {sm_model_name} ..."
409 );
410 return Ok(());
411 }
412
413 println!(" Creating endpoint config...");
414 let _ = std::process::Command::new("aws")
415 .args([
416 "sagemaker", "create-endpoint-config",
417 "--endpoint-config-name", &format!("{sm_model_name}-config"),
418 "--production-variants",
419 &format!("VariantName=default,ModelName={sm_model_name},InstanceType={instance_type},InitialInstanceCount=1"),
420 "--region", ®ion,
421 ])
422 .status();
423
424 println!(" Creating endpoint...");
425 let _ = std::process::Command::new("aws")
426 .args([
427 "sagemaker",
428 "create-endpoint",
429 "--endpoint-name",
430 &sm_model_name,
431 "--endpoint-config-name",
432 &format!("{sm_model_name}-config"),
433 "--region",
434 ®ion,
435 ])
436 .status();
437
438 println!();
439 println!(" Endpoint: {sm_model_name}");
440 println!(" Check status: aws sagemaker describe-endpoint --endpoint-name {sm_model_name} --region {region}");
441 }
442
443 other => {
444 println!();
445 println!("Unknown target: '{other}'");
446 println!("Available: local, docker, sagemaker");
447 }
448 }
449 }
450 }
451 Ok(())
452}
453
454fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> {
455 std::fs::create_dir_all(dst)?;
456 for entry in std::fs::read_dir(src)? {
457 let entry = entry?;
458 let ft = entry.file_type()?;
459 let dest_path = dst.join(entry.file_name());
460 if ft.is_file() {
461 std::fs::copy(entry.path(), &dest_path)?;
462 } else if ft.is_dir() {
463 copy_dir_recursive(&entry.path(), &dest_path)?;
464 }
465 }
466 Ok(())
467}