zernel_ebpf/
gpu_watchdog.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3//! GPU Memory Pressure Watchdog
4//!
5//! Continuously monitors GPU memory usage and sends early warnings
6//! before OOM occurs. Can trigger automatic checkpointing in training
7//! frameworks that support SIGUSR1 signal handling.
8
9use std::process::Command;
10use tracing::{info, warn};
11
12/// GPU memory state.
13#[derive(Debug, Clone)]
14pub struct GpuMemoryState {
15    pub gpu_id: u32,
16    pub used_mb: u64,
17    pub total_mb: u64,
18    pub usage_pct: f64,
19}
20
21/// Thresholds for memory pressure alerts.
22#[derive(Debug, Clone)]
23pub struct WatchdogConfig {
24    /// Warning threshold (percentage). Default: 85%.
25    pub warn_pct: f64,
26    /// Critical threshold (percentage). Default: 95%.
27    pub critical_pct: f64,
28    /// Polling interval in milliseconds. Default: 2000.
29    pub poll_interval_ms: u64,
30    /// Send SIGUSR1 to training processes at critical threshold.
31    pub auto_checkpoint: bool,
32}
33
34impl Default for WatchdogConfig {
35    fn default() -> Self {
36        Self {
37            warn_pct: 85.0,
38            critical_pct: 95.0,
39            poll_interval_ms: 2000,
40            auto_checkpoint: true,
41        }
42    }
43}
44
45/// Poll current GPU memory state.
46pub fn poll_gpu_memory() -> Vec<GpuMemoryState> {
47    let output = Command::new("nvidia-smi")
48        .args([
49            "--query-gpu=index,memory.used,memory.total",
50            "--format=csv,noheader,nounits",
51        ])
52        .output();
53
54    let Ok(o) = output else { return Vec::new() };
55    if !o.status.success() {
56        return Vec::new();
57    }
58
59    let stdout = String::from_utf8_lossy(&o.stdout);
60    let mut states = Vec::new();
61
62    for line in stdout.lines() {
63        let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
64        if f.len() >= 3 {
65            let gpu_id: u32 = f[0].parse().unwrap_or(0);
66            let used_mb: u64 = f[1].parse().unwrap_or(0);
67            let total_mb: u64 = f[2].parse().unwrap_or(1);
68            let usage_pct = used_mb as f64 / total_mb as f64 * 100.0;
69
70            states.push(GpuMemoryState {
71                gpu_id,
72                used_mb,
73                total_mb,
74                usage_pct,
75            });
76        }
77    }
78
79    states
80}
81
82/// Get PIDs of processes using CUDA on a specific GPU.
83fn get_cuda_pids(gpu_id: u32) -> Vec<u32> {
84    let uuid_output = Command::new("nvidia-smi")
85        .args(["--query-gpu=index,uuid", "--format=csv,noheader"])
86        .output();
87
88    let target_uuid = uuid_output
89        .ok()
90        .and_then(|o| {
91            String::from_utf8_lossy(&o.stdout).lines().find_map(|line| {
92                let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
93                if f.len() >= 2 && f[0] == gpu_id.to_string() {
94                    Some(f[1].to_string())
95                } else {
96                    None
97                }
98            })
99        })
100        .unwrap_or_default();
101
102    if target_uuid.is_empty() {
103        return Vec::new();
104    }
105
106    let proc_output = Command::new("nvidia-smi")
107        .args([
108            "--query-compute-apps",
109            "gpu_uuid,pid",
110            "--format=csv,noheader",
111        ])
112        .output();
113
114    proc_output
115        .ok()
116        .map(|o| {
117            String::from_utf8_lossy(&o.stdout)
118                .lines()
119                .filter_map(|line| {
120                    let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
121                    if f.len() >= 2 && f[0].contains(&target_uuid) {
122                        f[1].parse().ok()
123                    } else {
124                        None
125                    }
126                })
127                .collect()
128        })
129        .unwrap_or_default()
130}
131
132/// Run the GPU memory watchdog loop.
133pub async fn run_watchdog(config: WatchdogConfig) {
134    info!(
135        warn = config.warn_pct,
136        critical = config.critical_pct,
137        auto_checkpoint = config.auto_checkpoint,
138        "GPU memory watchdog started"
139    );
140
141    let mut warned: std::collections::HashSet<u32> = std::collections::HashSet::new();
142
143    loop {
144        let states = poll_gpu_memory();
145
146        for state in &states {
147            if state.usage_pct >= config.critical_pct {
148                warn!(
149                    gpu = state.gpu_id,
150                    used_mb = state.used_mb,
151                    total_mb = state.total_mb,
152                    usage_pct = format!("{:.1}", state.usage_pct),
153                    "GPU memory CRITICAL — OOM imminent"
154                );
155
156                if config.auto_checkpoint {
157                    let pids = get_cuda_pids(state.gpu_id);
158                    for pid in &pids {
159                        info!(
160                            pid,
161                            gpu = state.gpu_id,
162                            "sending SIGUSR1 for emergency checkpoint"
163                        );
164                        #[cfg(unix)]
165                        unsafe {
166                            libc::kill(*pid as i32, libc::SIGUSR1);
167                        }
168                    }
169                }
170            } else if state.usage_pct >= config.warn_pct && !warned.contains(&state.gpu_id) {
171                warn!(
172                    gpu = state.gpu_id,
173                    used_mb = state.used_mb,
174                    total_mb = state.total_mb,
175                    usage_pct = format!("{:.1}", state.usage_pct),
176                    "GPU memory pressure warning"
177                );
178                warned.insert(state.gpu_id);
179            } else if state.usage_pct < config.warn_pct * 0.9 {
180                warned.remove(&state.gpu_id);
181            }
182        }
183
184        tokio::time::sleep(tokio::time::Duration::from_millis(config.poll_interval_ms)).await;
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn default_config() {
194        let cfg = WatchdogConfig::default();
195        assert_eq!(cfg.warn_pct, 85.0);
196        assert_eq!(cfg.critical_pct, 95.0);
197        assert!(cfg.auto_checkpoint);
198    }
199}