zernel_scheduler/
numa.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// NUMA topology information for a system.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct NumaTopology {
9    pub nodes: Vec<NumaNode>,
10    /// GPU-to-NUMA-node mapping (gpu_id -> node_id).
11    pub gpu_node_map: HashMap<u32, u32>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct NumaNode {
16    pub node_id: u32,
17    pub cpu_ids: Vec<u32>,
18    pub memory_mb: u64,
19}
20
21impl NumaTopology {
22    /// Detect NUMA topology from the system.
23    /// On non-Linux or single-node systems, returns a single flat node.
24    pub fn detect() -> Self {
25        // Try to read from /sys/devices/system/node/ on Linux
26        #[cfg(target_os = "linux")]
27        if let Ok(topo) = Self::detect_linux() {
28            return topo;
29        }
30
31        // Fallback: single node with all available CPUs
32        let num_cpus = std::thread::available_parallelism()
33            .map(|n| n.get() as u32)
34            .unwrap_or(1);
35
36        Self {
37            nodes: vec![NumaNode {
38                node_id: 0,
39                cpu_ids: (0..num_cpus).collect(),
40                memory_mb: 0, // unknown
41            }],
42            gpu_node_map: HashMap::new(),
43        }
44    }
45
46    #[cfg(target_os = "linux")]
47    fn detect_linux() -> anyhow::Result<Self> {
48        use std::fs;
49
50        let mut nodes = Vec::new();
51        let node_base = std::path::Path::new("/sys/devices/system/node");
52
53        if !node_base.exists() {
54            anyhow::bail!("no NUMA sysfs");
55        }
56
57        for entry in fs::read_dir(node_base)? {
58            let entry = entry?;
59            let name = entry.file_name().to_string_lossy().to_string();
60            if !name.starts_with("node") {
61                continue;
62            }
63            let node_id: u32 = name.trim_start_matches("node").parse().unwrap_or(0);
64
65            // Read CPU list
66            let cpulist_path = entry.path().join("cpulist");
67            let cpu_ids = if cpulist_path.exists() {
68                parse_cpu_list(&fs::read_to_string(cpulist_path).unwrap_or_default())
69            } else {
70                vec![]
71            };
72
73            // Read memory info
74            let meminfo_path = entry.path().join("meminfo");
75            let memory_mb = if meminfo_path.exists() {
76                parse_node_meminfo(&fs::read_to_string(meminfo_path).unwrap_or_default())
77            } else {
78                0
79            };
80
81            nodes.push(NumaNode {
82                node_id,
83                cpu_ids,
84                memory_mb,
85            });
86        }
87
88        nodes.sort_by_key(|n| n.node_id);
89
90        // Detect GPU NUMA affinity via nvidia-smi or sysfs
91        let gpu_node_map = detect_gpu_numa_map();
92
93        Ok(Self {
94            nodes,
95            gpu_node_map,
96        })
97    }
98
99    /// Get the NUMA node that a GPU is connected to.
100    pub fn gpu_numa_node(&self, gpu_id: u32) -> Option<u32> {
101        self.gpu_node_map.get(&gpu_id).copied()
102    }
103
104    /// Get CPUs on the same NUMA node as a given GPU.
105    pub fn cpus_for_gpu(&self, gpu_id: u32) -> Vec<u32> {
106        if let Some(node_id) = self.gpu_numa_node(gpu_id) {
107            self.nodes
108                .iter()
109                .find(|n| n.node_id == node_id)
110                .map(|n| n.cpu_ids.clone())
111                .unwrap_or_default()
112        } else {
113            // No NUMA info — return all CPUs
114            self.nodes
115                .iter()
116                .flat_map(|n| n.cpu_ids.iter().copied())
117                .collect()
118        }
119    }
120
121    /// Select the best CPU for a task given its GPU affinity.
122    /// Returns the CPU ID from the preferred NUMA node with the lowest load.
123    pub fn select_cpu(&self, gpu_id: Option<u32>, cpu_loads: &HashMap<u32, f32>) -> u32 {
124        let preferred_cpus = match gpu_id {
125            Some(gid) => self.cpus_for_gpu(gid),
126            None => self
127                .nodes
128                .iter()
129                .flat_map(|n| n.cpu_ids.iter().copied())
130                .collect(),
131        };
132
133        // Pick the CPU with lowest load from preferred set
134        preferred_cpus
135            .iter()
136            .copied()
137            .min_by(|a, b| {
138                let la = cpu_loads.get(a).unwrap_or(&0.0);
139                let lb = cpu_loads.get(b).unwrap_or(&0.0);
140                la.partial_cmp(lb).unwrap_or(std::cmp::Ordering::Equal)
141            })
142            .unwrap_or(0)
143    }
144
145    pub fn total_cpus(&self) -> usize {
146        self.nodes.iter().map(|n| n.cpu_ids.len()).sum()
147    }
148}
149
150/// Parse a Linux CPU list string like "0-3,8-11" into individual CPU IDs.
151fn parse_cpu_list(s: &str) -> Vec<u32> {
152    let mut cpus = Vec::new();
153    for part in s.trim().split(',') {
154        let part = part.trim();
155        if part.is_empty() {
156            continue;
157        }
158        if let Some((start, end)) = part.split_once('-') {
159            if let (Ok(s), Ok(e)) = (start.parse::<u32>(), end.parse::<u32>()) {
160                cpus.extend(s..=e);
161            }
162        } else if let Ok(cpu) = part.parse::<u32>() {
163            cpus.push(cpu);
164        }
165    }
166    cpus
167}
168
169/// Parse MemTotal from a NUMA node meminfo file.
170fn parse_node_meminfo(s: &str) -> u64 {
171    for line in s.lines() {
172        if line.contains("MemTotal") {
173            // Format: "Node 0 MemTotal:       12345678 kB"
174            let parts: Vec<&str> = line.split_whitespace().collect();
175            if parts.len() >= 4 {
176                if let Ok(kb) = parts[3].parse::<u64>() {
177                    return kb / 1024; // Convert kB to MB
178                }
179            }
180        }
181    }
182    0
183}
184
185/// Detect which NUMA node each GPU is on.
186#[cfg(target_os = "linux")]
187fn detect_gpu_numa_map() -> HashMap<u32, u32> {
188    let mut map = HashMap::new();
189    // Try reading from /sys/bus/pci/devices/*/numa_node for NVIDIA GPUs
190    // Each GPU PCI device has a numa_node file
191    if let Ok(entries) = std::fs::read_dir("/sys/bus/pci/devices") {
192        let mut gpu_idx = 0u32;
193        for entry in entries.flatten() {
194            let vendor_path = entry.path().join("vendor");
195            if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
196                // NVIDIA vendor ID = 0x10de
197                if vendor.trim() == "0x10de" {
198                    let class_path = entry.path().join("class");
199                    if let Ok(class) = std::fs::read_to_string(&class_path) {
200                        // GPU class = 0x030000 or 0x030200
201                        if class.trim().starts_with("0x0302") || class.trim().starts_with("0x0300")
202                        {
203                            let numa_path = entry.path().join("numa_node");
204                            if let Ok(numa) = std::fs::read_to_string(&numa_path) {
205                                if let Ok(node) = numa.trim().parse::<i32>() {
206                                    if node >= 0 {
207                                        map.insert(gpu_idx, node as u32);
208                                    }
209                                }
210                            }
211                            gpu_idx += 1;
212                        }
213                    }
214                }
215            }
216        }
217    }
218    map
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn parse_cpu_list_simple() {
227        assert_eq!(parse_cpu_list("0-3"), vec![0, 1, 2, 3]);
228    }
229
230    #[test]
231    fn parse_cpu_list_ranges_and_singles() {
232        assert_eq!(parse_cpu_list("0-2,5,8-9"), vec![0, 1, 2, 5, 8, 9]);
233    }
234
235    #[test]
236    fn parse_cpu_list_empty() {
237        assert_eq!(parse_cpu_list(""), Vec::<u32>::new());
238    }
239
240    #[test]
241    fn single_node_topology() {
242        let topo = NumaTopology::detect();
243        assert!(!topo.nodes.is_empty());
244        assert!(topo.total_cpus() > 0);
245    }
246
247    #[test]
248    fn select_cpu_picks_lowest_load() {
249        let topo = NumaTopology {
250            nodes: vec![NumaNode {
251                node_id: 0,
252                cpu_ids: vec![0, 1, 2, 3],
253                memory_mb: 32768,
254            }],
255            gpu_node_map: HashMap::from([(0, 0)]),
256        };
257        let loads = HashMap::from([(0, 0.9), (1, 0.2), (2, 0.5), (3, 0.8)]);
258        assert_eq!(topo.select_cpu(Some(0), &loads), 1);
259    }
260
261    #[test]
262    fn cpus_for_gpu_with_mapping() {
263        let topo = NumaTopology {
264            nodes: vec![
265                NumaNode {
266                    node_id: 0,
267                    cpu_ids: vec![0, 1, 2, 3],
268                    memory_mb: 16384,
269                },
270                NumaNode {
271                    node_id: 1,
272                    cpu_ids: vec![4, 5, 6, 7],
273                    memory_mb: 16384,
274                },
275            ],
276            gpu_node_map: HashMap::from([(0, 0), (1, 1)]),
277        };
278        assert_eq!(topo.cpus_for_gpu(0), vec![0, 1, 2, 3]);
279        assert_eq!(topo.cpus_for_gpu(1), vec![4, 5, 6, 7]);
280    }
281}