1use crate::config::SchedulerConfig;
4use crate::multi_tenant::TenantScheduler;
5use crate::numa::NumaTopology;
6use crate::phase_detector::PhaseDetector;
7use crate::task_state::{WorkloadPhase, ZernelTaskState};
8use std::collections::HashMap;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone, Copy)]
13pub struct SchedulingDecision {
14 pub priority: i32,
15 pub preempt: bool,
16 pub latency_target_us: Option<u64>,
17 pub preferred_cpu: Option<u32>,
18}
19
20fn policy_for_phase(phase: WorkloadPhase) -> (i32, bool, Option<u64>) {
22 match phase {
23 WorkloadPhase::DataLoading => (10, true, Some(100)),
24 WorkloadPhase::GpuCompute => (-5, false, None),
25 WorkloadPhase::NcclCollective => (10, true, Some(50)),
26 WorkloadPhase::OptimizerStep => (10, true, Some(100)),
27 WorkloadPhase::Unknown => (0, false, None),
28 }
29}
30
31pub struct ZernelScheduler {
36 config: SchedulerConfig,
37 phase_detector: PhaseDetector,
38 task_states: HashMap<u32, ZernelTaskState>,
39 numa: NumaTopology,
40 tenant_scheduler: TenantScheduler,
41 cpu_loads: HashMap<u32, f32>,
43 pub decisions_made: u64,
45}
46
47impl ZernelScheduler {
48 pub fn new(config: SchedulerConfig) -> Self {
49 let phase_config = (&config.phase_detection).into();
50 let numa = NumaTopology::detect();
51
52 info!(
53 numa_nodes = numa.nodes.len(),
54 total_cpus = numa.total_cpus(),
55 gpus_mapped = numa.gpu_node_map.len(),
56 "NUMA topology detected"
57 );
58
59 Self {
60 config,
61 phase_detector: PhaseDetector::new(phase_config),
62 task_states: HashMap::new(),
63 numa,
64 tenant_scheduler: TenantScheduler::new(),
65 cpu_loads: HashMap::new(),
66 decisions_made: 0,
67 }
68 }
69
70 pub fn register_task(&mut self, pid: u32, is_ml: bool, gpu_id: Option<u32>) {
72 let mut state = ZernelTaskState::new(pid);
73 state.is_ml_process = is_ml;
74 state.gpu_id = gpu_id;
75 self.task_states.insert(pid, state);
76 info!(pid, is_ml, ?gpu_id, "registered task");
77 }
78
79 pub fn unregister_task(&mut self, pid: u32) {
81 self.task_states.remove(&pid);
82 self.phase_detector.remove_task(pid);
83 debug!(pid, "unregistered task");
84 }
85
86 pub fn update_task(&mut self, pid: u32, update: TaskUpdate) {
88 if let Some(state) = self.task_states.get_mut(&pid) {
89 if let Some(v) = update.gpu_utilization {
90 state.gpu_utilization = v;
91 }
92 if let Some(v) = update.io_wait_fraction {
93 state.io_wait_fraction = v;
94 }
95 if let Some(v) = update.cpu_burst_duration_ns {
96 state.cpu_burst_duration_ns = v;
97 }
98 if let Some(v) = update.last_gpu_sync_ns {
99 state.last_gpu_sync_ns = v;
100 }
101 if let Some(v) = update.nccl_active {
102 state.nccl_active = v;
103 }
104 if let Some(v) = update.futex_wait_count {
105 state.futex_wait_count = v;
106 }
107 if let Some(v) = update.gpu_id {
108 state.gpu_id = Some(v);
109 }
110 }
111 }
112
113 pub fn update_cpu_loads(&mut self, loads: HashMap<u32, f32>) {
115 self.cpu_loads = loads;
116 }
117
118 pub fn schedule(&mut self, pid: u32, now_ns: u64) -> SchedulingDecision {
120 let Some(state) = self.task_states.get_mut(&pid) else {
121 return SchedulingDecision {
122 priority: 0,
123 preempt: false,
124 latency_target_us: None,
125 preferred_cpu: None,
126 };
127 };
128
129 let new_phase = self.phase_detector.detect(state);
131 if new_phase != state.current_phase {
132 state.transition_phase(new_phase, now_ns);
133 }
134
135 let (base_priority, preempt, latency_target_us) = policy_for_phase(new_phase);
137
138 let priority = if self.config.multi_tenant.enabled {
140 self.tenant_scheduler.effective_priority(pid, base_priority)
141 } else {
142 base_priority
143 };
144
145 let preferred_cpu = if self.config.numa.gpu_affinity {
147 let gpu_id = state.gpu_id;
148 Some(self.numa.select_cpu(gpu_id, &self.cpu_loads))
149 } else {
150 None
151 };
152
153 self.decisions_made += 1;
154
155 let decision = SchedulingDecision {
156 priority,
157 preempt,
158 latency_target_us,
159 preferred_cpu,
160 };
161
162 debug!(
163 pid,
164 phase = %new_phase,
165 priority = decision.priority,
166 ?preferred_cpu,
167 "scheduling decision"
168 );
169
170 decision
171 }
172
173 pub fn tenant_scheduler_mut(&mut self) -> &mut TenantScheduler {
175 &mut self.tenant_scheduler
176 }
177
178 pub fn task_states(&self) -> &HashMap<u32, ZernelTaskState> {
180 &self.task_states
181 }
182
183 pub fn numa_topology(&self) -> &NumaTopology {
184 &self.numa
185 }
186
187 pub fn phase_transition_count(&self) -> u64 {
188 self.phase_detector.transition_count
189 }
190
191 pub fn config(&self) -> &SchedulerConfig {
192 &self.config
193 }
194}
195
196#[derive(Debug, Default)]
198pub struct TaskUpdate {
199 pub gpu_utilization: Option<u8>,
200 pub io_wait_fraction: Option<f32>,
201 pub cpu_burst_duration_ns: Option<u64>,
202 pub last_gpu_sync_ns: Option<u64>,
203 pub nccl_active: Option<bool>,
204 pub futex_wait_count: Option<u32>,
205 pub gpu_id: Option<u32>,
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn default_sched() -> ZernelScheduler {
213 let mut config = SchedulerConfig::default();
214 config.phase_detection.phase_stability_count = 1;
216 ZernelScheduler::new(config)
217 }
218
219 #[test]
220 fn register_and_schedule() {
221 let mut sched = default_sched();
222 sched.register_task(100, true, Some(0));
223 sched.update_task(
224 100,
225 TaskUpdate {
226 gpu_utilization: Some(95),
227 ..Default::default()
228 },
229 );
230
231 let decision = sched.schedule(100, 1000);
232 assert_eq!(decision.priority, -5); assert!(!decision.preempt);
234 }
235
236 #[test]
237 fn data_loading_gets_high_priority() {
238 let mut sched = default_sched();
239 sched.register_task(100, true, Some(0));
240 sched.update_task(
241 100,
242 TaskUpdate {
243 io_wait_fraction: Some(0.5),
244 gpu_utilization: Some(5),
245 ..Default::default()
246 },
247 );
248
249 let decision = sched.schedule(100, 1000);
250 assert_eq!(decision.priority, 10);
251 assert!(decision.preempt);
252 }
253
254 #[test]
255 fn preferred_cpu_set_with_numa() {
256 let mut sched = default_sched();
257 sched.register_task(100, true, Some(0));
258 let decision = sched.schedule(100, 1000);
259 assert!(decision.preferred_cpu.is_some());
260 }
261
262 #[test]
263 fn unknown_pid_gets_default() {
264 let mut sched = default_sched();
265 let decision = sched.schedule(999, 1000);
266 assert_eq!(decision.priority, 0);
267 }
268
269 #[test]
270 fn decisions_counter_increments() {
271 let mut sched = default_sched();
272 sched.register_task(100, true, None);
273 sched.schedule(100, 1000);
274 sched.schedule(100, 2000);
275 assert_eq!(sched.decisions_made, 2);
276 }
277}