zernel_scheduler/
phase_detector.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use crate::task_state::{WorkloadPhase, ZernelTaskState};
4use serde::{Deserialize, Serialize};
5
6/// Thresholds for phase detection heuristics.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PhaseDetectorConfig {
9    /// I/O wait fraction above which we classify as DataLoading.
10    pub io_wait_threshold: f32,
11    /// Duration (ns) below which a CPU burst after GpuCompute is an OptimizerStep.
12    pub optimizer_burst_max_ns: u64,
13    /// GPU utilization below which we suspect the GPU is idle.
14    pub gpu_idle_threshold: u8,
15    /// GPU utilization above which we consider GPU actively computing.
16    pub gpu_active_threshold: u8,
17    /// Number of consecutive identical phase samples before committing transition.
18    pub phase_stability_count: u32,
19    /// Enable NCCL collective detection.
20    pub nccl_detection_enabled: bool,
21}
22
23impl Default for PhaseDetectorConfig {
24    fn default() -> Self {
25        Self {
26            io_wait_threshold: 0.3,
27            optimizer_burst_max_ns: 5_000_000, // 5ms
28            gpu_idle_threshold: 10,
29            gpu_active_threshold: 80,
30            phase_stability_count: 3,
31            nccl_detection_enabled: false,
32        }
33    }
34}
35
36/// Tracks per-task phase stability to avoid flapping.
37#[derive(Debug, Default)]
38struct PhaseStabilityTracker {
39    /// The candidate phase we're observing.
40    candidate: Option<WorkloadPhase>,
41    /// How many consecutive times we've seen this candidate.
42    count: u32,
43}
44
45/// Detects the current workload phase for a given task based on runtime metrics.
46pub struct PhaseDetector {
47    config: PhaseDetectorConfig,
48    /// Per-pid stability tracking to debounce phase transitions.
49    stability: std::collections::HashMap<u32, PhaseStabilityTracker>,
50    /// Phase transition counters for telemetry.
51    pub transition_count: u64,
52}
53
54impl PhaseDetector {
55    pub fn new(config: PhaseDetectorConfig) -> Self {
56        Self {
57            config,
58            stability: std::collections::HashMap::new(),
59            transition_count: 0,
60        }
61    }
62
63    /// Classify the workload phase based on current task state.
64    /// Uses stability tracking to prevent rapid phase flapping.
65    pub fn detect(&mut self, state: &ZernelTaskState) -> WorkloadPhase {
66        let raw_phase = self.detect_raw(state);
67
68        if self.config.phase_stability_count <= 1 {
69            // Stability disabled — immediate transitions
70            if raw_phase != state.current_phase {
71                self.transition_count += 1;
72            }
73            return raw_phase;
74        }
75
76        let tracker = self.stability.entry(state.pid).or_default();
77
78        if tracker.candidate == Some(raw_phase) {
79            tracker.count += 1;
80        } else {
81            tracker.candidate = Some(raw_phase);
82            tracker.count = 1;
83        }
84
85        if tracker.count >= self.config.phase_stability_count {
86            if raw_phase != state.current_phase {
87                self.transition_count += 1;
88            }
89            raw_phase
90        } else {
91            // Not stable yet — keep current phase
92            state.current_phase
93        }
94    }
95
96    /// Raw phase classification without stability tracking.
97    fn detect_raw(&self, state: &ZernelTaskState) -> WorkloadPhase {
98        if !state.is_ml_process {
99            return WorkloadPhase::Unknown;
100        }
101
102        // NCCL collective detection (highest priority — on critical path)
103        if self.config.nccl_detection_enabled && state.nccl_active && state.futex_wait_count > 0 {
104            return WorkloadPhase::NcclCollective;
105        }
106
107        // High I/O wait + low GPU util → data loading
108        if state.io_wait_fraction > self.config.io_wait_threshold
109            && state.gpu_utilization < self.config.gpu_idle_threshold
110        {
111            return WorkloadPhase::DataLoading;
112        }
113
114        // GPU highly utilized + low CPU burst → GPU compute phase
115        if state.gpu_utilization > self.config.gpu_active_threshold
116            && state.cpu_burst_duration_ns == 0
117        {
118            return WorkloadPhase::GpuCompute;
119        }
120
121        // Short CPU burst right after GPU sync → optimizer step
122        if state.cpu_burst_duration_ns > 0
123            && state.cpu_burst_duration_ns < self.config.optimizer_burst_max_ns
124            && state.last_gpu_sync_ns > 0
125        {
126            return WorkloadPhase::OptimizerStep;
127        }
128
129        // Medium I/O + medium GPU → probably data loading with prefetch overlap
130        if state.io_wait_fraction > self.config.io_wait_threshold * 0.5
131            && state.gpu_utilization < self.config.gpu_active_threshold
132            && state.gpu_utilization > self.config.gpu_idle_threshold
133        {
134            return WorkloadPhase::DataLoading;
135        }
136
137        WorkloadPhase::Unknown
138    }
139
140    /// Remove stability tracking for a task.
141    pub fn remove_task(&mut self, pid: u32) {
142        self.stability.remove(&pid);
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    fn make_state(pid: u32, ml: bool) -> ZernelTaskState {
151        let mut s = ZernelTaskState::new(pid);
152        s.is_ml_process = ml;
153        s
154    }
155
156    fn detector() -> PhaseDetector {
157        PhaseDetector::new(PhaseDetectorConfig {
158            phase_stability_count: 1, // disable stability for unit tests
159            ..Default::default()
160        })
161    }
162
163    #[test]
164    fn non_ml_process_is_unknown() {
165        let mut det = detector();
166        let state = make_state(1, false);
167        assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
168    }
169
170    #[test]
171    fn high_io_wait_low_gpu_is_data_loading() {
172        let mut det = detector();
173        let mut state = make_state(1, true);
174        state.io_wait_fraction = 0.5;
175        state.gpu_utilization = 5;
176        assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
177    }
178
179    #[test]
180    fn high_gpu_util_is_gpu_compute() {
181        let mut det = detector();
182        let mut state = make_state(1, true);
183        state.gpu_utilization = 95;
184        state.cpu_burst_duration_ns = 0;
185        assert_eq!(det.detect(&state), WorkloadPhase::GpuCompute);
186    }
187
188    #[test]
189    fn short_burst_after_sync_is_optimizer_step() {
190        let mut det = detector();
191        let mut state = make_state(1, true);
192        state.cpu_burst_duration_ns = 2_000_000; // 2ms
193        state.last_gpu_sync_ns = 1_000_000_000;
194        assert_eq!(det.detect(&state), WorkloadPhase::OptimizerStep);
195    }
196
197    #[test]
198    fn nccl_detection_when_enabled() {
199        let mut det = PhaseDetector::new(PhaseDetectorConfig {
200            nccl_detection_enabled: true,
201            phase_stability_count: 1,
202            ..Default::default()
203        });
204        let mut state = make_state(1, true);
205        state.nccl_active = true;
206        state.futex_wait_count = 10;
207        assert_eq!(det.detect(&state), WorkloadPhase::NcclCollective);
208    }
209
210    #[test]
211    fn nccl_detection_off_by_default() {
212        let mut det = detector();
213        let mut state = make_state(1, true);
214        state.nccl_active = true;
215        state.futex_wait_count = 10;
216        // Should NOT detect as NCCL since detection is disabled
217        assert_ne!(det.detect(&state), WorkloadPhase::NcclCollective);
218    }
219
220    #[test]
221    fn stability_prevents_flapping() {
222        let mut det = PhaseDetector::new(PhaseDetectorConfig {
223            phase_stability_count: 3,
224            ..Default::default()
225        });
226
227        let mut state = make_state(1, true);
228        state.current_phase = WorkloadPhase::Unknown;
229
230        // Set to data loading
231        state.io_wait_fraction = 0.5;
232        state.gpu_utilization = 5;
233
234        // First two detections should keep Unknown (not stable yet)
235        assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
236        assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
237        // Third detection commits the transition
238        assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
239    }
240
241    #[test]
242    fn overlapped_prefetch_detected_as_data_loading() {
243        let mut det = detector();
244        let mut state = make_state(1, true);
245        state.io_wait_fraction = 0.2; // moderate I/O
246        state.gpu_utilization = 50; // moderate GPU (prefetch overlap)
247        assert_eq!(det.detect(&state), WorkloadPhase::DataLoading);
248    }
249
250    #[test]
251    fn transition_counter_increments() {
252        let mut det = detector();
253        let mut state = make_state(1, true);
254        state.current_phase = WorkloadPhase::Unknown;
255        state.io_wait_fraction = 0.5;
256        state.gpu_utilization = 5;
257        det.detect(&state);
258        assert_eq!(det.transition_count, 1);
259    }
260}
261
262#[cfg(test)]
263mod proptests {
264    use super::*;
265    use proptest::prelude::*;
266
267    proptest! {
268        /// Phase detection must never panic for any input combination.
269        #[test]
270        fn detect_never_panics(
271            gpu_util in 0u8..=100,
272            io_wait in 0.0f32..=1.0,
273            cpu_burst in 0u64..=100_000_000,
274            gpu_sync in 0u64..=10_000_000_000u64,
275            nccl in proptest::bool::ANY,
276            futex in 0u32..=1000,
277            is_ml in proptest::bool::ANY,
278        ) {
279            let mut det = PhaseDetector::new(PhaseDetectorConfig {
280                phase_stability_count: 1,
281                nccl_detection_enabled: true,
282                ..Default::default()
283            });
284            let mut state = ZernelTaskState::new(1);
285            state.is_ml_process = is_ml;
286            state.gpu_utilization = gpu_util;
287            state.io_wait_fraction = io_wait;
288            state.cpu_burst_duration_ns = cpu_burst;
289            state.last_gpu_sync_ns = gpu_sync;
290            state.nccl_active = nccl;
291            state.futex_wait_count = futex;
292
293            // Must return a valid phase, never panic
294            let phase = det.detect(&state);
295            match phase {
296                WorkloadPhase::DataLoading
297                | WorkloadPhase::GpuCompute
298                | WorkloadPhase::NcclCollective
299                | WorkloadPhase::OptimizerStep
300                | WorkloadPhase::Unknown => {} // all valid
301            }
302        }
303
304        /// Non-ML processes always get Unknown phase.
305        #[test]
306        fn non_ml_always_unknown(
307            gpu_util in 0u8..=100,
308            io_wait in 0.0f32..=1.0,
309        ) {
310            let mut det = PhaseDetector::new(PhaseDetectorConfig {
311                phase_stability_count: 1,
312                ..Default::default()
313            });
314            let mut state = ZernelTaskState::new(1);
315            state.is_ml_process = false;
316            state.gpu_utilization = gpu_util;
317            state.io_wait_fraction = io_wait;
318
319            assert_eq!(det.detect(&state), WorkloadPhase::Unknown);
320        }
321    }
322}