zernel_scheduler/
task_state.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum WorkloadPhase {
8 DataLoading,
10 GpuCompute,
12 NcclCollective,
14 OptimizerStep,
16 #[default]
18 Unknown,
19}
20
21impl std::fmt::Display for WorkloadPhase {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::DataLoading => write!(f, "DataLoading"),
27 Self::GpuCompute => write!(f, "GpuCompute"),
28 Self::NcclCollective => write!(f, "NcclCollective"),
29 Self::OptimizerStep => write!(f, "OptimizerStep"),
30 Self::Unknown => write!(f, "Unknown"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ZernelTaskState {
40 pub pid: u32,
41 pub is_ml_process: bool,
42 pub current_phase: WorkloadPhase,
43 pub gpu_utilization: u8,
45 pub last_gpu_sync_ns: u64,
47 pub cpu_burst_duration_ns: u64,
49 pub io_wait_fraction: f32,
51 pub nccl_active: bool,
53 pub futex_wait_count: u32,
55 pub gpu_id: Option<u32>,
57 pub phase_time_ns: PhaseTimeAccumulator,
59 pub phase_start_ns: u64,
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct PhaseTimeAccumulator {
66 pub data_loading_ns: u64,
67 pub gpu_compute_ns: u64,
68 pub nccl_collective_ns: u64,
69 pub optimizer_step_ns: u64,
70 pub unknown_ns: u64,
71}
72
73impl PhaseTimeAccumulator {
74 pub fn record(&mut self, phase: WorkloadPhase, duration_ns: u64) {
75 match phase {
76 WorkloadPhase::DataLoading => self.data_loading_ns += duration_ns,
77 WorkloadPhase::GpuCompute => self.gpu_compute_ns += duration_ns,
78 WorkloadPhase::NcclCollective => self.nccl_collective_ns += duration_ns,
79 WorkloadPhase::OptimizerStep => self.optimizer_step_ns += duration_ns,
80 WorkloadPhase::Unknown => self.unknown_ns += duration_ns,
81 }
82 }
83
84 pub fn total_ns(&self) -> u64 {
85 self.data_loading_ns
86 + self.gpu_compute_ns
87 + self.nccl_collective_ns
88 + self.optimizer_step_ns
89 + self.unknown_ns
90 }
91
92 pub fn phase_fraction(&self, phase: WorkloadPhase) -> f64 {
94 let total = self.total_ns();
95 if total == 0 {
96 return 0.0;
97 }
98 let phase_ns = match phase {
99 WorkloadPhase::DataLoading => self.data_loading_ns,
100 WorkloadPhase::GpuCompute => self.gpu_compute_ns,
101 WorkloadPhase::NcclCollective => self.nccl_collective_ns,
102 WorkloadPhase::OptimizerStep => self.optimizer_step_ns,
103 WorkloadPhase::Unknown => self.unknown_ns,
104 };
105 phase_ns as f64 / total as f64
106 }
107}
108
109impl ZernelTaskState {
110 pub fn new(pid: u32) -> Self {
111 Self {
112 pid,
113 is_ml_process: false,
114 current_phase: WorkloadPhase::Unknown,
115 gpu_utilization: 0,
116 last_gpu_sync_ns: 0,
117 cpu_burst_duration_ns: 0,
118 io_wait_fraction: 0.0,
119 nccl_active: false,
120 futex_wait_count: 0,
121 gpu_id: None,
122 phase_time_ns: PhaseTimeAccumulator::default(),
123 phase_start_ns: 0,
124 }
125 }
126
127 pub fn transition_phase(&mut self, new_phase: WorkloadPhase, now_ns: u64) {
129 if self.phase_start_ns > 0 && now_ns > self.phase_start_ns {
130 let duration = now_ns - self.phase_start_ns;
131 self.phase_time_ns.record(self.current_phase, duration);
132 }
133 self.current_phase = new_phase;
134 self.phase_start_ns = now_ns;
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn phase_time_accumulation() {
144 let mut state = ZernelTaskState::new(1);
145 state.transition_phase(WorkloadPhase::DataLoading, 1000);
146 state.transition_phase(WorkloadPhase::GpuCompute, 5000);
147 state.transition_phase(WorkloadPhase::OptimizerStep, 8000);
148
149 assert_eq!(state.phase_time_ns.data_loading_ns, 4000);
150 assert_eq!(state.phase_time_ns.gpu_compute_ns, 3000);
151 }
152
153 #[test]
154 fn phase_fraction_calculation() {
155 let mut acc = PhaseTimeAccumulator::default();
156 acc.data_loading_ns = 200;
157 acc.gpu_compute_ns = 800;
158 assert!((acc.phase_fraction(WorkloadPhase::GpuCompute) - 0.8).abs() < 0.01);
159 assert!((acc.phase_fraction(WorkloadPhase::DataLoading) - 0.2).abs() < 0.01);
160 }
161
162 #[test]
163 fn display_impl() {
164 assert_eq!(WorkloadPhase::DataLoading.to_string(), "DataLoading");
165 assert_eq!(WorkloadPhase::NcclCollective.to_string(), "NcclCollective");
166 }
167}