zernel/commands/
serve.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel serve — Unified inference server
4
5use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::process::Command;
8
9#[derive(Subcommand)]
10pub enum ServeCommands {
11    /// Start serving a model
12    Start {
13        /// Path to model or model name
14        model: String,
15        /// Inference engine (auto, vllm, trt, onnx)
16        #[arg(long, default_value = "auto")]
17        engine: String,
18        /// Port to serve on
19        #[arg(long, default_value = "8080")]
20        port: u16,
21        /// Number of GPU replicas
22        #[arg(long, default_value = "1")]
23        replicas: u32,
24        /// Quantization (none, int8, int4)
25        #[arg(long, default_value = "none")]
26        quantize: String,
27    },
28    /// List running inference servers
29    List,
30    /// Stop an inference server
31    Stop {
32        /// Model name or port
33        model: String,
34    },
35    /// Show server logs
36    Logs {
37        /// Model name
38        model: String,
39    },
40    /// Load test an inference endpoint
41    Benchmark {
42        /// URL to benchmark
43        #[arg(default_value = "http://localhost:8080")]
44        url: String,
45        /// Queries per second
46        #[arg(long, default_value = "10")]
47        qps: u32,
48        /// Duration in seconds
49        #[arg(long, default_value = "30")]
50        duration: u32,
51    },
52}
53
54fn detect_engine(model_path: &str) -> &'static str {
55    let path = std::path::Path::new(model_path);
56
57    // Check for engine-specific files
58    if path.join("config.json").exists() {
59        return "vllm"; // HuggingFace model
60    }
61    if path.extension().map(|e| e == "onnx").unwrap_or(false) {
62        return "onnx";
63    }
64    if path
65        .extension()
66        .map(|e| e == "engine" || e == "plan")
67        .unwrap_or(false)
68    {
69        return "trt";
70    }
71
72    // Default to vLLM (most versatile)
73    "vllm"
74}
75
76pub async fn run(cmd: ServeCommands) -> Result<()> {
77    match cmd {
78        ServeCommands::Start {
79            model,
80            engine,
81            port,
82            replicas,
83            quantize,
84        } => {
85            let selected_engine = if engine == "auto" {
86                detect_engine(&model)
87            } else {
88                engine.as_str()
89            };
90
91            println!("Zernel Serve");
92            println!("  Model:    {model}");
93            println!("  Engine:   {selected_engine}");
94            println!("  Port:     {port}");
95            println!("  Replicas: {replicas}");
96            println!("  Quantize: {quantize}");
97            println!();
98
99            match selected_engine {
100                "vllm" => {
101                    let mut args = vec![
102                        "-m".into(),
103                        "vllm.entrypoints.openai.api_server".into(),
104                        "--model".into(),
105                        model.clone(),
106                        "--port".into(),
107                        port.to_string(),
108                    ];
109
110                    if replicas > 1 {
111                        args.extend(["--tensor-parallel-size".into(), replicas.to_string()]);
112                    }
113
114                    if quantize != "none" {
115                        args.extend(["--quantization".into(), quantize]);
116                    }
117
118                    println!("Starting vLLM server...");
119                    println!("  URL: http://localhost:{port}/v1");
120                    println!("  Docs: http://localhost:{port}/docs");
121                    println!("  Press Ctrl+C to stop");
122                    println!();
123
124                    let status = tokio::process::Command::new("python3")
125                        .args(&args)
126                        .status()
127                        .await
128                        .with_context(|| "failed to start vLLM — install with: pip install vllm")?;
129
130                    if !status.success() {
131                        anyhow::bail!("vLLM exited with code {}", status.code().unwrap_or(-1));
132                    }
133                }
134
135                "trt" => {
136                    println!("Starting TensorRT server...");
137                    let status = tokio::process::Command::new("tritonserver")
138                        .args([
139                            "--model-repository",
140                            &model,
141                            "--http-port",
142                            &port.to_string(),
143                        ])
144                        .status()
145                        .await
146                        .with_context(|| "tritonserver not found — install NVIDIA Triton")?;
147                    let _ = status;
148                }
149
150                "onnx" => {
151                    println!("Starting ONNX Runtime server...");
152                    let code = format!(
153                        "import onnxruntime as ort; \
154                         from fastapi import FastAPI; import uvicorn; \
155                         app=FastAPI(); session=ort.InferenceSession('{model}'); \
156                         uvicorn.run(app, host='0.0.0.0', port={port})"
157                    );
158                    let status = tokio::process::Command::new("python3")
159                        .args(["-c", &code])
160                        .status()
161                        .await?;
162                    let _ = status;
163                }
164
165                other => {
166                    println!("Unknown engine: {other}");
167                    println!("Available: vllm, trt, onnx");
168                }
169            }
170        }
171
172        ServeCommands::List => {
173            println!("Running Inference Servers");
174            println!("{}", "=".repeat(60));
175
176            // Check common inference ports
177            for port in [8080, 8081, 8082, 8000, 5000] {
178                let check = std::net::TcpStream::connect_timeout(
179                    &format!("127.0.0.1:{port}").parse().expect("valid addr"),
180                    std::time::Duration::from_millis(200),
181                );
182                if check.is_ok() {
183                    println!("  :{port} — active");
184                }
185            }
186        }
187
188        ServeCommands::Stop { model } => {
189            // Try to find the process serving this model
190            println!("Stopping server for: {model}");
191            let output = Command::new("pkill")
192                .args(["-f", &format!("vllm.*{model}")])
193                .output();
194            match output {
195                Ok(o) if o.status.success() => println!("Server stopped."),
196                _ => println!("No server found for {model}. Try: zernel serve list"),
197            }
198        }
199
200        ServeCommands::Logs { model } => {
201            println!("Showing logs for model: {model}");
202            println!("(inference log streaming coming in future release)");
203            println!("For now: journalctl -u zernel-serve-{model} -f");
204        }
205
206        ServeCommands::Benchmark { url, qps, duration } => {
207            println!("Load Testing: {url}");
208            println!("  QPS: {qps}");
209            println!("  Duration: {duration}s");
210            println!();
211
212            let code = format!(
213                "import requests, time, concurrent.futures, statistics; \
214                 url='{url}/v1/models'; latencies=[]; errors=0; \
215                 start=time.time(); \
216                 with concurrent.futures.ThreadPoolExecutor(max_workers={qps}) as ex: \
217                     while time.time()-start < {duration}: \
218                         futs=[ex.submit(requests.get, url) for _ in range({qps})]; \
219                         for f in futs: \
220                             try: \
221                                 r=f.result(timeout=5); latencies.append(r.elapsed.total_seconds()*1000); \
222                             except: errors+=1; \
223                         time.sleep(1); \
224                 if latencies: \
225                     print(f'Requests:  {{len(latencies)}}'); \
226                     print(f'Errors:    {{errors}}'); \
227                     print(f'p50:       {{statistics.median(latencies):.1f}} ms'); \
228                     print(f'p99:       {{sorted(latencies)[int(len(latencies)*0.99)]:.1f}} ms'); \
229                     print(f'Mean:      {{statistics.mean(latencies):.1f}} ms'); \
230                     print(f'Throughput: {{len(latencies)/{duration}:.0f}} req/s')"
231            );
232            let output = Command::new("python3").args(["-c", &code]).output()?;
233            print!("{}", String::from_utf8_lossy(&output.stdout));
234            if !output.status.success() {
235                print!("{}", String::from_utf8_lossy(&output.stderr));
236            }
237        }
238    }
239    Ok(())
240}