1use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::process::Command;
8
9#[derive(Subcommand)]
10pub enum ServeCommands {
11 Start {
13 model: String,
15 #[arg(long, default_value = "auto")]
17 engine: String,
18 #[arg(long, default_value = "8080")]
20 port: u16,
21 #[arg(long, default_value = "1")]
23 replicas: u32,
24 #[arg(long, default_value = "none")]
26 quantize: String,
27 },
28 List,
30 Stop {
32 model: String,
34 },
35 Logs {
37 model: String,
39 },
40 Benchmark {
42 #[arg(default_value = "http://localhost:8080")]
44 url: String,
45 #[arg(long, default_value = "10")]
47 qps: u32,
48 #[arg(long, default_value = "30")]
50 duration: u32,
51 },
52}
53
54fn detect_engine(model_path: &str) -> &'static str {
55 let path = std::path::Path::new(model_path);
56
57 if path.join("config.json").exists() {
59 return "vllm"; }
61 if path.extension().map(|e| e == "onnx").unwrap_or(false) {
62 return "onnx";
63 }
64 if path
65 .extension()
66 .map(|e| e == "engine" || e == "plan")
67 .unwrap_or(false)
68 {
69 return "trt";
70 }
71
72 "vllm"
74}
75
76pub async fn run(cmd: ServeCommands) -> Result<()> {
77 match cmd {
78 ServeCommands::Start {
79 model,
80 engine,
81 port,
82 replicas,
83 quantize,
84 } => {
85 let selected_engine = if engine == "auto" {
86 detect_engine(&model)
87 } else {
88 engine.as_str()
89 };
90
91 println!("Zernel Serve");
92 println!(" Model: {model}");
93 println!(" Engine: {selected_engine}");
94 println!(" Port: {port}");
95 println!(" Replicas: {replicas}");
96 println!(" Quantize: {quantize}");
97 println!();
98
99 match selected_engine {
100 "vllm" => {
101 let mut args = vec![
102 "-m".into(),
103 "vllm.entrypoints.openai.api_server".into(),
104 "--model".into(),
105 model.clone(),
106 "--port".into(),
107 port.to_string(),
108 ];
109
110 if replicas > 1 {
111 args.extend(["--tensor-parallel-size".into(), replicas.to_string()]);
112 }
113
114 if quantize != "none" {
115 args.extend(["--quantization".into(), quantize]);
116 }
117
118 println!("Starting vLLM server...");
119 println!(" URL: http://localhost:{port}/v1");
120 println!(" Docs: http://localhost:{port}/docs");
121 println!(" Press Ctrl+C to stop");
122 println!();
123
124 let status = tokio::process::Command::new("python3")
125 .args(&args)
126 .status()
127 .await
128 .with_context(|| "failed to start vLLM — install with: pip install vllm")?;
129
130 if !status.success() {
131 anyhow::bail!("vLLM exited with code {}", status.code().unwrap_or(-1));
132 }
133 }
134
135 "trt" => {
136 println!("Starting TensorRT server...");
137 let status = tokio::process::Command::new("tritonserver")
138 .args([
139 "--model-repository",
140 &model,
141 "--http-port",
142 &port.to_string(),
143 ])
144 .status()
145 .await
146 .with_context(|| "tritonserver not found — install NVIDIA Triton")?;
147 let _ = status;
148 }
149
150 "onnx" => {
151 println!("Starting ONNX Runtime server...");
152 let code = format!(
153 "import onnxruntime as ort; \
154 from fastapi import FastAPI; import uvicorn; \
155 app=FastAPI(); session=ort.InferenceSession('{model}'); \
156 uvicorn.run(app, host='0.0.0.0', port={port})"
157 );
158 let status = tokio::process::Command::new("python3")
159 .args(["-c", &code])
160 .status()
161 .await?;
162 let _ = status;
163 }
164
165 other => {
166 println!("Unknown engine: {other}");
167 println!("Available: vllm, trt, onnx");
168 }
169 }
170 }
171
172 ServeCommands::List => {
173 println!("Running Inference Servers");
174 println!("{}", "=".repeat(60));
175
176 for port in [8080, 8081, 8082, 8000, 5000] {
178 let check = std::net::TcpStream::connect_timeout(
179 &format!("127.0.0.1:{port}").parse().expect("valid addr"),
180 std::time::Duration::from_millis(200),
181 );
182 if check.is_ok() {
183 println!(" :{port} — active");
184 }
185 }
186 }
187
188 ServeCommands::Stop { model } => {
189 println!("Stopping server for: {model}");
191 let output = Command::new("pkill")
192 .args(["-f", &format!("vllm.*{model}")])
193 .output();
194 match output {
195 Ok(o) if o.status.success() => println!("Server stopped."),
196 _ => println!("No server found for {model}. Try: zernel serve list"),
197 }
198 }
199
200 ServeCommands::Logs { model } => {
201 println!("Showing logs for model: {model}");
202 println!("(inference log streaming coming in future release)");
203 println!("For now: journalctl -u zernel-serve-{model} -f");
204 }
205
206 ServeCommands::Benchmark { url, qps, duration } => {
207 println!("Load Testing: {url}");
208 println!(" QPS: {qps}");
209 println!(" Duration: {duration}s");
210 println!();
211
212 let code = format!(
213 "import requests, time, concurrent.futures, statistics; \
214 url='{url}/v1/models'; latencies=[]; errors=0; \
215 start=time.time(); \
216 with concurrent.futures.ThreadPoolExecutor(max_workers={qps}) as ex: \
217 while time.time()-start < {duration}: \
218 futs=[ex.submit(requests.get, url) for _ in range({qps})]; \
219 for f in futs: \
220 try: \
221 r=f.result(timeout=5); latencies.append(r.elapsed.total_seconds()*1000); \
222 except: errors+=1; \
223 time.sleep(1); \
224 if latencies: \
225 print(f'Requests: {{len(latencies)}}'); \
226 print(f'Errors: {{errors}}'); \
227 print(f'p50: {{statistics.median(latencies):.1f}} ms'); \
228 print(f'p99: {{sorted(latencies)[int(len(latencies)*0.99)]:.1f}} ms'); \
229 print(f'Mean: {{statistics.mean(latencies):.1f}} ms'); \
230 print(f'Throughput: {{len(latencies)/{duration}:.0f}} req/s')"
231 );
232 let output = Command::new("python3").args(["-c", &code]).output()?;
233 print!("{}", String::from_utf8_lossy(&output.stdout));
234 if !output.status.success() {
235 print!("{}", String::from_utf8_lossy(&output.stderr));
236 }
237 }
238 }
239 Ok(())
240}