1use anyhow::Result;
13use clap::Subcommand;
14use std::process::Command;
15
16#[derive(Subcommand)]
17pub enum OptimizeCommands {
18 Precision {
20 script: String,
22 },
23 BatchSize {
25 model: String,
27 #[arg(long, default_value = "512")]
29 seq_len: u32,
30 #[arg(long)]
32 amp: bool,
33 #[arg(long, default_value = "0")]
35 gpu: u32,
36 },
37 Checkpoint {
39 script: String,
41 #[arg(long, default_value = "0.85")]
43 target_mem: f32,
44 },
45 DataPipeline {
47 script: String,
49 #[arg(long, default_value = "20")]
51 steps: u32,
52 },
53 Memory,
55 Numa,
57 Scan {
59 script: Option<String>,
61 },
62 Auto {
64 script: String,
66 #[arg(long, short, default_value = "train_optimized.py")]
68 output: String,
69 },
70}
71
72pub async fn run(cmd: OptimizeCommands) -> Result<()> {
73 match cmd {
74 OptimizeCommands::Precision { script } => run_precision(&script).await,
75 OptimizeCommands::BatchSize { model, seq_len, amp, gpu } => {
76 run_batch_size(&model, seq_len, amp, gpu).await
77 }
78 OptimizeCommands::Checkpoint { script, target_mem } => {
79 run_checkpoint(&script, target_mem).await
80 }
81 OptimizeCommands::DataPipeline { script, steps } => {
82 run_data_pipeline(&script, steps).await
83 }
84 OptimizeCommands::Memory => run_memory().await,
85 OptimizeCommands::Numa => run_numa().await,
86 OptimizeCommands::Scan { script } => run_scan(script.as_deref()).await,
87 OptimizeCommands::Auto { script, output } => run_auto(&script, &output).await,
88 }
89}
90
91async fn run_precision(script: &str) -> Result<()> {
92 println!("Mixed Precision Analyzer");
93 println!("{}", "=".repeat(60));
94 println!("Script: {script}");
95 println!();
96
97 let output = Command::new("python3")
98 .args(["-c", r#"
99import torch, sys, os, ast, re
100
101# 1. Check GPU capabilities
102if not torch.cuda.is_available():
103 print("ERROR: No CUDA GPU available")
104 sys.exit(1)
105
106cap = torch.cuda.get_device_capability()
107name = torch.cuda.get_device_name()
108mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
109
110print(f"GPU: {name}")
111print(f"Compute capability: {cap[0]}.{cap[1]}")
112print(f"Memory: {mem_gb:.1f} GB")
113print()
114
115# Determine best precision
116if cap[0] >= 9:
117 best = "FP8 (Hopper+)"
118 dtype = "torch.float8_e4m3fn"
119 speedup = "3-4x"
120elif cap[0] >= 8:
121 best = "BF16 (Ampere+)"
122 dtype = "torch.bfloat16"
123 speedup = "1.5-2x"
124elif cap[0] >= 7:
125 best = "FP16 (Volta+)"
126 dtype = "torch.float16"
127 speedup = "1.5-2x"
128else:
129 best = "FP32 only"
130 dtype = "torch.float32"
131 speedup = "1x (no speedup)"
132
133print(f"Recommended precision: {best}")
134print(f"Expected speedup: {speedup}")
135print(f"Memory savings: {2 if cap[0] >= 7 else 1}x (FP32: {mem_gb:.1f}GB -> {'~' + str(round(mem_gb*2, 1)) if cap[0] >= 7 else str(round(mem_gb, 1))}GB effective)")
136print()
137
138# 2. Scan script for existing AMP usage
139script_path = os.environ['ZERNEL_ARG_SCRIPT']
140has_autocast = False
141has_gradscaler = False
142has_bf16 = False
143has_fp16 = False
144
145if os.path.exists(script_path):
146 with open(script_path) as f:
147 content = f.read()
148 has_autocast = "autocast" in content
149 has_gradscaler = "GradScaler" in content
150 has_bf16 = "bfloat16" in content
151 has_fp16 = "float16" in content and "bfloat" not in content
152
153 if has_autocast:
154 print("STATUS: Script already uses torch.autocast (AMP enabled)")
155 if has_bf16:
156 print(" Using: BF16 precision")
157 elif has_fp16:
158 print(" Using: FP16 precision")
159 if cap[0] >= 8:
160 print(" TIP: Switch to BF16 for better numerical stability on this GPU")
161 if not has_gradscaler and has_fp16:
162 print(" WARNING: Using FP16 without GradScaler - may have loss scaling issues")
163 print(" Add: scaler = torch.cuda.amp.GradScaler()")
164 else:
165 print("STATUS: Script does NOT use mixed precision")
166 print()
167 print("To enable AMP, wrap your training loop:")
168 print()
169 if cap[0] >= 8:
170 print(" # Add at top of script:")
171 print(" torch.backends.cuda.matmul.allow_tf32 = True")
172 print(" torch.backends.cudnn.allow_tf32 = True")
173 print()
174 print(" # Wrap forward pass:")
175 print(" with torch.autocast('cuda', dtype=torch.bfloat16):")
176 print(" output = model(input)")
177 print(" loss = loss_fn(output, target)")
178 print(" loss.backward()")
179 print(" optimizer.step()")
180 else:
181 print(" scaler = torch.cuda.amp.GradScaler()")
182 print()
183 print(" with torch.autocast('cuda', dtype=torch.float16):")
184 print(" output = model(input)")
185 print(" loss = loss_fn(output, target)")
186 print(" scaler.scale(loss).backward()")
187 print(" scaler.step(optimizer)")
188 print(" scaler.update()")
189 print()
190 print(f"Or use: zernel optimize auto {script_path} to generate an optimized wrapper")
191else:
192 print(f"Script not found: {script_path}")
193 print("Showing general recommendations based on GPU capability.")
194 print()
195 if cap[0] >= 8:
196 print(" Best: torch.autocast('cuda', dtype=torch.bfloat16)")
197 print(" Also: torch.backends.cuda.matmul.allow_tf32 = True")
198 elif cap[0] >= 7:
199 print(" Best: torch.autocast('cuda', dtype=torch.float16)")
200 print(" Need: torch.cuda.amp.GradScaler() for loss scaling")
201"#])
202 .env("ZERNEL_ARG_SCRIPT", script)
203 .output()?;
204
205 print!("{}", String::from_utf8_lossy(&output.stdout));
206 if !output.status.success() {
207 eprint!("{}", String::from_utf8_lossy(&output.stderr));
208 }
209 Ok(())
210}
211
212async fn run_batch_size(model: &str, seq_len: u32, amp: bool, gpu: u32) -> Result<()> {
213 println!("Smart Batch Size Calculator");
214 println!("{}", "=".repeat(60));
215 println!();
216
217 let output = Command::new("python3")
218 .args(["-c", r#"
219import torch, sys, math, os
220
221if not torch.cuda.is_available():
222 print("ERROR: No CUDA GPU available")
223 sys.exit(1)
224
225gpu_id = int(os.environ['ZERNEL_ARG_GPU'])
226props = torch.cuda.get_device_properties(gpu_id)
227total_mem = props.total_memory
228name = props.name
229mem_gb = total_mem / 1e9
230
231print(f"GPU {gpu_id}: {name} ({mem_gb:.1f} GB)")
232
233# Parse model size
234model_str = os.environ['ZERNEL_ARG_MODEL'].lower().strip()
235param_count = None
236
237# Known models
238known = {
239 "gpt2": 124e6, "gpt2-medium": 355e6, "gpt2-large": 774e6, "gpt2-xl": 1.5e9,
240 "llama-7b": 7e9, "llama-13b": 13e9, "llama-70b": 70e9,
241 "mistral-7b": 7.2e9, "phi-2": 2.7e9, "phi-3": 3.8e9,
242 "bert-base": 110e6, "bert-large": 340e6,
243 "t5-small": 60e6, "t5-base": 220e6, "t5-large": 770e6,
244 "vit-base": 86e6, "vit-large": 307e6,
245 "resnet50": 25.6e6, "resnet101": 44.5e6,
246}
247
248if model_str in known:
249 param_count = known[model_str]
250 print(f"Model: {model_str} ({param_count/1e6:.0f}M params)")
251elif model_str.endswith('b'):
252 param_count = float(model_str[:-1]) * 1e9
253 print(f"Model: {param_count/1e9:.1f}B params")
254elif model_str.endswith('m'):
255 param_count = float(model_str[:-1]) * 1e6
256 print(f"Model: {param_count/1e6:.0f}M params")
257else:
258 try:
259 param_count = float(model_str)
260 print(f"Model: {param_count/1e6:.0f}M params")
261 except:
262 print(f"Unknown model: {model_str}")
263 print("Use a known name (gpt2, llama-7b, etc.) or param count (125M, 1.3B)")
264 sys.exit(1)
265
266seq_len = int(os.environ['ZERNEL_ARG_SEQ_LEN'])
267use_amp = os.environ['ZERNEL_ARG_AMP'] == "1"
268bytes_per_param = 2 if use_amp else 4
269
270print(f"Sequence length: {seq_len}")
271print(f"Precision: {'FP16/BF16' if use_amp else 'FP32'}")
272print()
273
274# Memory estimation (rough but practical):
275# Model params: param_count * bytes_per_param
276# Gradients: same as model
277# Optimizer state (AdamW): 2x model size (momentum + variance) in FP32
278# Activations: depends on model architecture, roughly:
279# For transformers: ~12 * n_layers * hidden_dim * seq_len * batch_size * bytes_per_param
280# Simplified: ~6 * param_count * seq_len / hidden_dim * batch_size * bytes_per_param / 1024
281
282model_mem = param_count * bytes_per_param
283grad_mem = model_mem
284opt_mem = param_count * 8 # AdamW always FP32: momentum + variance = 2 * 4 bytes
285static_mem = model_mem + grad_mem + opt_mem
286cuda_overhead = 0.5e9 # ~500MB CUDA context
287
288available = total_mem - static_mem - cuda_overhead
289if available <= 0:
290 print(f"WARNING: Model doesn't fit in GPU memory!")
291 print(f" Model + gradients + optimizer: {static_mem/1e9:.1f} GB")
292 print(f" GPU memory: {mem_gb:.1f} GB")
293 print(f" Shortfall: {(static_mem + cuda_overhead - total_mem)/1e9:.1f} GB")
294 print()
295 print("Recommendations:")
296 print(f" 1. Use mixed precision: zernel optimize batch-size {model_str} --amp")
297 print(" 2. Use gradient checkpointing: zernel optimize checkpoint <script>")
298 print(" 3. Use model parallelism or offloading")
299 sys.exit(0)
300
301# Estimate activation memory per sample
302# Rough heuristic: 4-6 bytes per param per sample (varies by architecture)
303if param_count > 1e9:
304 act_per_sample = param_count * 4 * seq_len / 2048 # scale with seq_len
305else:
306 act_per_sample = param_count * 6 * seq_len / 512
307
308act_per_sample *= (0.5 if use_amp else 1.0)
309
310max_batch = max(1, int(available / act_per_sample))
311# Round down to power of 2 for efficiency
312optimal_batch = 2 ** int(math.log2(max_batch)) if max_batch >= 2 else 1
313safe_batch = max(1, optimal_batch // 2) # 50% margin for safety
314
315print(f"Memory breakdown:")
316print(f" Model weights: {model_mem/1e9:.2f} GB")
317print(f" Gradients: {grad_mem/1e9:.2f} GB")
318print(f" Optimizer (AdamW):{opt_mem/1e9:.2f} GB")
319print(f" CUDA overhead: {cuda_overhead/1e9:.2f} GB")
320print(f" Available for activations: {available/1e9:.2f} GB")
321print()
322print(f"Recommended batch sizes:")
323print(f" Maximum: {max_batch} (uses ~100% GPU memory — risky)")
324print(f" Optimal: {optimal_batch} (power-of-2, ~{optimal_batch * act_per_sample / available * 100:.0f}% memory)")
325print(f" Safe: {safe_batch} (50% margin for variable-length sequences)")
326print()
327
328if not use_amp and param_count > 100e6:
329 amp_available = total_mem - (param_count * 2 + param_count * 2 + opt_mem + cuda_overhead)
330 amp_act = act_per_sample * 0.5
331 amp_max = max(1, int(amp_available / amp_act))
332 amp_optimal = 2 ** int(math.log2(amp_max)) if amp_max >= 2 else 1
333 print(f"With mixed precision (--amp):")
334 print(f" Optimal batch size: {amp_optimal} ({amp_optimal/optimal_batch:.1f}x larger)")
335 print(f" Run: zernel optimize batch-size {model_str} --seq-len {seq_len} --amp")
336"#])
337 .env("ZERNEL_ARG_MODEL", model)
338 .env("ZERNEL_ARG_SEQ_LEN", seq_len.to_string())
339 .env("ZERNEL_ARG_AMP", if amp { "1" } else { "0" })
340 .env("ZERNEL_ARG_GPU", gpu.to_string())
341 .output()?;
342
343 print!("{}", String::from_utf8_lossy(&output.stdout));
344 if !output.status.success() {
345 eprint!("{}", String::from_utf8_lossy(&output.stderr));
346 }
347 Ok(())
348}
349
350async fn run_checkpoint(script: &str, target_mem: f32) -> Result<()> {
351 println!("Gradient Checkpointing Analyzer");
352 println!("{}", "=".repeat(60));
353 println!("Script: {script}");
354 println!("Target memory usage: {:.0}%", target_mem * 100.0);
355 println!();
356
357 let output = Command::new("python3")
358 .args(["-c", r#"
359import torch, sys, os
360
361if not torch.cuda.is_available():
362 print("ERROR: No CUDA GPU available")
363 sys.exit(1)
364
365mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
366name = torch.cuda.get_device_name()
367target = float(os.environ['ZERNEL_ARG_TARGET_MEM'])
368
369print(f"GPU: {name} ({mem_gb:.1f} GB)")
370print(f"Target usage: {target*100:.0f}% ({mem_gb*target:.1f} GB)")
371print()
372
373# Check if script exists and scan for model definition
374script_path = os.environ['ZERNEL_ARG_SCRIPT']
375has_checkpoint = False
376model_type = "unknown"
377
378if os.path.exists(script_path):
379 with open(script_path) as f:
380 content = f.read()
381 has_checkpoint = "checkpoint_sequential" in content or "checkpoint" in content and "gradient" in content.lower()
382
383 if "transformers" in content or "AutoModel" in content:
384 model_type = "huggingface"
385 elif "nn.Sequential" in content or "nn.Module" in content:
386 model_type = "pytorch"
387
388if has_checkpoint:
389 print("STATUS: Script already uses gradient checkpointing")
390else:
391 print("STATUS: Gradient checkpointing NOT enabled")
392 print()
393 print("Gradient checkpointing trades compute for memory:")
394 print(" - Saves ~60-70% activation memory")
395 print(" - Costs ~30% extra compute (recomputes activations in backward)")
396 print(" - Allows ~2-3x larger batch sizes or models")
397 print()
398
399 if model_type == "huggingface":
400 print("For HuggingFace models:")
401 print(" model.gradient_checkpointing_enable()")
402 print()
403 print("Or in TrainingArguments:")
404 print(" TrainingArguments(gradient_checkpointing=True, ...)")
405 else:
406 print("For PyTorch models:")
407 print(" from torch.utils.checkpoint import checkpoint_sequential")
408 print()
409 print(" # In your model's forward():")
410 print(" # Instead of: x = self.layers(x)")
411 print(" # Use: x = checkpoint_sequential(self.layers, segments, x)")
412 print()
413 print(" # Or per-layer:")
414 print(" from torch.utils.checkpoint import checkpoint")
415 print(" for layer in self.layers:")
416 print(" x = checkpoint(layer, x, use_reentrant=False)")
417
418 print()
419 print("When to use gradient checkpointing:")
420 print(f" - Model + optimizer > {mem_gb*0.6:.1f} GB (60% of GPU memory)")
421 print(f" - You're getting OOM errors")
422 print(f" - You want to increase batch size without more GPUs")
423 print()
424 print("When NOT to use it:")
425 print(f" - You have plenty of GPU memory headroom")
426 print(f" - Training is already compute-bound (GPU util > 90%)")
427 print(f" - The 30% slowdown isn't acceptable")
428"#])
429 .env("ZERNEL_ARG_SCRIPT", script)
430 .env("ZERNEL_ARG_TARGET_MEM", target_mem.to_string())
431 .output()?;
432
433 print!("{}", String::from_utf8_lossy(&output.stdout));
434 if !output.status.success() {
435 eprint!("{}", String::from_utf8_lossy(&output.stderr));
436 }
437 Ok(())
438}
439
440async fn run_data_pipeline(script: &str, steps: u32) -> Result<()> {
441 println!("Data Pipeline Profiler");
442 println!("{}", "=".repeat(60));
443 println!("Script: {script}");
444 println!("Profiling: {steps} steps");
445 println!();
446
447 let output = Command::new("python3")
448 .args(["-c", r#"
449import torch, time, sys, os, statistics
450
451if not torch.cuda.is_available():
452 print("ERROR: No CUDA GPU available")
453 sys.exit(1)
454
455print(f"GPU: {torch.cuda.get_device_name()}")
456print(f"CPU cores: {os.cpu_count()}")
457print()
458
459# Profile different DataLoader configurations
460from torch.utils.data import DataLoader, TensorDataset
461
462# Create dummy dataset (1M samples of 224x224x3 "images")
463n_samples = 10000
464data = torch.randn(n_samples, 3, 224, 224)
465labels = torch.randint(0, 1000, (n_samples,))
466dataset = TensorDataset(data, labels)
467
468configs = [
469 {"num_workers": 0, "pin_memory": False, "persistent_workers": False, "label": "Default (0 workers)"},
470 {"num_workers": 2, "pin_memory": False, "persistent_workers": False, "label": "2 workers"},
471 {"num_workers": 4, "pin_memory": True, "persistent_workers": False, "label": "4 workers + pin_memory"},
472 {"num_workers": 4, "pin_memory": True, "persistent_workers": True, "label": "4 workers + pin + persistent"},
473 {"num_workers": 8, "pin_memory": True, "persistent_workers": True, "label": "8 workers + pin + persistent"},
474]
475
476n_steps = int(os.environ['ZERNEL_ARG_STEPS'])
477batch_size = 64
478
479print(f"Profiling {n_steps} batches (batch_size={batch_size}) per config...")
480print()
481print(f" {'Config':<40} {'Time(ms)':>10} {'Throughput':>12} {'vs Best':>10}")
482print(f" {'-'*75}")
483
484best_time = float('inf')
485results = []
486
487for cfg in configs:
488 kw = {k: v for k, v in cfg.items() if k != 'label'}
489 if kw.get('persistent_workers') and kw.get('num_workers', 0) == 0:
490 continue
491 try:
492 loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, **kw)
493 it = iter(loader)
494
495 # Warmup
496 for _ in range(3):
497 try:
498 batch = next(it)
499 except StopIteration:
500 it = iter(loader)
501 batch = next(it)
502
503 # Benchmark
504 times = []
505 for _ in range(n_steps):
506 t0 = time.perf_counter()
507 try:
508 batch = next(it)
509 except StopIteration:
510 it = iter(loader)
511 batch = next(it)
512 # Simulate GPU transfer
513 x = batch[0].to('cuda', non_blocking=kw.get('pin_memory', False))
514 torch.cuda.synchronize()
515 times.append((time.perf_counter() - t0) * 1000)
516
517 avg = statistics.mean(times)
518 throughput = batch_size * 1000 / avg
519 best_time = min(best_time, avg)
520 results.append((cfg['label'], avg, throughput))
521
522 del loader
523 except Exception as e:
524 results.append((cfg['label'], 0, 0))
525
526for label, avg, throughput in results:
527 if avg > 0:
528 ratio = best_time / avg
529 marker = " <-- best" if abs(avg - best_time) < 0.01 else ""
530 print(f" {label:<40} {avg:>8.2f}ms {throughput:>9.0f} s/s {ratio:>7.2f}x{marker}")
531 else:
532 print(f" {label:<40} {'FAILED':>10}")
533
534print()
535
536# Recommendations
537best_cfg = None
538best_avg = float('inf')
539for label, avg, _ in results:
540 if 0 < avg < best_avg:
541 best_avg = avg
542 best_cfg = label
543
544if best_cfg:
545 print(f"Recommendation: Use '{best_cfg}'")
546 print()
547 print("Add to your DataLoader:")
548 print(" DataLoader(")
549 print(" dataset,")
550 print(f" batch_size={batch_size},")
551 if "8 worker" in best_cfg:
552 print(" num_workers=8,")
553 elif "4 worker" in best_cfg:
554 print(" num_workers=4,")
555 elif "2 worker" in best_cfg:
556 print(" num_workers=2,")
557 if "pin" in best_cfg:
558 print(" pin_memory=True,")
559 if "persistent" in best_cfg:
560 print(" persistent_workers=True,")
561 print(" prefetch_factor=2,")
562 print(" )")
563
564 # Check if data loading is the bottleneck
565 typical_gpu_step = 50 # ms for a typical forward+backward
566 if best_avg > typical_gpu_step * 0.3:
567 print()
568 print("WARNING: Data loading may be a bottleneck!")
569 print(f" Data load time ({best_avg:.1f}ms) is >{typical_gpu_step*0.3:.0f}ms (30% of typical GPU step)")
570 print(" Consider: larger prefetch_factor, SSD storage, or data preprocessing")
571"#])
572 .env("ZERNEL_ARG_SCRIPT", script)
573 .env("ZERNEL_ARG_STEPS", steps.to_string())
574 .output()?;
575
576 print!("{}", String::from_utf8_lossy(&output.stdout));
577 if !output.status.success() {
578 eprint!("{}", String::from_utf8_lossy(&output.stderr));
579 }
580 Ok(())
581}
582
583async fn run_memory() -> Result<()> {
584 println!("CUDA Memory Allocator Configuration");
585 println!("{}", "=".repeat(60));
586
587 let code = r#"
588import os, torch
589if torch.cuda.is_available():
590 conf = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '(not set)')
591 print(f" PYTORCH_CUDA_ALLOC_CONF: {conf}")
592 props = torch.cuda.get_device_properties(0)
593 print(f" GPU memory: {props.total_memory / 1e9:.1f} GB")
594 print()
595 print("Recommended settings:")
596 print(" For training:")
597 print(" export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,garbage_collection_threshold:0.6,max_split_size_mb:512")
598 print(" For inference:")
599 print(" export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128")
600"#;
601 let output = Command::new("python3").args(["-c", code]).output()?;
602 print!("{}", String::from_utf8_lossy(&output.stdout));
603 Ok(())
604}
605
606async fn run_numa() -> Result<()> {
607 println!("NUMA-Aware Data Placement");
608 println!("{}", "=".repeat(60));
609
610 #[cfg(target_os = "linux")]
611 {
612 if let Ok(entries) = std::fs::read_dir("/sys/devices/system/node") {
613 let mut node_count = 0;
614 for entry in entries.flatten() {
615 let name = entry.file_name().to_string_lossy().to_string();
616 if name.starts_with("node") {
617 node_count += 1;
618 let cpulist_path = entry.path().join("cpulist");
619 if let Ok(cpus) = std::fs::read_to_string(&cpulist_path) {
620 println!(" {name}: CPUs {}", cpus.trim());
621 }
622 }
623 }
624 println!();
625 if node_count > 1 {
626 println!("Multi-NUMA system ({node_count} nodes). Use:");
627 println!(" numactl --cpunodebind=0 --membind=0 python3 train.py");
628 } else {
629 println!("Single NUMA node — no placement optimization needed.");
630 }
631 }
632 }
633
634 #[cfg(not(target_os = "linux"))]
635 println!(" NUMA detection requires Linux.");
636
637 Ok(())
638}
639
640async fn run_scan(script: Option<&str>) -> Result<()> {
641 println!("Zernel Optimization Scan");
642 println!("{}", "=".repeat(60));
643 println!();
644
645 let script_arg = script.unwrap_or("(none)");
646 let output = Command::new("python3")
647 .args(["-c", r#"
648import torch, os, sys
649
650issues = []
651recommendations = []
652
653print("Environment:")
654print(f" PyTorch: {torch.__version__}")
655print(f" CUDA: {torch.cuda.is_available()}")
656
657if torch.cuda.is_available():
658 cap = torch.cuda.get_device_capability()
659 name = torch.cuda.get_device_name()
660 mem = torch.cuda.get_device_properties(0).total_memory / 1e9
661 print(f" GPU: {name} ({mem:.1f} GB, sm_{cap[0]}{cap[1]})")
662
663 # Check TF32
664 if cap[0] >= 8:
665 tf32_matmul = torch.backends.cuda.matmul.allow_tf32
666 tf32_cudnn = torch.backends.cudnn.allow_tf32
667 if not tf32_matmul:
668 issues.append("TF32 disabled for matmul (free 3x speedup on Ampere+)")
669 recommendations.append("torch.backends.cuda.matmul.allow_tf32 = True")
670 if not tf32_cudnn:
671 issues.append("TF32 disabled for cuDNN")
672 recommendations.append("torch.backends.cudnn.allow_tf32 = True")
673
674 # Check CUDA alloc config
675 alloc_conf = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
676 if 'expandable_segments' not in alloc_conf:
677 issues.append("CUDA allocator not optimized (expandable_segments not set)")
678 recommendations.append("export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
679
680 # Check precision recommendation
681 if cap[0] >= 8:
682 recommendations.append("Use BF16: torch.autocast('cuda', dtype=torch.bfloat16)")
683 elif cap[0] >= 7:
684 recommendations.append("Use FP16: torch.autocast('cuda', dtype=torch.float16) + GradScaler()")
685
686# Check CPU
687cpu_count = os.cpu_count()
688print(f" CPUs: {cpu_count}")
689if cpu_count and cpu_count >= 4:
690 recommendations.append(f"DataLoader: num_workers={min(cpu_count, 8)}, pin_memory=True, persistent_workers=True")
691
692# Check script if provided
693script_path = os.environ['ZERNEL_ARG_SCRIPT']
694if script_path != "(none)" and os.path.exists(script_path):
695 with open(script_path) as f:
696 content = f.read()
697 print(f" Script: {script_path}")
698
699 if "autocast" not in content and "amp" not in content.lower():
700 issues.append("No mixed precision (AMP) detected in script")
701 if "num_workers" not in content:
702 issues.append("DataLoader num_workers not set (defaults to 0 — single-threaded)")
703 if "pin_memory" not in content:
704 issues.append("DataLoader pin_memory not set (slower CPU→GPU transfers)")
705 if "gradient_checkpointing" not in content and "checkpoint_sequential" not in content:
706 recommendations.append("Consider gradient checkpointing for larger models/batches")
707
708print()
709if issues:
710 print(f"ISSUES FOUND ({len(issues)}):")
711 for i, issue in enumerate(issues, 1):
712 print(f" {i}. {issue}")
713else:
714 print("No issues found!")
715
716print()
717if recommendations:
718 print(f"RECOMMENDATIONS ({len(recommendations)}):")
719 for i, rec in enumerate(recommendations, 1):
720 print(f" {i}. {rec}")
721
722print()
723print("Run individual optimizations:")
724print(" zernel optimize precision <script> # Mixed precision analysis")
725print(" zernel optimize batch-size <model> # Optimal batch size")
726print(" zernel optimize checkpoint <script> # Gradient checkpointing")
727print(" zernel optimize data-pipeline <script> # Data loading profiler")
728print(" zernel optimize auto <script> # Generate optimized wrapper")
729"#])
730 .env("ZERNEL_ARG_SCRIPT", script_arg)
731 .output()?;
732
733 print!("{}", String::from_utf8_lossy(&output.stdout));
734 if !output.status.success() {
735 eprint!("{}", String::from_utf8_lossy(&output.stderr));
736 }
737 Ok(())
738}
739
740async fn run_auto(script: &str, output_path: &str) -> Result<()> {
741 println!("Zernel Auto-Optimizer");
742 println!("{}", "=".repeat(60));
743 println!("Input: {script}");
744 println!("Output: {output_path}");
745 println!();
746
747 let output = Command::new("python3")
749 .args(["-c", include_str!("optimize_auto.py")])
750 .env("ZERNEL_OPT_SCRIPT", script)
751 .env("ZERNEL_OPT_OUTPUT", output_path)
752 .output();
753
754 match output {
755 Ok(out) => {
756 print!("{}", String::from_utf8_lossy(&out.stdout));
757 if !out.status.success() {
758 eprint!("{}", String::from_utf8_lossy(&out.stderr));
759 }
760 }
761 Err(_) => {
762 println!("Generating optimized wrapper for {script}...");
764 let original = std::fs::read_to_string(script)?;
765 let mut wrapper = format!("#!/usr/bin/env python3\n# Auto-optimized by Zernel\n");
766 wrapper.push_str("import torch, os\n");
767 wrapper.push_str("torch.backends.cuda.matmul.allow_tf32 = True\n");
768 wrapper.push_str("torch.backends.cudnn.allow_tf32 = True\n");
769 wrapper.push_str("os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')\n\n");
770 wrapper.push_str("# Original script:\n");
771 wrapper.push_str(&original);
772 std::fs::write(output_path, &wrapper)?;
773 println!("Written to: {output_path}");
774 }
775 }
776 Ok(())
777}