zernel_scheduler/
multi_tenant.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tracing::debug;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum PriorityClass {
10 Training,
12 Inference,
14 Interactive,
16 Background,
18}
19
20impl PriorityClass {
21 pub fn base_priority(&self) -> i32 {
23 match self {
24 Self::Training => 5,
25 Self::Inference => 3,
26 Self::Interactive => 1,
27 Self::Background => -5,
28 }
29 }
30
31 pub fn from_str(s: &str) -> Self {
32 match s.to_lowercase().as_str() {
33 "training" => Self::Training,
34 "inference" => Self::Inference,
35 "interactive" => Self::Interactive,
36 "background" => Self::Background,
37 _ => Self::Training,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Tenant {
45 pub id: String,
46 pub gpu_count: u32,
47 pub priority_class: PriorityClass,
48 pub cpu_weight: Option<f32>,
50}
51
52pub struct TenantScheduler {
58 tenants: HashMap<String, Tenant>,
59 pid_tenant_map: HashMap<u32, String>,
61 total_gpus: u32,
62}
63
64impl TenantScheduler {
65 pub fn new() -> Self {
66 Self {
67 tenants: HashMap::new(),
68 pid_tenant_map: HashMap::new(),
69 total_gpus: 0,
70 }
71 }
72
73 pub fn register_tenant(&mut self, tenant: Tenant) {
75 self.total_gpus += tenant.gpu_count;
76 debug!(
77 tenant_id = tenant.id,
78 gpus = tenant.gpu_count,
79 class = ?tenant.priority_class,
80 "registered tenant"
81 );
82 self.tenants.insert(tenant.id.clone(), tenant);
83 }
84
85 pub fn unregister_tenant(&mut self, tenant_id: &str) {
87 if let Some(tenant) = self.tenants.remove(tenant_id) {
88 self.total_gpus = self.total_gpus.saturating_sub(tenant.gpu_count);
89 self.pid_tenant_map.retain(|_, tid| tid != tenant_id);
90 }
91 }
92
93 pub fn assign_pid(&mut self, pid: u32, tenant_id: &str) {
95 self.pid_tenant_map.insert(pid, tenant_id.to_string());
96 }
97
98 pub fn cpu_weight_for_pid(&self, pid: u32) -> f32 {
101 let tenant_id = match self.pid_tenant_map.get(&pid) {
102 Some(tid) => tid,
103 None => return 1.0,
104 };
105 let tenant = match self.tenants.get(tenant_id) {
106 Some(t) => t,
107 None => return 1.0,
108 };
109
110 if let Some(w) = tenant.cpu_weight {
112 return w;
113 }
114
115 if self.total_gpus == 0 {
117 return 1.0;
118 }
119
120 (tenant.gpu_count as f32) / (self.total_gpus as f32) * self.tenants.len() as f32
121 }
123
124 pub fn effective_priority(&self, pid: u32, phase_priority: i32) -> i32 {
126 let weight = self.cpu_weight_for_pid(pid);
127 let class_priority = self
128 .pid_tenant_map
129 .get(&pid)
130 .and_then(|tid| self.tenants.get(tid))
131 .map(|t| t.priority_class.base_priority())
132 .unwrap_or(0);
133
134 let weighted = (phase_priority as f32 * weight) as i32;
135 weighted + class_priority
136 }
137
138 pub fn tenant_count(&self) -> usize {
139 self.tenants.len()
140 }
141
142 pub fn get_tenant_for_pid(&self, pid: u32) -> Option<&Tenant> {
143 self.pid_tenant_map
144 .get(&pid)
145 .and_then(|tid| self.tenants.get(tid))
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn gpu_proportional_weight() {
155 let mut ts = TenantScheduler::new();
156 ts.register_tenant(Tenant {
157 id: "job-a".into(),
158 gpu_count: 2,
159 priority_class: PriorityClass::Training,
160 cpu_weight: None,
161 });
162 ts.register_tenant(Tenant {
163 id: "job-b".into(),
164 gpu_count: 6,
165 priority_class: PriorityClass::Training,
166 cpu_weight: None,
167 });
168
169 ts.assign_pid(100, "job-a");
170 ts.assign_pid(200, "job-b");
171
172 let w_a = ts.cpu_weight_for_pid(100);
173 let w_b = ts.cpu_weight_for_pid(200);
174 assert!((w_b / w_a - 3.0).abs() < 0.01);
176 }
177
178 #[test]
179 fn priority_class_affects_priority() {
180 let mut ts = TenantScheduler::new();
181 ts.register_tenant(Tenant {
182 id: "train".into(),
183 gpu_count: 4,
184 priority_class: PriorityClass::Training,
185 cpu_weight: Some(1.0),
186 });
187 ts.register_tenant(Tenant {
188 id: "bg".into(),
189 gpu_count: 4,
190 priority_class: PriorityClass::Background,
191 cpu_weight: Some(1.0),
192 });
193
194 ts.assign_pid(100, "train");
195 ts.assign_pid(200, "bg");
196
197 let p_train = ts.effective_priority(100, 10);
198 let p_bg = ts.effective_priority(200, 10);
199 assert!(p_train > p_bg);
200 }
201
202 #[test]
203 fn unregistered_pid_gets_default_weight() {
204 let ts = TenantScheduler::new();
205 assert_eq!(ts.cpu_weight_for_pid(999), 1.0);
206 }
207
208 #[test]
209 fn unregister_tenant_cleans_up() {
210 let mut ts = TenantScheduler::new();
211 ts.register_tenant(Tenant {
212 id: "job-a".into(),
213 gpu_count: 4,
214 priority_class: PriorityClass::Training,
215 cpu_weight: None,
216 });
217 ts.assign_pid(100, "job-a");
218 assert_eq!(ts.tenant_count(), 1);
219
220 ts.unregister_tenant("job-a");
221 assert_eq!(ts.tenant_count(), 0);
222 assert_eq!(ts.cpu_weight_for_pid(100), 1.0); }
224}