zernel_ebpf/
prefetch.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3//! Predictive Data Prefetching
4//!
5//! Uses phase timing data from the sched_ext scheduler to predict when
6//! the GpuCompute phase will end and signal the DataLoader to start
7//! prefetching the next batch BEFORE the GPU finishes.
8//!
9//! This eliminates the "GPU idle waiting for data" gap that causes
10//! 5-30% throughput loss in data-bound training.
11
12use std::collections::VecDeque;
13use tracing::debug;
14
15/// Tracks phase durations to predict when to trigger prefetch.
16pub struct PrefetchPredictor {
17    /// Recent GpuCompute phase durations (ns).
18    compute_durations: VecDeque<u64>,
19    /// Recent DataLoading phase durations (ns).
20    loading_durations: VecDeque<u64>,
21    /// Maximum history size.
22    max_history: usize,
23    /// How far before predicted compute end to trigger prefetch (ns).
24    prefetch_lead_ns: u64,
25}
26
27impl PrefetchPredictor {
28    pub fn new(max_history: usize, prefetch_lead_ns: u64) -> Self {
29        Self {
30            compute_durations: VecDeque::with_capacity(max_history),
31            loading_durations: VecDeque::with_capacity(max_history),
32            max_history,
33            prefetch_lead_ns,
34        }
35    }
36
37    /// Record a completed GpuCompute phase duration.
38    pub fn record_compute(&mut self, duration_ns: u64) {
39        if self.compute_durations.len() >= self.max_history {
40            self.compute_durations.pop_front();
41        }
42        self.compute_durations.push_back(duration_ns);
43    }
44
45    /// Record a completed DataLoading phase duration.
46    pub fn record_loading(&mut self, duration_ns: u64) {
47        if self.loading_durations.len() >= self.max_history {
48            self.loading_durations.pop_front();
49        }
50        self.loading_durations.push_back(duration_ns);
51    }
52
53    /// Predict the next GpuCompute duration using exponential moving average.
54    pub fn predicted_compute_ns(&self) -> Option<u64> {
55        if self.compute_durations.is_empty() {
56            return None;
57        }
58        // Exponential moving average (alpha = 0.3)
59        let alpha = 0.3;
60        let mut ema = self.compute_durations[0] as f64;
61        for &d in self.compute_durations.iter().skip(1) {
62            ema = alpha * d as f64 + (1.0 - alpha) * ema;
63        }
64        Some(ema as u64)
65    }
66
67    /// Predict the next DataLoading duration.
68    pub fn predicted_loading_ns(&self) -> Option<u64> {
69        if self.loading_durations.is_empty() {
70            return None;
71        }
72        let alpha = 0.3;
73        let mut ema = self.loading_durations[0] as f64;
74        for &d in self.loading_durations.iter().skip(1) {
75            ema = alpha * d as f64 + (1.0 - alpha) * ema;
76        }
77        Some(ema as u64)
78    }
79
80    /// Should we trigger prefetch now?
81    /// Call this periodically during GpuCompute phase with elapsed time.
82    pub fn should_prefetch(&self, elapsed_compute_ns: u64) -> bool {
83        let Some(predicted) = self.predicted_compute_ns() else {
84            return false;
85        };
86
87        if elapsed_compute_ns + self.prefetch_lead_ns >= predicted {
88            debug!(
89                elapsed_ms = elapsed_compute_ns / 1_000_000,
90                predicted_ms = predicted / 1_000_000,
91                lead_ms = self.prefetch_lead_ns / 1_000_000,
92                "triggering predictive prefetch"
93            );
94            return true;
95        }
96
97        false
98    }
99
100    /// Calculate the overlap efficiency.
101    /// 1.0 = perfect overlap (data ready exactly when GPU finishes)
102    /// <1.0 = GPU had to wait for data
103    /// >1.0 = data was ready before GPU finished (ideal)
104    pub fn overlap_efficiency(&self) -> f64 {
105        let compute = self.predicted_compute_ns().unwrap_or(1) as f64;
106        let loading = self.predicted_loading_ns().unwrap_or(1) as f64;
107        if loading == 0.0 {
108            return 1.0;
109        }
110        (compute + self.prefetch_lead_ns as f64) / loading
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn prediction_converges() {
120        let mut p = PrefetchPredictor::new(10, 5_000_000); // 5ms lead
121                                                           // Simulate consistent 100ms compute phases
122        for _ in 0..10 {
123            p.record_compute(100_000_000);
124        }
125        let predicted = p.predicted_compute_ns().unwrap();
126        assert!((predicted as f64 - 100_000_000.0).abs() < 1_000_000.0);
127    }
128
129    #[test]
130    fn prefetch_triggers_near_end() {
131        let mut p = PrefetchPredictor::new(10, 10_000_000); // 10ms lead
132        for _ in 0..5 {
133            p.record_compute(100_000_000); // 100ms phases
134        }
135        // At 85ms into compute (15ms before predicted end, within 10ms lead)
136        assert!(!p.should_prefetch(80_000_000)); // too early
137        assert!(p.should_prefetch(95_000_000)); // within lead time
138    }
139
140    #[test]
141    fn overlap_efficiency_calculation() {
142        let mut p = PrefetchPredictor::new(10, 10_000_000);
143        for _ in 0..5 {
144            p.record_compute(100_000_000);
145            p.record_loading(50_000_000);
146        }
147        // compute(100ms) + lead(10ms) / loading(50ms) = 2.2
148        let eff = p.overlap_efficiency();
149        assert!(eff > 1.0); // data ready well before GPU finishes
150    }
151}