1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct NumaTopology {
9 pub nodes: Vec<NumaNode>,
10 pub gpu_node_map: HashMap<u32, u32>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct NumaNode {
16 pub node_id: u32,
17 pub cpu_ids: Vec<u32>,
18 pub memory_mb: u64,
19}
20
21impl NumaTopology {
22 pub fn detect() -> Self {
25 #[cfg(target_os = "linux")]
27 if let Ok(topo) = Self::detect_linux() {
28 return topo;
29 }
30
31 let num_cpus = std::thread::available_parallelism()
33 .map(|n| n.get() as u32)
34 .unwrap_or(1);
35
36 Self {
37 nodes: vec![NumaNode {
38 node_id: 0,
39 cpu_ids: (0..num_cpus).collect(),
40 memory_mb: 0, }],
42 gpu_node_map: HashMap::new(),
43 }
44 }
45
46 #[cfg(target_os = "linux")]
47 fn detect_linux() -> anyhow::Result<Self> {
48 use std::fs;
49
50 let mut nodes = Vec::new();
51 let node_base = std::path::Path::new("/sys/devices/system/node");
52
53 if !node_base.exists() {
54 anyhow::bail!("no NUMA sysfs");
55 }
56
57 for entry in fs::read_dir(node_base)? {
58 let entry = entry?;
59 let name = entry.file_name().to_string_lossy().to_string();
60 if !name.starts_with("node") {
61 continue;
62 }
63 let node_id: u32 = name.trim_start_matches("node").parse().unwrap_or(0);
64
65 let cpulist_path = entry.path().join("cpulist");
67 let cpu_ids = if cpulist_path.exists() {
68 parse_cpu_list(&fs::read_to_string(cpulist_path).unwrap_or_default())
69 } else {
70 vec![]
71 };
72
73 let meminfo_path = entry.path().join("meminfo");
75 let memory_mb = if meminfo_path.exists() {
76 parse_node_meminfo(&fs::read_to_string(meminfo_path).unwrap_or_default())
77 } else {
78 0
79 };
80
81 nodes.push(NumaNode {
82 node_id,
83 cpu_ids,
84 memory_mb,
85 });
86 }
87
88 nodes.sort_by_key(|n| n.node_id);
89
90 let gpu_node_map = detect_gpu_numa_map();
92
93 Ok(Self {
94 nodes,
95 gpu_node_map,
96 })
97 }
98
99 pub fn gpu_numa_node(&self, gpu_id: u32) -> Option<u32> {
101 self.gpu_node_map.get(&gpu_id).copied()
102 }
103
104 pub fn cpus_for_gpu(&self, gpu_id: u32) -> Vec<u32> {
106 if let Some(node_id) = self.gpu_numa_node(gpu_id) {
107 self.nodes
108 .iter()
109 .find(|n| n.node_id == node_id)
110 .map(|n| n.cpu_ids.clone())
111 .unwrap_or_default()
112 } else {
113 self.nodes
115 .iter()
116 .flat_map(|n| n.cpu_ids.iter().copied())
117 .collect()
118 }
119 }
120
121 pub fn select_cpu(&self, gpu_id: Option<u32>, cpu_loads: &HashMap<u32, f32>) -> u32 {
124 let preferred_cpus = match gpu_id {
125 Some(gid) => self.cpus_for_gpu(gid),
126 None => self
127 .nodes
128 .iter()
129 .flat_map(|n| n.cpu_ids.iter().copied())
130 .collect(),
131 };
132
133 preferred_cpus
135 .iter()
136 .copied()
137 .min_by(|a, b| {
138 let la = cpu_loads.get(a).unwrap_or(&0.0);
139 let lb = cpu_loads.get(b).unwrap_or(&0.0);
140 la.partial_cmp(lb).unwrap_or(std::cmp::Ordering::Equal)
141 })
142 .unwrap_or(0)
143 }
144
145 pub fn total_cpus(&self) -> usize {
146 self.nodes.iter().map(|n| n.cpu_ids.len()).sum()
147 }
148}
149
150fn parse_cpu_list(s: &str) -> Vec<u32> {
152 let mut cpus = Vec::new();
153 for part in s.trim().split(',') {
154 let part = part.trim();
155 if part.is_empty() {
156 continue;
157 }
158 if let Some((start, end)) = part.split_once('-') {
159 if let (Ok(s), Ok(e)) = (start.parse::<u32>(), end.parse::<u32>()) {
160 cpus.extend(s..=e);
161 }
162 } else if let Ok(cpu) = part.parse::<u32>() {
163 cpus.push(cpu);
164 }
165 }
166 cpus
167}
168
169fn parse_node_meminfo(s: &str) -> u64 {
171 for line in s.lines() {
172 if line.contains("MemTotal") {
173 let parts: Vec<&str> = line.split_whitespace().collect();
175 if parts.len() >= 4 {
176 if let Ok(kb) = parts[3].parse::<u64>() {
177 return kb / 1024; }
179 }
180 }
181 }
182 0
183}
184
185#[cfg(target_os = "linux")]
187fn detect_gpu_numa_map() -> HashMap<u32, u32> {
188 let mut map = HashMap::new();
189 if let Ok(entries) = std::fs::read_dir("/sys/bus/pci/devices") {
192 let mut gpu_idx = 0u32;
193 for entry in entries.flatten() {
194 let vendor_path = entry.path().join("vendor");
195 if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
196 if vendor.trim() == "0x10de" {
198 let class_path = entry.path().join("class");
199 if let Ok(class) = std::fs::read_to_string(&class_path) {
200 if class.trim().starts_with("0x0302") || class.trim().starts_with("0x0300")
202 {
203 let numa_path = entry.path().join("numa_node");
204 if let Ok(numa) = std::fs::read_to_string(&numa_path) {
205 if let Ok(node) = numa.trim().parse::<i32>() {
206 if node >= 0 {
207 map.insert(gpu_idx, node as u32);
208 }
209 }
210 }
211 gpu_idx += 1;
212 }
213 }
214 }
215 }
216 }
217 }
218 map
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn parse_cpu_list_simple() {
227 assert_eq!(parse_cpu_list("0-3"), vec![0, 1, 2, 3]);
228 }
229
230 #[test]
231 fn parse_cpu_list_ranges_and_singles() {
232 assert_eq!(parse_cpu_list("0-2,5,8-9"), vec![0, 1, 2, 5, 8, 9]);
233 }
234
235 #[test]
236 fn parse_cpu_list_empty() {
237 assert_eq!(parse_cpu_list(""), Vec::<u32>::new());
238 }
239
240 #[test]
241 fn single_node_topology() {
242 let topo = NumaTopology::detect();
243 assert!(!topo.nodes.is_empty());
244 assert!(topo.total_cpus() > 0);
245 }
246
247 #[test]
248 fn select_cpu_picks_lowest_load() {
249 let topo = NumaTopology {
250 nodes: vec![NumaNode {
251 node_id: 0,
252 cpu_ids: vec![0, 1, 2, 3],
253 memory_mb: 32768,
254 }],
255 gpu_node_map: HashMap::from([(0, 0)]),
256 };
257 let loads = HashMap::from([(0, 0.9), (1, 0.2), (2, 0.5), (3, 0.8)]);
258 assert_eq!(topo.select_cpu(Some(0), &loads), 1);
259 }
260
261 #[test]
262 fn cpus_for_gpu_with_mapping() {
263 let topo = NumaTopology {
264 nodes: vec![
265 NumaNode {
266 node_id: 0,
267 cpu_ids: vec![0, 1, 2, 3],
268 memory_mb: 16384,
269 },
270 NumaNode {
271 node_id: 1,
272 cpu_ids: vec![4, 5, 6, 7],
273 memory_mb: 16384,
274 },
275 ],
276 gpu_node_map: HashMap::from([(0, 0), (1, 1)]),
277 };
278 assert_eq!(topo.cpus_for_gpu(0), vec![0, 1, 2, 3]);
279 assert_eq!(topo.cpus_for_gpu(1), vec![4, 5, 6, 7]);
280 }
281}