zernel/commands/
optimize.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel optimize — ML Training optimization tools
4//!
5//! Auto-detects and applies optimizations:
6//! - Mixed precision (AMP) with auto-detection of GPU capability
7//! - Smart batch size calculation based on GPU memory and model size
8//! - Gradient checkpointing for memory-constrained training
9//! - Data pipeline bottleneck detection and fix recommendations
10//! - Full optimization scan with one command
11
12use anyhow::Result;
13use clap::Subcommand;
14use std::process::Command;
15
16#[derive(Subcommand)]
17pub enum OptimizeCommands {
18    /// Analyze model for mixed precision and auto-generate AMP wrapper
19    Precision {
20        /// Training script to analyze and wrap
21        script: String,
22    },
23    /// Calculate optimal batch size for your GPU and model
24    BatchSize {
25        /// Model name or parameter count (e.g. "125M", "1.3B", "gpt2", "llama-7b")
26        model: String,
27        /// Sequence length (default: 512)
28        #[arg(long, default_value = "512")]
29        seq_len: u32,
30        /// Use mixed precision (halves memory per parameter)
31        #[arg(long)]
32        amp: bool,
33        /// GPU ID (default: 0)
34        #[arg(long, default_value = "0")]
35        gpu: u32,
36    },
37    /// Analyze and enable gradient checkpointing for large models
38    Checkpoint {
39        /// Training script or model path
40        script: String,
41        /// Target memory usage fraction (0.0-1.0, default: 0.85)
42        #[arg(long, default_value = "0.85")]
43        target_mem: f32,
44    },
45    /// Profile data pipeline and detect bottlenecks
46    DataPipeline {
47        /// Training script to profile
48        script: String,
49        /// Number of steps to profile (default: 20)
50        #[arg(long, default_value = "20")]
51        steps: u32,
52    },
53    /// Configure CUDA memory allocator for optimal performance
54    Memory,
55    /// Configure NUMA-aware data placement
56    Numa,
57    /// Full optimization scan — analyzes everything and generates a report
58    Scan {
59        /// Training script to analyze (optional — scans GPU environment if omitted)
60        script: Option<String>,
61    },
62    /// Auto-optimize: generate a wrapper script with all optimizations applied
63    Auto {
64        /// Training script to optimize
65        script: String,
66        /// Output optimized wrapper script
67        #[arg(long, short, default_value = "train_optimized.py")]
68        output: String,
69    },
70}
71
72pub async fn run(cmd: OptimizeCommands) -> Result<()> {
73    match cmd {
74        OptimizeCommands::Precision { script } => run_precision(&script).await,
75        OptimizeCommands::BatchSize { model, seq_len, amp, gpu } => {
76            run_batch_size(&model, seq_len, amp, gpu).await
77        }
78        OptimizeCommands::Checkpoint { script, target_mem } => {
79            run_checkpoint(&script, target_mem).await
80        }
81        OptimizeCommands::DataPipeline { script, steps } => {
82            run_data_pipeline(&script, steps).await
83        }
84        OptimizeCommands::Memory => run_memory().await,
85        OptimizeCommands::Numa => run_numa().await,
86        OptimizeCommands::Scan { script } => run_scan(script.as_deref()).await,
87        OptimizeCommands::Auto { script, output } => run_auto(&script, &output).await,
88    }
89}
90
91async fn run_precision(script: &str) -> Result<()> {
92    println!("Mixed Precision Analyzer");
93    println!("{}", "=".repeat(60));
94    println!("Script: {script}");
95    println!();
96
97    let output = Command::new("python3")
98        .args(["-c", r#"
99import torch, sys, os, ast, re
100
101# 1. Check GPU capabilities
102if not torch.cuda.is_available():
103    print("ERROR: No CUDA GPU available")
104    sys.exit(1)
105
106cap = torch.cuda.get_device_capability()
107name = torch.cuda.get_device_name()
108mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
109
110print(f"GPU: {name}")
111print(f"Compute capability: {cap[0]}.{cap[1]}")
112print(f"Memory: {mem_gb:.1f} GB")
113print()
114
115# Determine best precision
116if cap[0] >= 9:
117    best = "FP8 (Hopper+)"
118    dtype = "torch.float8_e4m3fn"
119    speedup = "3-4x"
120elif cap[0] >= 8:
121    best = "BF16 (Ampere+)"
122    dtype = "torch.bfloat16"
123    speedup = "1.5-2x"
124elif cap[0] >= 7:
125    best = "FP16 (Volta+)"
126    dtype = "torch.float16"
127    speedup = "1.5-2x"
128else:
129    best = "FP32 only"
130    dtype = "torch.float32"
131    speedup = "1x (no speedup)"
132
133print(f"Recommended precision: {best}")
134print(f"Expected speedup: {speedup}")
135print(f"Memory savings: {2 if cap[0] >= 7 else 1}x (FP32: {mem_gb:.1f}GB -> {'~' + str(round(mem_gb*2, 1)) if cap[0] >= 7 else str(round(mem_gb, 1))}GB effective)")
136print()
137
138# 2. Scan script for existing AMP usage
139script_path = os.environ['ZERNEL_ARG_SCRIPT']
140has_autocast = False
141has_gradscaler = False
142has_bf16 = False
143has_fp16 = False
144
145if os.path.exists(script_path):
146    with open(script_path) as f:
147        content = f.read()
148    has_autocast = "autocast" in content
149    has_gradscaler = "GradScaler" in content
150    has_bf16 = "bfloat16" in content
151    has_fp16 = "float16" in content and "bfloat" not in content
152
153    if has_autocast:
154        print("STATUS: Script already uses torch.autocast (AMP enabled)")
155        if has_bf16:
156            print("  Using: BF16 precision")
157        elif has_fp16:
158            print("  Using: FP16 precision")
159            if cap[0] >= 8:
160                print("  TIP: Switch to BF16 for better numerical stability on this GPU")
161        if not has_gradscaler and has_fp16:
162            print("  WARNING: Using FP16 without GradScaler - may have loss scaling issues")
163            print("  Add: scaler = torch.cuda.amp.GradScaler()")
164    else:
165        print("STATUS: Script does NOT use mixed precision")
166        print()
167        print("To enable AMP, wrap your training loop:")
168        print()
169        if cap[0] >= 8:
170            print("  # Add at top of script:")
171            print("  torch.backends.cuda.matmul.allow_tf32 = True")
172            print("  torch.backends.cudnn.allow_tf32 = True")
173            print()
174            print("  # Wrap forward pass:")
175            print("  with torch.autocast('cuda', dtype=torch.bfloat16):")
176            print("      output = model(input)")
177            print("      loss = loss_fn(output, target)")
178            print("  loss.backward()")
179            print("  optimizer.step()")
180        else:
181            print("  scaler = torch.cuda.amp.GradScaler()")
182            print()
183            print("  with torch.autocast('cuda', dtype=torch.float16):")
184            print("      output = model(input)")
185            print("      loss = loss_fn(output, target)")
186            print("  scaler.scale(loss).backward()")
187            print("  scaler.step(optimizer)")
188            print("  scaler.update()")
189        print()
190        print(f"Or use: zernel optimize auto {script_path} to generate an optimized wrapper")
191else:
192    print(f"Script not found: {script_path}")
193    print("Showing general recommendations based on GPU capability.")
194    print()
195    if cap[0] >= 8:
196        print("  Best: torch.autocast('cuda', dtype=torch.bfloat16)")
197        print("  Also: torch.backends.cuda.matmul.allow_tf32 = True")
198    elif cap[0] >= 7:
199        print("  Best: torch.autocast('cuda', dtype=torch.float16)")
200        print("  Need: torch.cuda.amp.GradScaler() for loss scaling")
201"#])
202        .env("ZERNEL_ARG_SCRIPT", script)
203        .output()?;
204
205    print!("{}", String::from_utf8_lossy(&output.stdout));
206    if !output.status.success() {
207        eprint!("{}", String::from_utf8_lossy(&output.stderr));
208    }
209    Ok(())
210}
211
212async fn run_batch_size(model: &str, seq_len: u32, amp: bool, gpu: u32) -> Result<()> {
213    println!("Smart Batch Size Calculator");
214    println!("{}", "=".repeat(60));
215    println!();
216
217    let output = Command::new("python3")
218        .args(["-c", r#"
219import torch, sys, math, os
220
221if not torch.cuda.is_available():
222    print("ERROR: No CUDA GPU available")
223    sys.exit(1)
224
225gpu_id = int(os.environ['ZERNEL_ARG_GPU'])
226props = torch.cuda.get_device_properties(gpu_id)
227total_mem = props.total_memory
228name = props.name
229mem_gb = total_mem / 1e9
230
231print(f"GPU {gpu_id}: {name} ({mem_gb:.1f} GB)")
232
233# Parse model size
234model_str = os.environ['ZERNEL_ARG_MODEL'].lower().strip()
235param_count = None
236
237# Known models
238known = {
239    "gpt2": 124e6, "gpt2-medium": 355e6, "gpt2-large": 774e6, "gpt2-xl": 1.5e9,
240    "llama-7b": 7e9, "llama-13b": 13e9, "llama-70b": 70e9,
241    "mistral-7b": 7.2e9, "phi-2": 2.7e9, "phi-3": 3.8e9,
242    "bert-base": 110e6, "bert-large": 340e6,
243    "t5-small": 60e6, "t5-base": 220e6, "t5-large": 770e6,
244    "vit-base": 86e6, "vit-large": 307e6,
245    "resnet50": 25.6e6, "resnet101": 44.5e6,
246}
247
248if model_str in known:
249    param_count = known[model_str]
250    print(f"Model: {model_str} ({param_count/1e6:.0f}M params)")
251elif model_str.endswith('b'):
252    param_count = float(model_str[:-1]) * 1e9
253    print(f"Model: {param_count/1e9:.1f}B params")
254elif model_str.endswith('m'):
255    param_count = float(model_str[:-1]) * 1e6
256    print(f"Model: {param_count/1e6:.0f}M params")
257else:
258    try:
259        param_count = float(model_str)
260        print(f"Model: {param_count/1e6:.0f}M params")
261    except:
262        print(f"Unknown model: {model_str}")
263        print("Use a known name (gpt2, llama-7b, etc.) or param count (125M, 1.3B)")
264        sys.exit(1)
265
266seq_len = int(os.environ['ZERNEL_ARG_SEQ_LEN'])
267use_amp = os.environ['ZERNEL_ARG_AMP'] == "1"
268bytes_per_param = 2 if use_amp else 4
269
270print(f"Sequence length: {seq_len}")
271print(f"Precision: {'FP16/BF16' if use_amp else 'FP32'}")
272print()
273
274# Memory estimation (rough but practical):
275# Model params: param_count * bytes_per_param
276# Gradients: same as model
277# Optimizer state (AdamW): 2x model size (momentum + variance) in FP32
278# Activations: depends on model architecture, roughly:
279#   For transformers: ~12 * n_layers * hidden_dim * seq_len * batch_size * bytes_per_param
280#   Simplified: ~6 * param_count * seq_len / hidden_dim * batch_size * bytes_per_param / 1024
281
282model_mem = param_count * bytes_per_param
283grad_mem = model_mem
284opt_mem = param_count * 8  # AdamW always FP32: momentum + variance = 2 * 4 bytes
285static_mem = model_mem + grad_mem + opt_mem
286cuda_overhead = 0.5e9  # ~500MB CUDA context
287
288available = total_mem - static_mem - cuda_overhead
289if available <= 0:
290    print(f"WARNING: Model doesn't fit in GPU memory!")
291    print(f"  Model + gradients + optimizer: {static_mem/1e9:.1f} GB")
292    print(f"  GPU memory: {mem_gb:.1f} GB")
293    print(f"  Shortfall: {(static_mem + cuda_overhead - total_mem)/1e9:.1f} GB")
294    print()
295    print("Recommendations:")
296    print(f"  1. Use mixed precision: zernel optimize batch-size {model_str} --amp")
297    print("  2. Use gradient checkpointing: zernel optimize checkpoint <script>")
298    print("  3. Use model parallelism or offloading")
299    sys.exit(0)
300
301# Estimate activation memory per sample
302# Rough heuristic: 4-6 bytes per param per sample (varies by architecture)
303if param_count > 1e9:
304    act_per_sample = param_count * 4 * seq_len / 2048  # scale with seq_len
305else:
306    act_per_sample = param_count * 6 * seq_len / 512
307
308act_per_sample *= (0.5 if use_amp else 1.0)
309
310max_batch = max(1, int(available / act_per_sample))
311# Round down to power of 2 for efficiency
312optimal_batch = 2 ** int(math.log2(max_batch)) if max_batch >= 2 else 1
313safe_batch = max(1, optimal_batch // 2)  # 50% margin for safety
314
315print(f"Memory breakdown:")
316print(f"  Model weights:    {model_mem/1e9:.2f} GB")
317print(f"  Gradients:        {grad_mem/1e9:.2f} GB")
318print(f"  Optimizer (AdamW):{opt_mem/1e9:.2f} GB")
319print(f"  CUDA overhead:    {cuda_overhead/1e9:.2f} GB")
320print(f"  Available for activations: {available/1e9:.2f} GB")
321print()
322print(f"Recommended batch sizes:")
323print(f"  Maximum:   {max_batch} (uses ~100% GPU memory — risky)")
324print(f"  Optimal:   {optimal_batch} (power-of-2, ~{optimal_batch * act_per_sample / available * 100:.0f}% memory)")
325print(f"  Safe:      {safe_batch} (50% margin for variable-length sequences)")
326print()
327
328if not use_amp and param_count > 100e6:
329    amp_available = total_mem - (param_count * 2 + param_count * 2 + opt_mem + cuda_overhead)
330    amp_act = act_per_sample * 0.5
331    amp_max = max(1, int(amp_available / amp_act))
332    amp_optimal = 2 ** int(math.log2(amp_max)) if amp_max >= 2 else 1
333    print(f"With mixed precision (--amp):")
334    print(f"  Optimal batch size: {amp_optimal} ({amp_optimal/optimal_batch:.1f}x larger)")
335    print(f"  Run: zernel optimize batch-size {model_str} --seq-len {seq_len} --amp")
336"#])
337        .env("ZERNEL_ARG_MODEL", model)
338        .env("ZERNEL_ARG_SEQ_LEN", seq_len.to_string())
339        .env("ZERNEL_ARG_AMP", if amp { "1" } else { "0" })
340        .env("ZERNEL_ARG_GPU", gpu.to_string())
341        .output()?;
342
343    print!("{}", String::from_utf8_lossy(&output.stdout));
344    if !output.status.success() {
345        eprint!("{}", String::from_utf8_lossy(&output.stderr));
346    }
347    Ok(())
348}
349
350async fn run_checkpoint(script: &str, target_mem: f32) -> Result<()> {
351    println!("Gradient Checkpointing Analyzer");
352    println!("{}", "=".repeat(60));
353    println!("Script: {script}");
354    println!("Target memory usage: {:.0}%", target_mem * 100.0);
355    println!();
356
357    let output = Command::new("python3")
358        .args(["-c", r#"
359import torch, sys, os
360
361if not torch.cuda.is_available():
362    print("ERROR: No CUDA GPU available")
363    sys.exit(1)
364
365mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
366name = torch.cuda.get_device_name()
367target = float(os.environ['ZERNEL_ARG_TARGET_MEM'])
368
369print(f"GPU: {name} ({mem_gb:.1f} GB)")
370print(f"Target usage: {target*100:.0f}% ({mem_gb*target:.1f} GB)")
371print()
372
373# Check if script exists and scan for model definition
374script_path = os.environ['ZERNEL_ARG_SCRIPT']
375has_checkpoint = False
376model_type = "unknown"
377
378if os.path.exists(script_path):
379    with open(script_path) as f:
380        content = f.read()
381    has_checkpoint = "checkpoint_sequential" in content or "checkpoint" in content and "gradient" in content.lower()
382
383    if "transformers" in content or "AutoModel" in content:
384        model_type = "huggingface"
385    elif "nn.Sequential" in content or "nn.Module" in content:
386        model_type = "pytorch"
387
388if has_checkpoint:
389    print("STATUS: Script already uses gradient checkpointing")
390else:
391    print("STATUS: Gradient checkpointing NOT enabled")
392    print()
393    print("Gradient checkpointing trades compute for memory:")
394    print("  - Saves ~60-70% activation memory")
395    print("  - Costs ~30% extra compute (recomputes activations in backward)")
396    print("  - Allows ~2-3x larger batch sizes or models")
397    print()
398
399    if model_type == "huggingface":
400        print("For HuggingFace models:")
401        print("  model.gradient_checkpointing_enable()")
402        print()
403        print("Or in TrainingArguments:")
404        print("  TrainingArguments(gradient_checkpointing=True, ...)")
405    else:
406        print("For PyTorch models:")
407        print("  from torch.utils.checkpoint import checkpoint_sequential")
408        print()
409        print("  # In your model's forward():")
410        print("  # Instead of: x = self.layers(x)")
411        print("  # Use:        x = checkpoint_sequential(self.layers, segments, x)")
412        print()
413        print("  # Or per-layer:")
414        print("  from torch.utils.checkpoint import checkpoint")
415        print("  for layer in self.layers:")
416        print("      x = checkpoint(layer, x, use_reentrant=False)")
417
418    print()
419    print("When to use gradient checkpointing:")
420    print(f"  - Model + optimizer > {mem_gb*0.6:.1f} GB (60% of GPU memory)")
421    print(f"  - You're getting OOM errors")
422    print(f"  - You want to increase batch size without more GPUs")
423    print()
424    print("When NOT to use it:")
425    print(f"  - You have plenty of GPU memory headroom")
426    print(f"  - Training is already compute-bound (GPU util > 90%)")
427    print(f"  - The 30% slowdown isn't acceptable")
428"#])
429        .env("ZERNEL_ARG_SCRIPT", script)
430        .env("ZERNEL_ARG_TARGET_MEM", target_mem.to_string())
431        .output()?;
432
433    print!("{}", String::from_utf8_lossy(&output.stdout));
434    if !output.status.success() {
435        eprint!("{}", String::from_utf8_lossy(&output.stderr));
436    }
437    Ok(())
438}
439
440async fn run_data_pipeline(script: &str, steps: u32) -> Result<()> {
441    println!("Data Pipeline Profiler");
442    println!("{}", "=".repeat(60));
443    println!("Script: {script}");
444    println!("Profiling: {steps} steps");
445    println!();
446
447    let output = Command::new("python3")
448        .args(["-c", r#"
449import torch, time, sys, os, statistics
450
451if not torch.cuda.is_available():
452    print("ERROR: No CUDA GPU available")
453    sys.exit(1)
454
455print(f"GPU: {torch.cuda.get_device_name()}")
456print(f"CPU cores: {os.cpu_count()}")
457print()
458
459# Profile different DataLoader configurations
460from torch.utils.data import DataLoader, TensorDataset
461
462# Create dummy dataset (1M samples of 224x224x3 "images")
463n_samples = 10000
464data = torch.randn(n_samples, 3, 224, 224)
465labels = torch.randint(0, 1000, (n_samples,))
466dataset = TensorDataset(data, labels)
467
468configs = [
469    {"num_workers": 0, "pin_memory": False, "persistent_workers": False, "label": "Default (0 workers)"},
470    {"num_workers": 2, "pin_memory": False, "persistent_workers": False, "label": "2 workers"},
471    {"num_workers": 4, "pin_memory": True,  "persistent_workers": False, "label": "4 workers + pin_memory"},
472    {"num_workers": 4, "pin_memory": True,  "persistent_workers": True,  "label": "4 workers + pin + persistent"},
473    {"num_workers": 8, "pin_memory": True,  "persistent_workers": True,  "label": "8 workers + pin + persistent"},
474]
475
476n_steps = int(os.environ['ZERNEL_ARG_STEPS'])
477batch_size = 64
478
479print(f"Profiling {n_steps} batches (batch_size={batch_size}) per config...")
480print()
481print(f"  {'Config':<40} {'Time(ms)':>10} {'Throughput':>12} {'vs Best':>10}")
482print(f"  {'-'*75}")
483
484best_time = float('inf')
485results = []
486
487for cfg in configs:
488    kw = {k: v for k, v in cfg.items() if k != 'label'}
489    if kw.get('persistent_workers') and kw.get('num_workers', 0) == 0:
490        continue
491    try:
492        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, **kw)
493        it = iter(loader)
494
495        # Warmup
496        for _ in range(3):
497            try:
498                batch = next(it)
499            except StopIteration:
500                it = iter(loader)
501                batch = next(it)
502
503        # Benchmark
504        times = []
505        for _ in range(n_steps):
506            t0 = time.perf_counter()
507            try:
508                batch = next(it)
509            except StopIteration:
510                it = iter(loader)
511                batch = next(it)
512            # Simulate GPU transfer
513            x = batch[0].to('cuda', non_blocking=kw.get('pin_memory', False))
514            torch.cuda.synchronize()
515            times.append((time.perf_counter() - t0) * 1000)
516
517        avg = statistics.mean(times)
518        throughput = batch_size * 1000 / avg
519        best_time = min(best_time, avg)
520        results.append((cfg['label'], avg, throughput))
521
522        del loader
523    except Exception as e:
524        results.append((cfg['label'], 0, 0))
525
526for label, avg, throughput in results:
527    if avg > 0:
528        ratio = best_time / avg
529        marker = " <-- best" if abs(avg - best_time) < 0.01 else ""
530        print(f"  {label:<40} {avg:>8.2f}ms {throughput:>9.0f} s/s  {ratio:>7.2f}x{marker}")
531    else:
532        print(f"  {label:<40} {'FAILED':>10}")
533
534print()
535
536# Recommendations
537best_cfg = None
538best_avg = float('inf')
539for label, avg, _ in results:
540    if 0 < avg < best_avg:
541        best_avg = avg
542        best_cfg = label
543
544if best_cfg:
545    print(f"Recommendation: Use '{best_cfg}'")
546    print()
547    print("Add to your DataLoader:")
548    print("  DataLoader(")
549    print("      dataset,")
550    print(f"      batch_size={batch_size},")
551    if "8 worker" in best_cfg:
552        print("      num_workers=8,")
553    elif "4 worker" in best_cfg:
554        print("      num_workers=4,")
555    elif "2 worker" in best_cfg:
556        print("      num_workers=2,")
557    if "pin" in best_cfg:
558        print("      pin_memory=True,")
559    if "persistent" in best_cfg:
560        print("      persistent_workers=True,")
561    print("      prefetch_factor=2,")
562    print("  )")
563
564    # Check if data loading is the bottleneck
565    typical_gpu_step = 50  # ms for a typical forward+backward
566    if best_avg > typical_gpu_step * 0.3:
567        print()
568        print("WARNING: Data loading may be a bottleneck!")
569        print(f"  Data load time ({best_avg:.1f}ms) is >{typical_gpu_step*0.3:.0f}ms (30% of typical GPU step)")
570        print("  Consider: larger prefetch_factor, SSD storage, or data preprocessing")
571"#])
572        .env("ZERNEL_ARG_SCRIPT", script)
573        .env("ZERNEL_ARG_STEPS", steps.to_string())
574        .output()?;
575
576    print!("{}", String::from_utf8_lossy(&output.stdout));
577    if !output.status.success() {
578        eprint!("{}", String::from_utf8_lossy(&output.stderr));
579    }
580    Ok(())
581}
582
583async fn run_memory() -> Result<()> {
584    println!("CUDA Memory Allocator Configuration");
585    println!("{}", "=".repeat(60));
586
587    let code = r#"
588import os, torch
589if torch.cuda.is_available():
590    conf = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '(not set)')
591    print(f"  PYTORCH_CUDA_ALLOC_CONF: {conf}")
592    props = torch.cuda.get_device_properties(0)
593    print(f"  GPU memory: {props.total_memory / 1e9:.1f} GB")
594    print()
595    print("Recommended settings:")
596    print("  For training:")
597    print("    export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,garbage_collection_threshold:0.6,max_split_size_mb:512")
598    print("  For inference:")
599    print("    export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128")
600"#;
601    let output = Command::new("python3").args(["-c", code]).output()?;
602    print!("{}", String::from_utf8_lossy(&output.stdout));
603    Ok(())
604}
605
606async fn run_numa() -> Result<()> {
607    println!("NUMA-Aware Data Placement");
608    println!("{}", "=".repeat(60));
609
610    #[cfg(target_os = "linux")]
611    {
612        if let Ok(entries) = std::fs::read_dir("/sys/devices/system/node") {
613            let mut node_count = 0;
614            for entry in entries.flatten() {
615                let name = entry.file_name().to_string_lossy().to_string();
616                if name.starts_with("node") {
617                    node_count += 1;
618                    let cpulist_path = entry.path().join("cpulist");
619                    if let Ok(cpus) = std::fs::read_to_string(&cpulist_path) {
620                        println!("  {name}: CPUs {}", cpus.trim());
621                    }
622                }
623            }
624            println!();
625            if node_count > 1 {
626                println!("Multi-NUMA system ({node_count} nodes). Use:");
627                println!("  numactl --cpunodebind=0 --membind=0 python3 train.py");
628            } else {
629                println!("Single NUMA node — no placement optimization needed.");
630            }
631        }
632    }
633
634    #[cfg(not(target_os = "linux"))]
635    println!("  NUMA detection requires Linux.");
636
637    Ok(())
638}
639
640async fn run_scan(script: Option<&str>) -> Result<()> {
641    println!("Zernel Optimization Scan");
642    println!("{}", "=".repeat(60));
643    println!();
644
645    let script_arg = script.unwrap_or("(none)");
646    let output = Command::new("python3")
647        .args(["-c", r#"
648import torch, os, sys
649
650issues = []
651recommendations = []
652
653print("Environment:")
654print(f"  PyTorch: {torch.__version__}")
655print(f"  CUDA: {torch.cuda.is_available()}")
656
657if torch.cuda.is_available():
658    cap = torch.cuda.get_device_capability()
659    name = torch.cuda.get_device_name()
660    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
661    print(f"  GPU: {name} ({mem:.1f} GB, sm_{cap[0]}{cap[1]})")
662
663    # Check TF32
664    if cap[0] >= 8:
665        tf32_matmul = torch.backends.cuda.matmul.allow_tf32
666        tf32_cudnn = torch.backends.cudnn.allow_tf32
667        if not tf32_matmul:
668            issues.append("TF32 disabled for matmul (free 3x speedup on Ampere+)")
669            recommendations.append("torch.backends.cuda.matmul.allow_tf32 = True")
670        if not tf32_cudnn:
671            issues.append("TF32 disabled for cuDNN")
672            recommendations.append("torch.backends.cudnn.allow_tf32 = True")
673
674    # Check CUDA alloc config
675    alloc_conf = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
676    if 'expandable_segments' not in alloc_conf:
677        issues.append("CUDA allocator not optimized (expandable_segments not set)")
678        recommendations.append("export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
679
680    # Check precision recommendation
681    if cap[0] >= 8:
682        recommendations.append("Use BF16: torch.autocast('cuda', dtype=torch.bfloat16)")
683    elif cap[0] >= 7:
684        recommendations.append("Use FP16: torch.autocast('cuda', dtype=torch.float16) + GradScaler()")
685
686# Check CPU
687cpu_count = os.cpu_count()
688print(f"  CPUs: {cpu_count}")
689if cpu_count and cpu_count >= 4:
690    recommendations.append(f"DataLoader: num_workers={min(cpu_count, 8)}, pin_memory=True, persistent_workers=True")
691
692# Check script if provided
693script_path = os.environ['ZERNEL_ARG_SCRIPT']
694if script_path != "(none)" and os.path.exists(script_path):
695    with open(script_path) as f:
696        content = f.read()
697    print(f"  Script: {script_path}")
698
699    if "autocast" not in content and "amp" not in content.lower():
700        issues.append("No mixed precision (AMP) detected in script")
701    if "num_workers" not in content:
702        issues.append("DataLoader num_workers not set (defaults to 0 — single-threaded)")
703    if "pin_memory" not in content:
704        issues.append("DataLoader pin_memory not set (slower CPU→GPU transfers)")
705    if "gradient_checkpointing" not in content and "checkpoint_sequential" not in content:
706        recommendations.append("Consider gradient checkpointing for larger models/batches")
707
708print()
709if issues:
710    print(f"ISSUES FOUND ({len(issues)}):")
711    for i, issue in enumerate(issues, 1):
712        print(f"  {i}. {issue}")
713else:
714    print("No issues found!")
715
716print()
717if recommendations:
718    print(f"RECOMMENDATIONS ({len(recommendations)}):")
719    for i, rec in enumerate(recommendations, 1):
720        print(f"  {i}. {rec}")
721
722print()
723print("Run individual optimizations:")
724print("  zernel optimize precision <script>     # Mixed precision analysis")
725print("  zernel optimize batch-size <model>      # Optimal batch size")
726print("  zernel optimize checkpoint <script>     # Gradient checkpointing")
727print("  zernel optimize data-pipeline <script>  # Data loading profiler")
728print("  zernel optimize auto <script>           # Generate optimized wrapper")
729"#])
730        .env("ZERNEL_ARG_SCRIPT", script_arg)
731        .output()?;
732
733    print!("{}", String::from_utf8_lossy(&output.stdout));
734    if !output.status.success() {
735        eprint!("{}", String::from_utf8_lossy(&output.stderr));
736    }
737    Ok(())
738}
739
740async fn run_auto(script: &str, output_path: &str) -> Result<()> {
741    println!("Zernel Auto-Optimizer");
742    println!("{}", "=".repeat(60));
743    println!("Input:  {script}");
744    println!("Output: {output_path}");
745    println!();
746
747    // Pass args via env vars to avoid format string escaping issues
748    let output = Command::new("python3")
749        .args(["-c", include_str!("optimize_auto.py")])
750        .env("ZERNEL_OPT_SCRIPT", script)
751        .env("ZERNEL_OPT_OUTPUT", output_path)
752        .output();
753
754    match output {
755        Ok(out) => {
756            print!("{}", String::from_utf8_lossy(&out.stdout));
757            if !out.status.success() {
758                eprint!("{}", String::from_utf8_lossy(&out.stderr));
759            }
760        }
761        Err(_) => {
762            // Fallback: inline simple version
763            println!("Generating optimized wrapper for {script}...");
764            let original = std::fs::read_to_string(script)?;
765            let mut wrapper = format!("#!/usr/bin/env python3\n# Auto-optimized by Zernel\n");
766            wrapper.push_str("import torch, os\n");
767            wrapper.push_str("torch.backends.cuda.matmul.allow_tf32 = True\n");
768            wrapper.push_str("torch.backends.cudnn.allow_tf32 = True\n");
769            wrapper.push_str("os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')\n\n");
770            wrapper.push_str("# Original script:\n");
771            wrapper.push_str(&original);
772            std::fs::write(output_path, &wrapper)?;
773            println!("Written to: {output_path}");
774        }
775    }
776    Ok(())
777}