zernel_scheduler/
config.rs1use crate::phase_detector::PhaseDetectorConfig;
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10#[serde(default)]
11pub struct SchedulerConfig {
12 #[serde(default)]
13 pub general: GeneralConfig,
14 #[serde(default)]
15 pub phase_detection: PhaseDetectionConfig,
16 #[serde(default)]
17 pub numa: NumaConfig,
18 #[serde(default)]
19 pub multi_tenant: MultiTenantConfig,
20 #[serde(default)]
21 pub telemetry: TelemetryConfig,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(default)]
26pub struct GeneralConfig {
27 pub phase_eval_interval_ms: u64,
29 pub gpu_poll_interval_ms: u64,
31 pub max_tracked_tasks: usize,
33 pub log_level: String,
35}
36
37impl Default for GeneralConfig {
38 fn default() -> Self {
39 Self {
40 phase_eval_interval_ms: 100,
41 gpu_poll_interval_ms: 500,
42 max_tracked_tasks: 65536,
43 log_level: "info".into(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(default)]
50pub struct PhaseDetectionConfig {
51 pub io_wait_threshold: f32,
52 pub optimizer_burst_max_ns: u64,
53 pub gpu_idle_threshold: u8,
54 pub gpu_active_threshold: u8,
55 pub phase_stability_count: u32,
57 pub nccl_detection_enabled: bool,
59}
60
61impl Default for PhaseDetectionConfig {
62 fn default() -> Self {
63 Self {
64 io_wait_threshold: 0.3,
65 optimizer_burst_max_ns: 5_000_000,
66 gpu_idle_threshold: 10,
67 gpu_active_threshold: 80,
68 phase_stability_count: 3,
69 nccl_detection_enabled: false,
70 }
71 }
72}
73
74impl From<&PhaseDetectionConfig> for PhaseDetectorConfig {
75 fn from(cfg: &PhaseDetectionConfig) -> Self {
76 PhaseDetectorConfig {
77 io_wait_threshold: cfg.io_wait_threshold,
78 optimizer_burst_max_ns: cfg.optimizer_burst_max_ns,
79 gpu_idle_threshold: cfg.gpu_idle_threshold,
80 gpu_active_threshold: cfg.gpu_active_threshold,
81 phase_stability_count: cfg.phase_stability_count,
82 nccl_detection_enabled: cfg.nccl_detection_enabled,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(default)]
89pub struct NumaConfig {
90 pub enabled: bool,
92 pub gpu_affinity: bool,
94 pub memory_affinity: bool,
96}
97
98impl Default for NumaConfig {
99 fn default() -> Self {
100 Self {
101 enabled: true,
102 gpu_affinity: true,
103 memory_affinity: true,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(default)]
110pub struct MultiTenantConfig {
111 pub enabled: bool,
113 pub default_priority_class: String,
115}
116
117impl Default for MultiTenantConfig {
118 fn default() -> Self {
119 Self {
120 enabled: false,
121 default_priority_class: "normal".into(),
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(default)]
128pub struct TelemetryConfig {
129 pub metrics_port: u16,
131 pub push_interval_ms: u64,
133}
134
135impl Default for TelemetryConfig {
136 fn default() -> Self {
137 Self {
138 metrics_port: 9093,
139 push_interval_ms: 1000,
140 }
141 }
142}
143
144impl SchedulerConfig {
147 pub fn load(path: &Path) -> Result<Self> {
149 if path.exists() {
150 let content = std::fs::read_to_string(path)?;
151 let config: Self = toml::from_str(&content)?;
152 Ok(config)
153 } else {
154 Ok(Self::default())
155 }
156 }
157
158 pub fn to_toml(&self) -> Result<String> {
160 Ok(toml::to_string_pretty(self)?)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn default_config_serializes() {
170 let config = SchedulerConfig::default();
171 let toml = config.to_toml().unwrap();
172 assert!(toml.contains("phase_eval_interval_ms"));
173 assert!(toml.contains("gpu_affinity"));
174 }
175
176 #[test]
177 fn load_missing_file_returns_defaults() {
178 let config = SchedulerConfig::load(Path::new("/nonexistent/path.toml")).unwrap();
179 assert_eq!(config.general.phase_eval_interval_ms, 100);
180 }
181
182 #[test]
183 fn partial_toml_fills_defaults() {
184 let toml = r#"
185[general]
186phase_eval_interval_ms = 200
187"#;
188 let config: SchedulerConfig = toml::from_str(toml).unwrap();
189 assert_eq!(config.general.phase_eval_interval_ms, 200);
190 assert_eq!(config.general.gpu_poll_interval_ms, 500); assert!(config.numa.enabled); }
193}