zernel/experiments/
tracker.rs1use regex::Regex;
4use std::collections::HashMap;
5
6pub struct MetricExtractor {
14 patterns: Vec<CompiledPattern>,
15}
16
17struct CompiledPattern {
18 name: String,
19 regex: Regex,
20}
21
22impl MetricExtractor {
23 pub fn new() -> Self {
24 let patterns = vec![
25 ("loss", r"(?i)\bloss[=:\s]+([0-9]+\.?[0-9]*)"),
26 (
27 "accuracy",
28 r"(?i)\b(?:accuracy|acc)[=:\s]+([0-9]+\.?[0-9]*)",
29 ),
30 ("grad_norm", r"(?i)\bgrad_norm[=:\s]+([0-9]+\.?[0-9]*)"),
31 (
32 "learning_rate",
33 r"(?i)\b(?:learning_rate|lr)[=:\s]+([0-9]+\.?[0-9eE\-]*)",
34 ),
35 (
36 "throughput",
37 r"(?i)\b(?:throughput|samples/s|it/s)[=:\s]+([0-9]+\.?[0-9]*)",
38 ),
39 ("epoch", r"(?i)\bepoch[=:\s]+([0-9]+\.?[0-9]*)"),
40 ("step", r"(?i)\b(?:step|global_step)[=:\s]+([0-9]+)"),
41 (
42 "perplexity",
43 r"(?i)\b(?:perplexity|ppl)[=:\s]+([0-9]+\.?[0-9]*)",
44 ),
45 ("eval_loss", r"(?i)\beval_loss[=:\s]+([0-9]+\.?[0-9]*)"),
46 ];
47
48 Self {
49 patterns: patterns
50 .into_iter()
51 .filter_map(|(name, pat)| {
52 Regex::new(pat).ok().map(|regex| CompiledPattern {
53 name: name.to_string(),
54 regex,
55 })
56 })
57 .collect(),
58 }
59 }
60
61 pub fn extract_from_line(&self, line: &str) -> HashMap<String, f64> {
63 let mut metrics = HashMap::new();
64
65 for pattern in &self.patterns {
66 if let Some(caps) = pattern.regex.captures(line) {
67 if let Some(m) = caps.get(1) {
68 if let Ok(val) = m.as_str().parse::<f64>() {
69 metrics.insert(pattern.name.clone(), val);
70 }
71 }
72 }
73 }
74
75 metrics
76 }
77}
78
79pub fn generate_experiment_id() -> String {
81 let now = chrono::Utc::now();
82 let short_uuid = &uuid::Uuid::new_v4().to_string()[..8];
83 format!("exp-{}-{}", now.format("%Y%m%d-%H%M%S"), short_uuid)
84}
85
86pub fn zernel_dir() -> std::path::PathBuf {
88 dirs::home_dir()
89 .unwrap_or_else(|| std::path::PathBuf::from("."))
90 .join(".zernel")
91}
92
93pub fn experiments_db_path() -> std::path::PathBuf {
95 let dir = zernel_dir().join("experiments");
96 std::fs::create_dir_all(&dir).ok();
97 dir.join("experiments.db")
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn extract_loss() {
106 let ext = MetricExtractor::new();
107 let m = ext.extract_from_line("Epoch 1/10, loss: 1.2345, accuracy: 0.876");
108 assert!((m["loss"] - 1.2345).abs() < 0.0001);
109 assert!((m["accuracy"] - 0.876).abs() < 0.001);
110 }
111
112 #[test]
113 fn extract_equals_format() {
114 let ext = MetricExtractor::new();
115 let m = ext.extract_from_line("loss=0.456 lr=3e-4 grad_norm=0.89");
116 assert!((m["loss"] - 0.456).abs() < 0.001);
117 assert!((m["grad_norm"] - 0.89).abs() < 0.01);
118 }
119
120 #[test]
121 fn extract_huggingface_format() {
122 let ext = MetricExtractor::new();
123 let m = ext.extract_from_line("loss: 2.1, learning_rate: 5e-05, epoch: 0.5");
126 assert!((m["loss"] - 2.1).abs() < 0.01);
127 assert!((m["epoch"] - 0.5).abs() < 0.01);
128 }
129
130 #[test]
131 fn extract_step() {
132 let ext = MetricExtractor::new();
133 let m = ext.extract_from_line("Step 4821/10000, loss: 1.23");
134 assert_eq!(m["step"], 4821.0);
135 }
136
137 #[test]
138 fn no_match_returns_empty() {
139 let ext = MetricExtractor::new();
140 let m = ext.extract_from_line("Loading dataset from disk...");
141 assert!(m.is_empty());
142 }
143
144 #[test]
145 fn experiment_id_is_unique() {
146 let a = generate_experiment_id();
147 let b = generate_experiment_id();
148 assert_ne!(a, b);
149 assert!(a.starts_with("exp-"));
150 }
151}