zernel_scheduler/
main.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2//
3// Zernel sched_ext ML Scheduler (v3)
4//
5// Full pipeline:
6//   1. Load sched_ext BPF scheduler into kernel
7//   2. Discover GPU processes (nvidia-smi) and register them
8//   3. Poll GPU utilization and update task state
9//   4. Run phase detection on tracked tasks
10//   5. Write detected phases to BPF phase_map
11//   6. Write CPU affinity hints to BPF cpu_affinity_map
12//   7. Adjust GPU power profiles per detected phase
13//   8. Read ring buffer events for observability
14#![allow(dead_code)]
15
16mod config;
17mod multi_tenant;
18mod numa;
19mod phase_detector;
20mod scheduler;
21mod task_state;
22mod telemetry;
23
24use anyhow::Result;
25use clap::Parser;
26use config::SchedulerConfig;
27use std::collections::HashMap;
28use std::path::PathBuf;
29use tracing::{debug, info, warn};
30#[cfg(feature = "bpf")]
31use libbpf_rs::skel::{SkelBuilder, OpenSkel, Skel};
32#[cfg(feature = "bpf")]
33use libbpf_rs::MapCore;
34
35const DEFAULT_CONFIG_PATH: &str = "/etc/zernel/scheduler.toml";
36
37#[derive(Parser)]
38#[command(name = "zernel-scheduler")]
39#[command(about = "Zernel sched_ext ML-Aware CPU Scheduler")]
40#[command(version)]
41struct Args {
42    #[arg(long, default_value = DEFAULT_CONFIG_PATH, env = "ZERNEL_SCHEDULER_CONFIG")]
43    config: PathBuf,
44    #[arg(long)]
45    dump_config: bool,
46    #[arg(long)]
47    demo: bool,
48}
49
50#[cfg(feature = "bpf")]
51mod skel {
52    include!(concat!(env!("OUT_DIR"), "/zernel_sched.skel.rs"));
53}
54
55// ── GPU process discovery via nvidia-smi ──────────────────────
56
57/// Discover PIDs currently using NVIDIA GPUs.
58fn discover_gpu_processes() -> Vec<(u32, u32)> {
59    // (pid, gpu_id) pairs
60    let output = std::process::Command::new("nvidia-smi")
61        .args(["--query-compute-apps=pid,gpu_uuid", "--format=csv,noheader,nounits"])
62        .output();
63
64    let mut results = Vec::new();
65    if let Ok(out) = output {
66        if out.status.success() {
67            let stdout = String::from_utf8_lossy(&out.stdout);
68            for line in stdout.lines() {
69                let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
70                if parts.len() >= 2 {
71                    if let Ok(pid) = parts[0].parse::<u32>() {
72                        // Simple GPU ID: count lines from nvidia-smi pci.bus_id
73                        results.push((pid, 0)); // GPU 0 for single-GPU systems
74                    }
75                }
76            }
77        }
78    }
79    results
80}
81
82/// Get GPU utilization from nvidia-smi.
83fn get_gpu_utilization() -> Option<u8> {
84    let output = std::process::Command::new("nvidia-smi")
85        .args(["--query-gpu=utilization.gpu", "--format=csv,noheader,nounits"])
86        .output()
87        .ok()?;
88    if !output.status.success() {
89        return None;
90    }
91    let stdout = String::from_utf8_lossy(&output.stdout);
92    stdout.trim().lines().next()?.trim().parse().ok()
93}
94
95/// Get GPU power draw in watts.
96fn get_gpu_power() -> Option<f32> {
97    let output = std::process::Command::new("nvidia-smi")
98        .args(["--query-gpu=power.draw", "--format=csv,noheader,nounits"])
99        .output()
100        .ok()?;
101    if !output.status.success() {
102        return None;
103    }
104    let stdout = String::from_utf8_lossy(&output.stdout);
105    stdout.trim().lines().next()?.trim().parse().ok()
106}
107
108/// Get GPU max clocks for power management.
109fn get_gpu_max_clocks() -> Option<(u32, u32, u32)> {
110    let output = std::process::Command::new("nvidia-smi")
111        .args([
112            "--query-gpu=clocks.max.graphics,clocks.max.memory,power.max_limit",
113            "--format=csv,noheader,nounits",
114        ])
115        .output()
116        .ok()?;
117    if !output.status.success() {
118        return None;
119    }
120    let stdout = String::from_utf8_lossy(&output.stdout);
121    let fields: Vec<&str> = stdout.trim().split(',').map(|s| s.trim()).collect();
122    if fields.len() >= 3 {
123        let g = fields[0].parse().ok()?;
124        let m = fields[1].parse().ok()?;
125        let p = fields[2].parse::<f32>().ok()? as u32;
126        Some((g, m, p))
127    } else {
128        None
129    }
130}
131
132/// Apply GPU power profile for a phase.
133fn apply_gpu_power_profile(phase: &task_state::WorkloadPhase, max_clocks: (u32, u32, u32)) {
134    let (max_g, _max_m, max_p) = max_clocks;
135    let (target_g, target_p) = match phase {
136        task_state::WorkloadPhase::DataLoading => (max_g / 3, (max_p as f32 * 0.78) as u32),
137        task_state::WorkloadPhase::GpuCompute => (max_g, max_p),
138        task_state::WorkloadPhase::NcclCollective => (max_g / 2, (max_p as f32 * 0.85) as u32),
139        task_state::WorkloadPhase::OptimizerStep => (max_g, max_p),
140        task_state::WorkloadPhase::Unknown => return,
141    };
142
143    // Lock GPU clocks (nvidia-smi -lgc works on consumer cards, -ac does not)
144    let _ = std::process::Command::new("nvidia-smi")
145        .args(["-i", "0", "-lgc", &target_g.to_string()])
146        .output();
147
148    // Set power limit
149    let _ = std::process::Command::new("nvidia-smi")
150        .args(["-i", "0", "-pl", &target_p.to_string()])
151        .output();
152
153    debug!(phase = %phase, gpu_clock = target_g, power_limit = target_p, "GPU power profile applied");
154}
155
156/// Set CPU frequency for all cores (phase-aware power management).
157/// During GPU compute phases, CPU is mostly idle — drop to minimum frequency.
158/// During data loading phases, CPU needs full speed for preprocessing.
159fn apply_cpu_power_profile(phase: &task_state::WorkloadPhase, num_cpus: usize) {
160    let freq_khz = match phase {
161        task_state::WorkloadPhase::DataLoading => 3600000,    // full speed for I/O + preprocessing
162        task_state::WorkloadPhase::GpuCompute => 1200000,     // minimum — CPU idle during GPU work
163        task_state::WorkloadPhase::NcclCollective => 1200000,  // minimum — waiting on network
164        task_state::WorkloadPhase::OptimizerStep => 3600000,   // full speed for CPU burst
165        task_state::WorkloadPhase::Unknown => 3600000,         // default to full
166    };
167
168    for i in 0..num_cpus {
169        let path = format!("/sys/devices/system/cpu/cpu{}/cpufreq/scaling_max_freq", i);
170        let _ = std::fs::write(&path, freq_khz.to_string());
171    }
172
173    debug!(phase = %phase, freq_mhz = freq_khz / 1000, "CPU frequency set");
174}
175
176/// Reset CPU frequency to maximum.
177fn reset_cpu_frequency(num_cpus: usize) {
178    for i in 0..num_cpus {
179        let path = format!("/sys/devices/system/cpu/cpu{}/cpufreq/scaling_max_freq", i);
180        let _ = std::fs::write(&path, "3600000");
181    }
182}
183
184/// Check if a process is alive.
185fn process_alive(pid: u32) -> bool {
186    std::path::Path::new(&format!("/proc/{}", pid)).exists()
187}
188
189#[tokio::main]
190async fn main() -> Result<()> {
191    tracing_subscriber::fmt()
192        .with_env_filter(
193            std::env::var("ZERNEL_LOG").unwrap_or_else(|_| "zernel_scheduler=info".into()),
194        )
195        .init();
196
197    let args = Args::parse();
198    info!("Zernel scheduler v{}", env!("CARGO_PKG_VERSION"));
199
200    let config = SchedulerConfig::load(&args.config)?;
201    info!(config = ?args.config, "loaded configuration");
202
203    if args.dump_config {
204        println!("{}", config.to_toml()?);
205        return Ok(());
206    }
207
208    // ── BPF sched_ext attachment ──────────────────────────────
209    #[cfg(feature = "bpf")]
210    let mut _open_object = std::mem::MaybeUninit::uninit();
211    #[cfg(feature = "bpf")]
212    let _skel_hold;
213    #[cfg(feature = "bpf")]
214    let _link_hold;
215
216    #[cfg(feature = "bpf")]
217    if !args.demo {
218        info!("attempting BPF sched_ext attachment");
219
220        let skel_builder = skel::ZernelSchedSkelBuilder::default();
221        let open_skel = skel_builder.open(&mut _open_object)
222            .expect("failed to open BPF skeleton");
223        let mut loaded = open_skel.load()
224            .expect("failed to load BPF skeleton");
225
226        let link = loaded.maps.zernel_ops.attach_struct_ops()
227            .expect("failed to attach struct_ops scheduler — is CONFIG_SCHED_CLASS_EXT=y?");
228        info!("sched_ext scheduler ATTACHED — zernel is now the kernel scheduler");
229
230        if let Ok(state) = std::fs::read_to_string("/sys/kernel/sched_ext/state") {
231            info!(state = state.trim(), "sched_ext kernel state");
232        }
233
234        _link_hold = Some(link);
235        _skel_hold = Some(loaded);
236    } else {
237        _link_hold = None;
238        _skel_hold = None;
239    }
240
241    #[cfg(not(feature = "bpf"))]
242    {
243        info!("running in userspace-only mode (BPF feature disabled)");
244    }
245
246    let mut sched = scheduler::ZernelScheduler::new(config.clone());
247
248    if args.demo {
249        run_demo(&mut sched);
250        return Ok(());
251    }
252
253    // ── Discover GPU max clocks for power management ──────────
254    let gpu_max_clocks = get_gpu_max_clocks();
255    if let Some((g, m, p)) = gpu_max_clocks {
256        info!(
257            max_graphics = g, max_memory = m, max_power = p,
258            "GPU max clocks detected — power management enabled"
259        );
260        // Enable persistence mode for clock control
261        let _ = std::process::Command::new("nvidia-smi")
262            .args(["-i", "0", "-pm", "1"])
263            .output();
264    } else {
265        warn!("could not detect GPU max clocks — power management disabled");
266    }
267
268    // Track the dominant phase across all GPU tasks for power management
269    let mut current_power_phase = task_state::WorkloadPhase::Unknown;
270    let mut gpu_poll_counter = 0u64;
271    let gpu_poll_interval = config.general.gpu_poll_interval_ms / config.general.phase_eval_interval_ms;
272
273    let eval_interval = tokio::time::Duration::from_millis(config.general.phase_eval_interval_ms);
274    let mut interval = tokio::time::interval(eval_interval);
275
276    info!(
277        interval_ms = config.general.phase_eval_interval_ms,
278        gpu_poll_every = gpu_poll_interval,
279        "entering continuous scheduling loop (Ctrl+C to stop)"
280    );
281
282    loop {
283        tokio::select! {
284            _ = interval.tick() => {
285                let now_ns = std::time::SystemTime::now()
286                    .duration_since(std::time::UNIX_EPOCH)
287                    .unwrap_or_default()
288                    .as_nanos() as u64;
289
290                gpu_poll_counter += 1;
291
292                // ── Periodic GPU process discovery + metrics update ──
293                if gpu_poll_counter % gpu_poll_interval == 0 {
294                    // Discover new GPU processes
295                    let gpu_procs = discover_gpu_processes();
296                    for (pid, gpu_id) in &gpu_procs {
297                        if !sched.task_states().contains_key(pid) {
298                            sched.register_task(*pid, true, Some(*gpu_id));
299                            info!(pid, gpu_id, "discovered GPU process");
300                        }
301                    }
302
303                    // Clean up dead processes
304                    let dead_pids: Vec<u32> = sched.task_states().keys()
305                        .filter(|pid| !process_alive(**pid))
306                        .copied()
307                        .collect();
308                    for pid in dead_pids {
309                        sched.unregister_task(pid);
310                        debug!(pid, "cleaned up dead process");
311                    }
312
313                    // Update GPU utilization for all tracked tasks
314                    if let Some(gpu_util) = get_gpu_utilization() {
315                        let pids: Vec<u32> = sched.task_states().keys().copied().collect();
316                        for pid in &pids {
317                            sched.update_task(*pid, scheduler::TaskUpdate {
318                                gpu_utilization: Some(gpu_util),
319                                ..Default::default()
320                            });
321                        }
322                    }
323                }
324
325                // ── Run phase detection on all tracked tasks ──
326                let pids: Vec<u32> = sched.task_states().keys().copied().collect();
327                let mut phase_counts: HashMap<task_state::WorkloadPhase, u32> = HashMap::new();
328
329                for pid in &pids {
330                    let decision = sched.schedule(*pid, now_ns);
331
332                    // Get the detected phase
333                    if let Some(state) = sched.task_states().get(pid) {
334                        *phase_counts.entry(state.current_phase).or_insert(0) += 1;
335
336                        // ── Write phase to BPF phase_map ──
337                        #[cfg(feature = "bpf")]
338                        if !args.demo {
339                            let phase_val: u32 = match state.current_phase {
340                                task_state::WorkloadPhase::DataLoading => 0,
341                                task_state::WorkloadPhase::GpuCompute => 1,
342                                task_state::WorkloadPhase::NcclCollective => 2,
343                                task_state::WorkloadPhase::OptimizerStep => 3,
344                                task_state::WorkloadPhase::Unknown => 255,
345                            };
346                            let key = pid.to_ne_bytes();
347                            let val = phase_val.to_ne_bytes();
348                            let _ = _skel_hold.as_ref().unwrap().maps.phase_map
349                                .update(&key, &val, libbpf_rs::MapFlags::ANY);
350                        }
351
352                        // ── Write CPU affinity for data-loading tasks ──
353                        #[cfg(feature = "bpf")]
354                        if let Some(cpu) = decision.preferred_cpu {
355                            if state.current_phase == task_state::WorkloadPhase::DataLoading {
356                                let key = pid.to_ne_bytes();
357                                let val = (cpu as i32).to_ne_bytes();
358                                let _ = _skel_hold.as_ref().unwrap().maps.cpu_affinity_map
359                                    .update(&key, &val, libbpf_rs::MapFlags::ANY);
360                            }
361                        }
362                    }
363                }
364
365                // ─��� GPU power management: apply dominant phase ──
366                if let Some(max_clocks) = gpu_max_clocks {
367                    // Find the dominant phase among GPU tasks
368                    let dominant = phase_counts.iter()
369                        .max_by_key(|(_, count)| *count)
370                        .map(|(phase, _)| *phase)
371                        .unwrap_or(task_state::WorkloadPhase::Unknown);
372
373                    if dominant != current_power_phase && dominant != task_state::WorkloadPhase::Unknown {
374                        apply_gpu_power_profile(&dominant, max_clocks);
375                        let num_cpus = sched.numa_topology().total_cpus();
376                        apply_cpu_power_profile(&dominant, num_cpus);
377                        info!(
378                            from = %current_power_phase, to = %dominant,
379                            "power phase transition (GPU + CPU)"
380                        );
381                        current_power_phase = dominant;
382                    }
383                }
384
385                // ─�� Periodic telemetry ──
386                if sched.decisions_made > 0 && sched.decisions_made % 500 == 0 {
387                    let telem = telemetry::export_telemetry(&sched);
388                    let power = get_gpu_power().unwrap_or(0.0);
389                    info!(
390                        tasks = telem.total_tracked_tasks,
391                        ml = telem.ml_tasks,
392                        decisions = telem.decisions_made,
393                        transitions = telem.phase_transitions,
394                        gpu_power_w = format!("{:.1}", power),
395                        power_phase = %current_power_phase,
396                        "scheduling telemetry"
397                    );
398                }
399            }
400            _ = tokio::signal::ctrl_c() => {
401                let telem = telemetry::export_telemetry(&sched);
402                info!(
403                    decisions = telem.decisions_made,
404                    transitions = telem.phase_transitions,
405                    "shutting down — final telemetry"
406                );
407
408                // Reset GPU + CPU power on exit
409                if gpu_max_clocks.is_some() {
410                    let _ = std::process::Command::new("nvidia-smi")
411                        .args(["-i", "0", "-rgc"]).output();
412                    let _ = std::process::Command::new("nvidia-smi")
413                        .args(["-i", "0", "-pl", "115"]).output();
414                    let num_cpus = sched.numa_topology().total_cpus();
415                    reset_cpu_frequency(num_cpus);
416                    info!("GPU + CPU power reset to defaults");
417                }
418                break;
419            }
420        }
421    }
422
423    Ok(())
424}
425
426fn run_demo(sched: &mut scheduler::ZernelScheduler) {
427    info!("--- demo: simulating ML workload lifecycle ---");
428
429    sched.register_task(1000, true, Some(0));
430
431    sched.update_task(1000, scheduler::TaskUpdate {
432        io_wait_fraction: Some(0.6), gpu_utilization: Some(5), ..Default::default()
433    });
434    let d = sched.schedule(1000, 1_000_000);
435    info!(phase = "DataLoading", priority = d.priority, cpu = ?d.preferred_cpu, "decision");
436
437    sched.update_task(1000, scheduler::TaskUpdate {
438        io_wait_fraction: Some(0.01), gpu_utilization: Some(96),
439        cpu_burst_duration_ns: Some(0), ..Default::default()
440    });
441    let d = sched.schedule(1000, 5_000_000);
442    info!(phase = "GpuCompute", priority = d.priority, "decision");
443
444    sched.update_task(1000, scheduler::TaskUpdate {
445        gpu_utilization: Some(10), cpu_burst_duration_ns: Some(2_000_000),
446        last_gpu_sync_ns: Some(4_900_000), ..Default::default()
447    });
448    let d = sched.schedule(1000, 8_000_000);
449    info!(phase = "OptimizerStep", priority = d.priority, "decision");
450
451    let telem = telemetry::export_telemetry(sched);
452    info!(tasks = telem.total_tracked_tasks, decisions = telem.decisions_made, "demo complete");
453}