zernel_scheduler/
phase_detector.rs1use crate::task_state::{WorkloadPhase, ZernelTaskState};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PhaseDetectorConfig {
9 pub io_wait_threshold: f32,
11 pub optimizer_burst_max_ns: u64,
13 pub gpu_idle_threshold: u8,
15 pub gpu_active_threshold: u8,
17 pub phase_stability_count: u32,
19 pub nccl_detection_enabled: bool,
21}
22
23impl Default for PhaseDetectorConfig {
24 fn default() -> Self {
25 Self {
26 io_wait_threshold: 0.3,
27 optimizer_burst_max_ns: 5_000_000, gpu_idle_threshold: 10,
29 gpu_active_threshold: 80,
30 phase_stability_count: 3,
31 nccl_detection_enabled: false,
32 }
33 }
34}
35
36#[derive(Debug, Default)]
38struct PhaseStabilityTracker {
39 candidate: Option<WorkloadPhase>,
41 count: u32,
43}
44
45pub struct PhaseDetector {
47 config: PhaseDetectorConfig,
48 stability: std::collections::HashMap<u32, PhaseStabilityTracker>,
50 pub transition_count: u64,
52}
53
54impl PhaseDetector {
55 pub fn new(config: PhaseDetectorConfig) -> Self {
56 Self {
57 config,
58 stability: std::collections::HashMap::new(),
59 transition_count: 0,
60 }
61 }
62
63 pub fn detect(&mut self, state: &ZernelTaskState) -> WorkloadPhase {
66 let raw_phase = self.detect_raw(state);
67
68 if self.config.phase_stability_count <= 1 {
69 if raw_phase != state.current_phase {
71 self.transition_count += 1;
72 }
73 return raw_phase;
74 }
75
76 let tracker = self.stability.entry(state.pid).or_default();
77
78 if tracker.candidate == Some(raw_phase) {
79 tracker.count += 1;
80 } else {
81 tracker.candidate = Some(raw_phase);
82 tracker.count = 1;
83 }
84
85 if tracker.count >= self.config.phase_stability_count {
86 if raw_phase != state.current_phase {
87 self.transition_count += 1;
88 }
89 raw_phase
90 } else {
91 state.current_phase
93 }
94 }
95
96 fn detect_raw(&self, state: &ZernelTaskState) -> WorkloadPhase {
98 if !state.is_ml_process {
99 return WorkloadPhase::Unknown;
100 }
101
102 if self.config.nccl_detection_enabled && state.nccl_active && state.futex_wait_count > 0 {
104 return WorkloadPhase::NcclCollective;
105 }
106
107 if state.io_wait_fraction > self.config.io_wait_threshold
109 && state.gpu_utilization < self.config.gpu_idle_threshold
110 {
111 return WorkloadPhase::DataLoading;
112 }
113
114 if state.gpu_utilization > self.config.gpu_active_threshold
116 && state.cpu_burst_duration_ns == 0
117 {
118 return WorkloadPhase::GpuCompute;
119 }
120
121 if state.cpu_burst_duration_ns > 0
123 && state.cpu_burst_duration_ns < self.config.optimizer_burst_max_ns
124 && state.last_gpu_sync_ns > 0
125 {
126 return WorkloadPhase::OptimizerStep;
127 }
128
129 if state.io_wait_fraction > self.config.io_wait_threshold * 0.5
131 && state.gpu_utilization < self.config.gpu_active_threshold
132 && state.gpu_utilization > self.config.gpu_idle_threshold
133 {
134 return WorkloadPhase::DataLoading;
135 }
136
137 WorkloadPhase::Unknown
138 }
139
140 pub fn remove_task(&mut self, pid: u32) {
142 self.stability.remove(&pid);
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 fn make_state(pid: u32, ml: bool) -> ZernelTaskState {
151 let mut s = ZernelTaskState::new(pid);
152 s.is_ml_process = ml;
153 s
154 }
155
156 fn detector() -> PhaseDetector {
157 PhaseDetector::new(PhaseDetectorConfig {
158 phase_stability_count: 1, ..Default::default()
160 })
161 }
162
163 #[test]
164 fn non_ml_process_is_unknown() {
165 let mut det = detector();
166 let state = make_state(1, false);
167 assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
168 }
169
170 #[test]
171 fn high_io_wait_low_gpu_is_data_loading() {
172 let mut det = detector();
173 let mut state = make_state(1, true);
174 state.io_wait_fraction = 0.5;
175 state.gpu_utilization = 5;
176 assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
177 }
178
179 #[test]
180 fn high_gpu_util_is_gpu_compute() {
181 let mut det = detector();
182 let mut state = make_state(1, true);
183 state.gpu_utilization = 95;
184 state.cpu_burst_duration_ns = 0;
185 assert_eq!(det.detect(&state), WorkloadPhase::GpuCompute);
186 }
187
188 #[test]
189 fn short_burst_after_sync_is_optimizer_step() {
190 let mut det = detector();
191 let mut state = make_state(1, true);
192 state.cpu_burst_duration_ns = 2_000_000; state.last_gpu_sync_ns = 1_000_000_000;
194 assert_eq!(det.detect(&state), WorkloadPhase::OptimizerStep);
195 }
196
197 #[test]
198 fn nccl_detection_when_enabled() {
199 let mut det = PhaseDetector::new(PhaseDetectorConfig {
200 nccl_detection_enabled: true,
201 phase_stability_count: 1,
202 ..Default::default()
203 });
204 let mut state = make_state(1, true);
205 state.nccl_active = true;
206 state.futex_wait_count = 10;
207 assert_eq!(det.detect(&state), WorkloadPhase::NcclCollective);
208 }
209
210 #[test]
211 fn nccl_detection_off_by_default() {
212 let mut det = detector();
213 let mut state = make_state(1, true);
214 state.nccl_active = true;
215 state.futex_wait_count = 10;
216 assert_ne!(det.detect(&state), WorkloadPhase::NcclCollective);
218 }
219
220 #[test]
221 fn stability_prevents_flapping() {
222 let mut det = PhaseDetector::new(PhaseDetectorConfig {
223 phase_stability_count: 3,
224 ..Default::default()
225 });
226
227 let mut state = make_state(1, true);
228 state.current_phase = WorkloadPhase::Unknown;
229
230 state.io_wait_fraction = 0.5;
232 state.gpu_utilization = 5;
233
234 assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
236 assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
237 assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
239 }
240
241 #[test]
242 fn overlapped_prefetch_detected_as_data_loading() {
243 let mut det = detector();
244 let mut state = make_state(1, true);
245 state.io_wait_fraction = 0.2; state.gpu_utilization = 50; assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
248 }
249
250 #[test]
251 fn transition_counter_increments() {
252 let mut det = detector();
253 let mut state = make_state(1, true);
254 state.current_phase = WorkloadPhase::Unknown;
255 state.io_wait_fraction = 0.5;
256 state.gpu_utilization = 5;
257 det.detect(&state);
258 assert_eq!(det.transition_count, 1);
259 }
260}
261
262#[cfg(test)]
263mod proptests {
264 use super::*;
265 use proptest::prelude::*;
266
267 proptest! {
268 #[test]
270 fn detect_never_panics(
271 gpu_util in 0u8..=100,
272 io_wait in 0.0f32..=1.0,
273 cpu_burst in 0u64..=100_000_000,
274 gpu_sync in 0u64..=10_000_000_000u64,
275 nccl in proptest::bool::ANY,
276 futex in 0u32..=1000,
277 is_ml in proptest::bool::ANY,
278 ) {
279 let mut det = PhaseDetector::new(PhaseDetectorConfig {
280 phase_stability_count: 1,
281 nccl_detection_enabled: true,
282 ..Default::default()
283 });
284 let mut state = ZernelTaskState::new(1);
285 state.is_ml_process = is_ml;
286 state.gpu_utilization = gpu_util;
287 state.io_wait_fraction = io_wait;
288 state.cpu_burst_duration_ns = cpu_burst;
289 state.last_gpu_sync_ns = gpu_sync;
290 state.nccl_active = nccl;
291 state.futex_wait_count = futex;
292
293 let phase = det.detect(&state);
295 match phase {
296 WorkloadPhase::DataLoading
297 | WorkloadPhase::GpuCompute
298 | WorkloadPhase::NcclCollective
299 | WorkloadPhase::OptimizerStep
300 | WorkloadPhase::Unknown => {} }
302 }
303
304 #[test]
306 fn non_ml_always_unknown(
307 gpu_util in 0u8..=100,
308 io_wait in 0.0f32..=1.0,
309 ) {
310 let mut det = PhaseDetector::new(PhaseDetectorConfig {
311 phase_stability_count: 1,
312 ..Default::default()
313 });
314 let mut state = ZernelTaskState::new(1);
315 state.is_ml_process = false;
316 state.gpu_utilization = gpu_util;
317 state.io_wait_fraction = io_wait;
318
319 assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
320 }
321 }
322}