zernel_scheduler/
scheduler.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use crate::config::SchedulerConfig;
4use crate::multi_tenant::TenantScheduler;
5use crate::numa::NumaTopology;
6use crate::phase_detector::PhaseDetector;
7use crate::task_state::{WorkloadPhase, ZernelTaskState};
8use std::collections::HashMap;
9use tracing::{debug, info};
10
11/// Scheduling priority assigned to a task based on its detected phase.
12#[derive(Debug, Clone, Copy)]
13pub struct SchedulingDecision {
14    pub priority: i32,
15    pub preempt: bool,
16    pub latency_target_us: Option<u64>,
17    pub preferred_cpu: Option<u32>,
18}
19
20/// Phase-based scheduling policy table.
21fn policy_for_phase(phase: WorkloadPhase) -> (i32, bool, Option<u64>) {
22    match phase {
23        WorkloadPhase::DataLoading => (10, true, Some(100)),
24        WorkloadPhase::GpuCompute => (-5, false, None),
25        WorkloadPhase::NcclCollective => (10, true, Some(50)),
26        WorkloadPhase::OptimizerStep => (10, true, Some(100)),
27        WorkloadPhase::Unknown => (0, false, None),
28    }
29}
30
31/// The Zernel ML-aware scheduler.
32///
33/// Integrates phase detection, NUMA-aware CPU selection, and
34/// multi-tenant GPU-proportional scheduling.
35pub struct ZernelScheduler {
36    config: SchedulerConfig,
37    phase_detector: PhaseDetector,
38    task_states: HashMap<u32, ZernelTaskState>,
39    numa: NumaTopology,
40    tenant_scheduler: TenantScheduler,
41    /// Per-CPU load estimates (cpu_id -> load fraction 0.0-1.0).
42    cpu_loads: HashMap<u32, f32>,
43    /// Total scheduling decisions made.
44    pub decisions_made: u64,
45}
46
47impl ZernelScheduler {
48    pub fn new(config: SchedulerConfig) -> Self {
49        let phase_config = (&config.phase_detection).into();
50        let numa = NumaTopology::detect();
51
52        info!(
53            numa_nodes = numa.nodes.len(),
54            total_cpus = numa.total_cpus(),
55            gpus_mapped = numa.gpu_node_map.len(),
56            "NUMA topology detected"
57        );
58
59        Self {
60            config,
61            phase_detector: PhaseDetector::new(phase_config),
62            task_states: HashMap::new(),
63            numa,
64            tenant_scheduler: TenantScheduler::new(),
65            cpu_loads: HashMap::new(),
66            decisions_made: 0,
67        }
68    }
69
70    /// Register a new task for tracking.
71    pub fn register_task(&mut self, pid: u32, is_ml: bool, gpu_id: Option<u32>) {
72        let mut state = ZernelTaskState::new(pid);
73        state.is_ml_process = is_ml;
74        state.gpu_id = gpu_id;
75        self.task_states.insert(pid, state);
76        info!(pid, is_ml, ?gpu_id, "registered task");
77    }
78
79    /// Remove a task from tracking.
80    pub fn unregister_task(&mut self, pid: u32) {
81        self.task_states.remove(&pid);
82        self.phase_detector.remove_task(pid);
83        debug!(pid, "unregistered task");
84    }
85
86    /// Update task metrics (called from BPF ringbuf events or polling).
87    pub fn update_task(&mut self, pid: u32, update: TaskUpdate) {
88        if let Some(state) = self.task_states.get_mut(&pid) {
89            if let Some(v) = update.gpu_utilization {
90                state.gpu_utilization = v;
91            }
92            if let Some(v) = update.io_wait_fraction {
93                state.io_wait_fraction = v;
94            }
95            if let Some(v) = update.cpu_burst_duration_ns {
96                state.cpu_burst_duration_ns = v;
97            }
98            if let Some(v) = update.last_gpu_sync_ns {
99                state.last_gpu_sync_ns = v;
100            }
101            if let Some(v) = update.nccl_active {
102                state.nccl_active = v;
103            }
104            if let Some(v) = update.futex_wait_count {
105                state.futex_wait_count = v;
106            }
107            if let Some(v) = update.gpu_id {
108                state.gpu_id = Some(v);
109            }
110        }
111    }
112
113    /// Update CPU load estimates (called periodically from /proc/stat or BPF).
114    pub fn update_cpu_loads(&mut self, loads: HashMap<u32, f32>) {
115        self.cpu_loads = loads;
116    }
117
118    /// Run phase detection and return a scheduling decision for a task.
119    pub fn schedule(&mut self, pid: u32, now_ns: u64) -> SchedulingDecision {
120        let Some(state) = self.task_states.get_mut(&pid) else {
121            return SchedulingDecision {
122                priority: 0,
123                preempt: false,
124                latency_target_us: None,
125                preferred_cpu: None,
126            };
127        };
128
129        // Phase detection
130        let new_phase = self.phase_detector.detect(state);
131        if new_phase != state.current_phase {
132            state.transition_phase(new_phase, now_ns);
133        }
134
135        // Base policy from phase
136        let (base_priority, preempt, latency_target_us) = policy_for_phase(new_phase);
137
138        // Multi-tenant priority adjustment
139        let priority = if self.config.multi_tenant.enabled {
140            self.tenant_scheduler.effective_priority(pid, base_priority)
141        } else {
142            base_priority
143        };
144
145        // NUMA-aware CPU selection
146        let preferred_cpu = if self.config.numa.gpu_affinity {
147            let gpu_id = state.gpu_id;
148            Some(self.numa.select_cpu(gpu_id, &self.cpu_loads))
149        } else {
150            None
151        };
152
153        self.decisions_made += 1;
154
155        let decision = SchedulingDecision {
156            priority,
157            preempt,
158            latency_target_us,
159            preferred_cpu,
160        };
161
162        debug!(
163            pid,
164            phase = %new_phase,
165            priority = decision.priority,
166            ?preferred_cpu,
167            "scheduling decision"
168        );
169
170        decision
171    }
172
173    /// Get a reference to the tenant scheduler.
174    pub fn tenant_scheduler_mut(&mut self) -> &mut TenantScheduler {
175        &mut self.tenant_scheduler
176    }
177
178    /// Get current state snapshot for telemetry export.
179    pub fn task_states(&self) -> &HashMap<u32, ZernelTaskState> {
180        &self.task_states
181    }
182
183    pub fn numa_topology(&self) -> &NumaTopology {
184        &self.numa
185    }
186
187    pub fn phase_transition_count(&self) -> u64 {
188        self.phase_detector.transition_count
189    }
190
191    pub fn config(&self) -> &SchedulerConfig {
192        &self.config
193    }
194}
195
196/// Partial update for task metrics.
197#[derive(Debug, Default)]
198pub struct TaskUpdate {
199    pub gpu_utilization: Option<u8>,
200    pub io_wait_fraction: Option<f32>,
201    pub cpu_burst_duration_ns: Option<u64>,
202    pub last_gpu_sync_ns: Option<u64>,
203    pub nccl_active: Option<bool>,
204    pub futex_wait_count: Option<u32>,
205    pub gpu_id: Option<u32>,
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    fn default_sched() -> ZernelScheduler {
213        let mut config = SchedulerConfig::default();
214        // Disable phase stability in tests for immediate transitions
215        config.phase_detection.phase_stability_count = 1;
216        ZernelScheduler::new(config)
217    }
218
219    #[test]
220    fn register_and_schedule() {
221        let mut sched = default_sched();
222        sched.register_task(100, true, Some(0));
223        sched.update_task(
224            100,
225            TaskUpdate {
226                gpu_utilization: Some(95),
227                ..Default::default()
228            },
229        );
230
231        let decision = sched.schedule(100, 1000);
232        assert_eq!(decision.priority, -5); // GpuCompute
233        assert!(!decision.preempt);
234    }
235
236    #[test]
237    fn data_loading_gets_high_priority() {
238        let mut sched = default_sched();
239        sched.register_task(100, true, Some(0));
240        sched.update_task(
241            100,
242            TaskUpdate {
243                io_wait_fraction: Some(0.5),
244                gpu_utilization: Some(5),
245                ..Default::default()
246            },
247        );
248
249        let decision = sched.schedule(100, 1000);
250        assert_eq!(decision.priority, 10);
251        assert!(decision.preempt);
252    }
253
254    #[test]
255    fn preferred_cpu_set_with_numa() {
256        let mut sched = default_sched();
257        sched.register_task(100, true, Some(0));
258        let decision = sched.schedule(100, 1000);
259        assert!(decision.preferred_cpu.is_some());
260    }
261
262    #[test]
263    fn unknown_pid_gets_default() {
264        let mut sched = default_sched();
265        let decision = sched.schedule(999, 1000);
266        assert_eq!(decision.priority, 0);
267    }
268
269    #[test]
270    fn decisions_counter_increments() {
271        let mut sched = default_sched();
272        sched.register_task(100, true, None);
273        sched.schedule(100, 1000);
274        sched.schedule(100, 2000);
275        assert_eq!(sched.decisions_made, 2);
276    }
277}