zernel_scheduler/
multi_tenant.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tracing::debug;
6
7/// Priority class for multi-tenant scheduling.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum PriorityClass {
10    /// Training jobs — highest resource allocation.
11    Training,
12    /// Inference serving — latency-sensitive but lower throughput needs.
13    Inference,
14    /// Interactive notebooks — responsive but best-effort.
15    Interactive,
16    /// Background — batch jobs, preprocessing, lowest priority.
17    Background,
18}
19
20impl PriorityClass {
21    /// Base priority modifier for this class.
22    pub fn base_priority(&self) -> i32 {
23        match self {
24            Self::Training => 5,
25            Self::Inference => 3,
26            Self::Interactive => 1,
27            Self::Background => -5,
28        }
29    }
30
31    pub fn from_str(s: &str) -> Self {
32        match s.to_lowercase().as_str() {
33            "training" => Self::Training,
34            "inference" => Self::Inference,
35            "interactive" => Self::Interactive,
36            "background" => Self::Background,
37            _ => Self::Training,
38        }
39    }
40}
41
42/// A tenant represents a user or job group sharing GPU server resources.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Tenant {
45    pub id: String,
46    pub gpu_count: u32,
47    pub priority_class: PriorityClass,
48    /// Explicit CPU weight override (None = auto from GPU count).
49    pub cpu_weight: Option<f32>,
50}
51
52/// Multi-tenant scheduler that enforces GPU-proportional CPU allocation.
53///
54/// If Job A has 2 GPUs and Job B has 6 GPUs, Job B gets 3x the CPU scheduling
55/// weight for data loading. This prevents one tenant's data loading from
56/// starving another's NCCL collectives.
57pub struct TenantScheduler {
58    tenants: HashMap<String, Tenant>,
59    /// Maps pid -> tenant_id for quick lookup.
60    pid_tenant_map: HashMap<u32, String>,
61    total_gpus: u32,
62}
63
64impl TenantScheduler {
65    pub fn new() -> Self {
66        Self {
67            tenants: HashMap::new(),
68            pid_tenant_map: HashMap::new(),
69            total_gpus: 0,
70        }
71    }
72
73    /// Register a new tenant.
74    pub fn register_tenant(&mut self, tenant: Tenant) {
75        self.total_gpus += tenant.gpu_count;
76        debug!(
77            tenant_id = tenant.id,
78            gpus = tenant.gpu_count,
79            class = ?tenant.priority_class,
80            "registered tenant"
81        );
82        self.tenants.insert(tenant.id.clone(), tenant);
83    }
84
85    /// Remove a tenant.
86    pub fn unregister_tenant(&mut self, tenant_id: &str) {
87        if let Some(tenant) = self.tenants.remove(tenant_id) {
88            self.total_gpus = self.total_gpus.saturating_sub(tenant.gpu_count);
89            self.pid_tenant_map.retain(|_, tid| tid != tenant_id);
90        }
91    }
92
93    /// Associate a process with a tenant.
94    pub fn assign_pid(&mut self, pid: u32, tenant_id: &str) {
95        self.pid_tenant_map.insert(pid, tenant_id.to_string());
96    }
97
98    /// Get the CPU weight for a given pid.
99    /// Weight is proportional to the tenant's GPU share.
100    pub fn cpu_weight_for_pid(&self, pid: u32) -> f32 {
101        let tenant_id = match self.pid_tenant_map.get(&pid) {
102            Some(tid) => tid,
103            None => return 1.0,
104        };
105        let tenant = match self.tenants.get(tenant_id) {
106            Some(t) => t,
107            None => return 1.0,
108        };
109
110        // Explicit override
111        if let Some(w) = tenant.cpu_weight {
112            return w;
113        }
114
115        // GPU-proportional weight
116        if self.total_gpus == 0 {
117            return 1.0;
118        }
119
120        (tenant.gpu_count as f32) / (self.total_gpus as f32) * self.tenants.len() as f32
121        // normalize so average weight = 1.0
122    }
123
124    /// Compute effective priority for a pid, combining phase priority with tenant weight.
125    pub fn effective_priority(&self, pid: u32, phase_priority: i32) -> i32 {
126        let weight = self.cpu_weight_for_pid(pid);
127        let class_priority = self
128            .pid_tenant_map
129            .get(&pid)
130            .and_then(|tid| self.tenants.get(tid))
131            .map(|t| t.priority_class.base_priority())
132            .unwrap_or(0);
133
134        let weighted = (phase_priority as f32 * weight) as i32;
135        weighted + class_priority
136    }
137
138    pub fn tenant_count(&self) -> usize {
139        self.tenants.len()
140    }
141
142    pub fn get_tenant_for_pid(&self, pid: u32) -> Option<&Tenant> {
143        self.pid_tenant_map
144            .get(&pid)
145            .and_then(|tid| self.tenants.get(tid))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn gpu_proportional_weight() {
155        let mut ts = TenantScheduler::new();
156        ts.register_tenant(Tenant {
157            id: "job-a".into(),
158            gpu_count: 2,
159            priority_class: PriorityClass::Training,
160            cpu_weight: None,
161        });
162        ts.register_tenant(Tenant {
163            id: "job-b".into(),
164            gpu_count: 6,
165            priority_class: PriorityClass::Training,
166            cpu_weight: None,
167        });
168
169        ts.assign_pid(100, "job-a");
170        ts.assign_pid(200, "job-b");
171
172        let w_a = ts.cpu_weight_for_pid(100);
173        let w_b = ts.cpu_weight_for_pid(200);
174        // Job B should get 3x the weight of Job A
175        assert!((w_b / w_a - 3.0).abs() < 0.01);
176    }
177
178    #[test]
179    fn priority_class_affects_priority() {
180        let mut ts = TenantScheduler::new();
181        ts.register_tenant(Tenant {
182            id: "train".into(),
183            gpu_count: 4,
184            priority_class: PriorityClass::Training,
185            cpu_weight: Some(1.0),
186        });
187        ts.register_tenant(Tenant {
188            id: "bg".into(),
189            gpu_count: 4,
190            priority_class: PriorityClass::Background,
191            cpu_weight: Some(1.0),
192        });
193
194        ts.assign_pid(100, "train");
195        ts.assign_pid(200, "bg");
196
197        let p_train = ts.effective_priority(100, 10);
198        let p_bg = ts.effective_priority(200, 10);
199        assert!(p_train > p_bg);
200    }
201
202    #[test]
203    fn unregistered_pid_gets_default_weight() {
204        let ts = TenantScheduler::new();
205        assert_eq!(ts.cpu_weight_for_pid(999), 1.0);
206    }
207
208    #[test]
209    fn unregister_tenant_cleans_up() {
210        let mut ts = TenantScheduler::new();
211        ts.register_tenant(Tenant {
212            id: "job-a".into(),
213            gpu_count: 4,
214            priority_class: PriorityClass::Training,
215            cpu_weight: None,
216        });
217        ts.assign_pid(100, "job-a");
218        assert_eq!(ts.tenant_count(), 1);
219
220        ts.unregister_tenant("job-a");
221        assert_eq!(ts.tenant_count(), 0);
222        assert_eq!(ts.cpu_weight_for_pid(100), 1.0); // cleaned up
223    }
224}