1use std::collections::VecDeque;
13use tracing::debug;
14
15pub struct PrefetchPredictor {
17 compute_durations: VecDeque<u64>,
19 loading_durations: VecDeque<u64>,
21 max_history: usize,
23 prefetch_lead_ns: u64,
25}
26
27impl PrefetchPredictor {
28 pub fn new(max_history: usize, prefetch_lead_ns: u64) -> Self {
29 Self {
30 compute_durations: VecDeque::with_capacity(max_history),
31 loading_durations: VecDeque::with_capacity(max_history),
32 max_history,
33 prefetch_lead_ns,
34 }
35 }
36
37 pub fn record_compute(&mut self, duration_ns: u64) {
39 if self.compute_durations.len() >= self.max_history {
40 self.compute_durations.pop_front();
41 }
42 self.compute_durations.push_back(duration_ns);
43 }
44
45 pub fn record_loading(&mut self, duration_ns: u64) {
47 if self.loading_durations.len() >= self.max_history {
48 self.loading_durations.pop_front();
49 }
50 self.loading_durations.push_back(duration_ns);
51 }
52
53 pub fn predicted_compute_ns(&self) -> Option<u64> {
55 if self.compute_durations.is_empty() {
56 return None;
57 }
58 let alpha = 0.3;
60 let mut ema = self.compute_durations[0] as f64;
61 for &d in self.compute_durations.iter().skip(1) {
62 ema = alpha * d as f64 + (1.0 - alpha) * ema;
63 }
64 Some(ema as u64)
65 }
66
67 pub fn predicted_loading_ns(&self) -> Option<u64> {
69 if self.loading_durations.is_empty() {
70 return None;
71 }
72 let alpha = 0.3;
73 let mut ema = self.loading_durations[0] as f64;
74 for &d in self.loading_durations.iter().skip(1) {
75 ema = alpha * d as f64 + (1.0 - alpha) * ema;
76 }
77 Some(ema as u64)
78 }
79
80 pub fn should_prefetch(&self, elapsed_compute_ns: u64) -> bool {
83 let Some(predicted) = self.predicted_compute_ns() else {
84 return false;
85 };
86
87 if elapsed_compute_ns + self.prefetch_lead_ns >= predicted {
88 debug!(
89 elapsed_ms = elapsed_compute_ns / 1_000_000,
90 predicted_ms = predicted / 1_000_000,
91 lead_ms = self.prefetch_lead_ns / 1_000_000,
92 "triggering predictive prefetch"
93 );
94 return true;
95 }
96
97 false
98 }
99
100 pub fn overlap_efficiency(&self) -> f64 {
105 let compute = self.predicted_compute_ns().unwrap_or(1) as f64;
106 let loading = self.predicted_loading_ns().unwrap_or(1) as f64;
107 if loading == 0.0 {
108 return 1.0;
109 }
110 (compute + self.prefetch_lead_ns as f64) / loading
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn prediction_converges() {
120 let mut p = PrefetchPredictor::new(10, 5_000_000); for _ in 0..10 {
123 p.record_compute(100_000_000);
124 }
125 let predicted = p.predicted_compute_ns().unwrap();
126 assert!((predicted as f64 - 100_000_000.0).abs() < 1_000_000.0);
127 }
128
129 #[test]
130 fn prefetch_triggers_near_end() {
131 let mut p = PrefetchPredictor::new(10, 10_000_000); for _ in 0..5 {
133 p.record_compute(100_000_000); }
135 assert!(!p.should_prefetch(80_000_000)); assert!(p.should_prefetch(95_000_000)); }
139
140 #[test]
141 fn overlap_efficiency_calculation() {
142 let mut p = PrefetchPredictor::new(10, 10_000_000);
143 for _ in 0..5 {
144 p.record_compute(100_000_000);
145 p.record_loading(50_000_000);
146 }
147 let eff = p.overlap_efficiency();
149 assert!(eff > 1.0); }
151}