zernel_ebpf/
gpu_watchdog.rs1use std::process::Command;
10use tracing::{info, warn};
11
12#[derive(Debug, Clone)]
14pub struct GpuMemoryState {
15 pub gpu_id: u32,
16 pub used_mb: u64,
17 pub total_mb: u64,
18 pub usage_pct: f64,
19}
20
21#[derive(Debug, Clone)]
23pub struct WatchdogConfig {
24 pub warn_pct: f64,
26 pub critical_pct: f64,
28 pub poll_interval_ms: u64,
30 pub auto_checkpoint: bool,
32}
33
34impl Default for WatchdogConfig {
35 fn default() -> Self {
36 Self {
37 warn_pct: 85.0,
38 critical_pct: 95.0,
39 poll_interval_ms: 2000,
40 auto_checkpoint: true,
41 }
42 }
43}
44
45pub fn poll_gpu_memory() -> Vec<GpuMemoryState> {
47 let output = Command::new("nvidia-smi")
48 .args([
49 "--query-gpu=index,memory.used,memory.total",
50 "--format=csv,noheader,nounits",
51 ])
52 .output();
53
54 let Ok(o) = output else { return Vec::new() };
55 if !o.status.success() {
56 return Vec::new();
57 }
58
59 let stdout = String::from_utf8_lossy(&o.stdout);
60 let mut states = Vec::new();
61
62 for line in stdout.lines() {
63 let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
64 if f.len() >= 3 {
65 let gpu_id: u32 = f[0].parse().unwrap_or(0);
66 let used_mb: u64 = f[1].parse().unwrap_or(0);
67 let total_mb: u64 = f[2].parse().unwrap_or(1);
68 let usage_pct = used_mb as f64 / total_mb as f64 * 100.0;
69
70 states.push(GpuMemoryState {
71 gpu_id,
72 used_mb,
73 total_mb,
74 usage_pct,
75 });
76 }
77 }
78
79 states
80}
81
82fn get_cuda_pids(gpu_id: u32) -> Vec<u32> {
84 let uuid_output = Command::new("nvidia-smi")
85 .args(["--query-gpu=index,uuid", "--format=csv,noheader"])
86 .output();
87
88 let target_uuid = uuid_output
89 .ok()
90 .and_then(|o| {
91 String::from_utf8_lossy(&o.stdout).lines().find_map(|line| {
92 let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
93 if f.len() >= 2 && f[0] == gpu_id.to_string() {
94 Some(f[1].to_string())
95 } else {
96 None
97 }
98 })
99 })
100 .unwrap_or_default();
101
102 if target_uuid.is_empty() {
103 return Vec::new();
104 }
105
106 let proc_output = Command::new("nvidia-smi")
107 .args([
108 "--query-compute-apps",
109 "gpu_uuid,pid",
110 "--format=csv,noheader",
111 ])
112 .output();
113
114 proc_output
115 .ok()
116 .map(|o| {
117 String::from_utf8_lossy(&o.stdout)
118 .lines()
119 .filter_map(|line| {
120 let f: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
121 if f.len() >= 2 && f[0].contains(&target_uuid) {
122 f[1].parse().ok()
123 } else {
124 None
125 }
126 })
127 .collect()
128 })
129 .unwrap_or_default()
130}
131
132pub async fn run_watchdog(config: WatchdogConfig) {
134 info!(
135 warn = config.warn_pct,
136 critical = config.critical_pct,
137 auto_checkpoint = config.auto_checkpoint,
138 "GPU memory watchdog started"
139 );
140
141 let mut warned: std::collections::HashSet<u32> = std::collections::HashSet::new();
142
143 loop {
144 let states = poll_gpu_memory();
145
146 for state in &states {
147 if state.usage_pct >= config.critical_pct {
148 warn!(
149 gpu = state.gpu_id,
150 used_mb = state.used_mb,
151 total_mb = state.total_mb,
152 usage_pct = format!("{:.1}", state.usage_pct),
153 "GPU memory CRITICAL — OOM imminent"
154 );
155
156 if config.auto_checkpoint {
157 let pids = get_cuda_pids(state.gpu_id);
158 for pid in &pids {
159 info!(
160 pid,
161 gpu = state.gpu_id,
162 "sending SIGUSR1 for emergency checkpoint"
163 );
164 #[cfg(unix)]
165 unsafe {
166 libc::kill(*pid as i32, libc::SIGUSR1);
167 }
168 }
169 }
170 } else if state.usage_pct >= config.warn_pct && !warned.contains(&state.gpu_id) {
171 warn!(
172 gpu = state.gpu_id,
173 used_mb = state.used_mb,
174 total_mb = state.total_mb,
175 usage_pct = format!("{:.1}", state.usage_pct),
176 "GPU memory pressure warning"
177 );
178 warned.insert(state.gpu_id);
179 } else if state.usage_pct < config.warn_pct * 0.9 {
180 warned.remove(&state.gpu_id);
181 }
182 }
183
184 tokio::time::sleep(tokio::time::Duration::from_millis(config.poll_interval_ms)).await;
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn default_config() {
194 let cfg = WatchdogConfig::default();
195 assert_eq!(cfg.warn_pct, 85.0);
196 assert_eq!(cfg.critical_pct, 95.0);
197 assert!(cfg.auto_checkpoint);
198 }
199}