zernel_scheduler/
config.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use crate::phase_detector::PhaseDetectorConfig;
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8/// Top-level scheduler configuration, loaded from /etc/zernel/scheduler.toml.
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10#[serde(default)]
11pub struct SchedulerConfig {
12    #[serde(default)]
13    pub general: GeneralConfig,
14    #[serde(default)]
15    pub phase_detection: PhaseDetectionConfig,
16    #[serde(default)]
17    pub numa: NumaConfig,
18    #[serde(default)]
19    pub multi_tenant: MultiTenantConfig,
20    #[serde(default)]
21    pub telemetry: TelemetryConfig,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(default)]
26pub struct GeneralConfig {
27    /// How often (ms) to re-evaluate task phases.
28    pub phase_eval_interval_ms: u64,
29    /// How often (ms) to poll GPU utilization.
30    pub gpu_poll_interval_ms: u64,
31    /// Maximum number of tracked tasks.
32    pub max_tracked_tasks: usize,
33    /// Log level override (trace, debug, info, warn, error).
34    pub log_level: String,
35}
36
37impl Default for GeneralConfig {
38    fn default() -> Self {
39        Self {
40            phase_eval_interval_ms: 100,
41            gpu_poll_interval_ms: 500,
42            max_tracked_tasks: 65536,
43            log_level: "info".into(),
44        }
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(default)]
50pub struct PhaseDetectionConfig {
51    pub io_wait_threshold: f32,
52    pub optimizer_burst_max_ns: u64,
53    pub gpu_idle_threshold: u8,
54    pub gpu_active_threshold: u8,
55    /// Number of consecutive samples before committing a phase transition.
56    pub phase_stability_count: u32,
57    /// Enable NCCL collective detection (requires BPF probes).
58    pub nccl_detection_enabled: bool,
59}
60
61impl Default for PhaseDetectionConfig {
62    fn default() -> Self {
63        Self {
64            io_wait_threshold: 0.3,
65            optimizer_burst_max_ns: 5_000_000,
66            gpu_idle_threshold: 10,
67            gpu_active_threshold: 80,
68            phase_stability_count: 3,
69            nccl_detection_enabled: false,
70        }
71    }
72}
73
74impl From<&PhaseDetectionConfig> for PhaseDetectorConfig {
75    fn from(cfg: &PhaseDetectionConfig) -> Self {
76        PhaseDetectorConfig {
77            io_wait_threshold: cfg.io_wait_threshold,
78            optimizer_burst_max_ns: cfg.optimizer_burst_max_ns,
79            gpu_idle_threshold: cfg.gpu_idle_threshold,
80            gpu_active_threshold: cfg.gpu_active_threshold,
81            phase_stability_count: cfg.phase_stability_count,
82            nccl_detection_enabled: cfg.nccl_detection_enabled,
83        }
84    }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(default)]
89pub struct NumaConfig {
90    /// Enable NUMA-aware CPU selection for ML tasks.
91    pub enabled: bool,
92    /// Prefer CPUs on the same NUMA node as the task's GPU.
93    pub gpu_affinity: bool,
94    /// Prefer CPUs on the same NUMA node as the task's memory.
95    pub memory_affinity: bool,
96}
97
98impl Default for NumaConfig {
99    fn default() -> Self {
100        Self {
101            enabled: true,
102            gpu_affinity: true,
103            memory_affinity: true,
104        }
105    }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(default)]
110pub struct MultiTenantConfig {
111    /// Enable multi-tenant scheduling (GPU-proportional CPU allocation).
112    pub enabled: bool,
113    /// Default priority class for new tasks.
114    pub default_priority_class: String,
115}
116
117impl Default for MultiTenantConfig {
118    fn default() -> Self {
119        Self {
120            enabled: false,
121            default_priority_class: "normal".into(),
122        }
123    }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(default)]
128pub struct TelemetryConfig {
129    /// Expose scheduler metrics on this port (0 = disabled).
130    pub metrics_port: u16,
131    /// Push interval (ms) for telemetry export.
132    pub push_interval_ms: u64,
133}
134
135impl Default for TelemetryConfig {
136    fn default() -> Self {
137        Self {
138            metrics_port: 9093,
139            push_interval_ms: 1000,
140        }
141    }
142}
143
144// Default is derived via #[serde(default)] on the struct fields.
145
146impl SchedulerConfig {
147    /// Load config from a TOML file, falling back to defaults for missing fields.
148    pub fn load(path: &Path) -> Result<Self> {
149        if path.exists() {
150            let content = std::fs::read_to_string(path)?;
151            let config: Self = toml::from_str(&content)?;
152            Ok(config)
153        } else {
154            Ok(Self::default())
155        }
156    }
157
158    /// Generate a default config file.
159    pub fn to_toml(&self) -> Result<String> {
160        Ok(toml::to_string_pretty(self)?)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn default_config_serializes() {
170        let config = SchedulerConfig::default();
171        let toml = config.to_toml().unwrap();
172        assert!(toml.contains("phase_eval_interval_ms"));
173        assert!(toml.contains("gpu_affinity"));
174    }
175
176    #[test]
177    fn load_missing_file_returns_defaults() {
178        let config = SchedulerConfig::load(Path::new("/nonexistent/path.toml")).unwrap();
179        assert_eq!(config.general.phase_eval_interval_ms, 100);
180    }
181
182    #[test]
183    fn partial_toml_fills_defaults() {
184        let toml = r#"
185[general]
186phase_eval_interval_ms = 200
187"#;
188        let config: SchedulerConfig = toml::from_str(toml).unwrap();
189        assert_eq!(config.general.phase_eval_interval_ms, 200);
190        assert_eq!(config.general.gpu_poll_interval_ms, 500); // default
191        assert!(config.numa.enabled); // default
192    }
193}