zernel_scheduler/
task_state.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use serde::{Deserialize, Serialize};
4
5/// The detected phase of an ML workload.
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum WorkloadPhase {
8    /// High I/O, many threads, CPU-intensive preprocessing.
9    DataLoading,
10    /// CPU idle, waiting on GPU compute (cudaDeviceSynchronize).
11    GpuCompute,
12    /// Network coordination for collective operations (NCCL).
13    NcclCollective,
14    /// Short CPU burst after GPU compute — optimizer step.
15    OptimizerStep,
16    /// Unknown phase — fall back to CFS-equivalent behavior.
17    #[default]
18    Unknown,
19}
20
21// Default derived automatically — Unknown is the first variant via #[default].
22
23impl std::fmt::Display for WorkloadPhase {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::DataLoading => write!(f, "DataLoading"),
27            Self::GpuCompute => write!(f, "GpuCompute"),
28            Self::NcclCollective => write!(f, "NcclCollective"),
29            Self::OptimizerStep => write!(f, "OptimizerStep"),
30            Self::Unknown => write!(f, "Unknown"),
31        }
32    }
33}
34
35/// Per-task state tracked by the Zernel scheduler.
36///
37/// Maintained in BPF maps at runtime; this Rust struct is the userspace mirror.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ZernelTaskState {
40    pub pid: u32,
41    pub is_ml_process: bool,
42    pub current_phase: WorkloadPhase,
43    /// GPU utilization percentage (0-100).
44    pub gpu_utilization: u8,
45    /// Timestamp (ns) of last cudaDeviceSynchronize call.
46    pub last_gpu_sync_ns: u64,
47    /// Duration (ns) of the most recent CPU burst.
48    pub cpu_burst_duration_ns: u64,
49    /// Fraction of time spent in I/O wait (0.0 - 1.0).
50    pub io_wait_fraction: f32,
51    /// Whether NCCL shared memory is mapped for this process.
52    pub nccl_active: bool,
53    /// Recent futex wait count (high = collective coordination).
54    pub futex_wait_count: u32,
55    /// GPU ID this task is primarily using (for NUMA affinity).
56    pub gpu_id: Option<u32>,
57    /// Total time (ns) spent in each phase since tracking started.
58    pub phase_time_ns: PhaseTimeAccumulator,
59    /// Timestamp (ns) when current phase began.
60    pub phase_start_ns: u64,
61}
62
63/// Accumulated time spent in each phase for telemetry.
64#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct PhaseTimeAccumulator {
66    pub data_loading_ns: u64,
67    pub gpu_compute_ns: u64,
68    pub nccl_collective_ns: u64,
69    pub optimizer_step_ns: u64,
70    pub unknown_ns: u64,
71}
72
73impl PhaseTimeAccumulator {
74    pub fn record(&mut self, phase: WorkloadPhase, duration_ns: u64) {
75        match phase {
76            WorkloadPhase::DataLoading => self.data_loading_ns += duration_ns,
77            WorkloadPhase::GpuCompute => self.gpu_compute_ns += duration_ns,
78            WorkloadPhase::NcclCollective => self.nccl_collective_ns += duration_ns,
79            WorkloadPhase::OptimizerStep => self.optimizer_step_ns += duration_ns,
80            WorkloadPhase::Unknown => self.unknown_ns += duration_ns,
81        }
82    }
83
84    pub fn total_ns(&self) -> u64 {
85        self.data_loading_ns
86            + self.gpu_compute_ns
87            + self.nccl_collective_ns
88            + self.optimizer_step_ns
89            + self.unknown_ns
90    }
91
92    /// Fraction of time in a given phase (0.0 - 1.0).
93    pub fn phase_fraction(&self, phase: WorkloadPhase) -> f64 {
94        let total = self.total_ns();
95        if total == 0 {
96            return 0.0;
97        }
98        let phase_ns = match phase {
99            WorkloadPhase::DataLoading => self.data_loading_ns,
100            WorkloadPhase::GpuCompute => self.gpu_compute_ns,
101            WorkloadPhase::NcclCollective => self.nccl_collective_ns,
102            WorkloadPhase::OptimizerStep => self.optimizer_step_ns,
103            WorkloadPhase::Unknown => self.unknown_ns,
104        };
105        phase_ns as f64 / total as f64
106    }
107}
108
109impl ZernelTaskState {
110    pub fn new(pid: u32) -> Self {
111        Self {
112            pid,
113            is_ml_process: false,
114            current_phase: WorkloadPhase::Unknown,
115            gpu_utilization: 0,
116            last_gpu_sync_ns: 0,
117            cpu_burst_duration_ns: 0,
118            io_wait_fraction: 0.0,
119            nccl_active: false,
120            futex_wait_count: 0,
121            gpu_id: None,
122            phase_time_ns: PhaseTimeAccumulator::default(),
123            phase_start_ns: 0,
124        }
125    }
126
127    /// Transition to a new phase, accumulating time in the old phase.
128    pub fn transition_phase(&mut self, new_phase: WorkloadPhase, now_ns: u64) {
129        if self.phase_start_ns > 0 && now_ns > self.phase_start_ns {
130            let duration = now_ns - self.phase_start_ns;
131            self.phase_time_ns.record(self.current_phase, duration);
132        }
133        self.current_phase = new_phase;
134        self.phase_start_ns = now_ns;
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn phase_time_accumulation() {
144        let mut state = ZernelTaskState::new(1);
145        state.transition_phase(WorkloadPhase::DataLoading, 1000);
146        state.transition_phase(WorkloadPhase::GpuCompute, 5000);
147        state.transition_phase(WorkloadPhase::OptimizerStep, 8000);
148
149        assert_eq!(state.phase_time_ns.data_loading_ns, 4000);
150        assert_eq!(state.phase_time_ns.gpu_compute_ns, 3000);
151    }
152
153    #[test]
154    fn phase_fraction_calculation() {
155        let mut acc = PhaseTimeAccumulator::default();
156        acc.data_loading_ns = 200;
157        acc.gpu_compute_ns = 800;
158        assert!((acc.phase_fraction(WorkloadPhase::GpuCompute) - 0.8).abs() < 0.01);
159        assert!((acc.phase_fraction(WorkloadPhase::DataLoading) - 0.2).abs() < 0.01);
160    }
161
162    #[test]
163    fn display_impl() {
164        assert_eq!(WorkloadPhase::DataLoading.to_string(), "DataLoading");
165        assert_eq!(WorkloadPhase::NcclCollective.to_string(), "NcclCollective");
166    }
167}