zernel/experiments/
tracker.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3use regex::Regex;
4use std::collections::HashMap;
5
6/// Extracts metrics from training script stdout.
7///
8/// Recognizes common patterns:
9/// - `loss: 1.234` or `loss=1.234`
10/// - `accuracy: 0.95` or `acc=0.95`
11/// - tqdm-style progress bars with metric suffixes
12/// - HuggingFace Trainer log format
13pub struct MetricExtractor {
14    patterns: Vec<CompiledPattern>,
15}
16
17struct CompiledPattern {
18    name: String,
19    regex: Regex,
20}
21
22impl MetricExtractor {
23    pub fn new() -> Self {
24        let patterns = vec![
25            ("loss", r"(?i)\bloss[=:\s]+([0-9]+\.?[0-9]*)"),
26            (
27                "accuracy",
28                r"(?i)\b(?:accuracy|acc)[=:\s]+([0-9]+\.?[0-9]*)",
29            ),
30            ("grad_norm", r"(?i)\bgrad_norm[=:\s]+([0-9]+\.?[0-9]*)"),
31            (
32                "learning_rate",
33                r"(?i)\b(?:learning_rate|lr)[=:\s]+([0-9]+\.?[0-9eE\-]*)",
34            ),
35            (
36                "throughput",
37                r"(?i)\b(?:throughput|samples/s|it/s)[=:\s]+([0-9]+\.?[0-9]*)",
38            ),
39            ("epoch", r"(?i)\bepoch[=:\s]+([0-9]+\.?[0-9]*)"),
40            ("step", r"(?i)\b(?:step|global_step)[=:\s]+([0-9]+)"),
41            (
42                "perplexity",
43                r"(?i)\b(?:perplexity|ppl)[=:\s]+([0-9]+\.?[0-9]*)",
44            ),
45            ("eval_loss", r"(?i)\beval_loss[=:\s]+([0-9]+\.?[0-9]*)"),
46        ];
47
48        Self {
49            patterns: patterns
50                .into_iter()
51                .filter_map(|(name, pat)| {
52                    Regex::new(pat).ok().map(|regex| CompiledPattern {
53                        name: name.to_string(),
54                        regex,
55                    })
56                })
57                .collect(),
58        }
59    }
60
61    /// Parse a line of stdout and extract any recognized metrics.
62    pub fn extract_from_line(&self, line: &str) -> HashMap<String, f64> {
63        let mut metrics = HashMap::new();
64
65        for pattern in &self.patterns {
66            if let Some(caps) = pattern.regex.captures(line) {
67                if let Some(m) = caps.get(1) {
68                    if let Ok(val) = m.as_str().parse::<f64>() {
69                        metrics.insert(pattern.name.clone(), val);
70                    }
71                }
72            }
73        }
74
75        metrics
76    }
77}
78
79/// Generates a unique experiment ID.
80pub fn generate_experiment_id() -> String {
81    let now = chrono::Utc::now();
82    let short_uuid = &uuid::Uuid::new_v4().to_string()[..8];
83    format!("exp-{}-{}", now.format("%Y%m%d-%H%M%S"), short_uuid)
84}
85
86/// Get the zernel data directory (~/.zernel/).
87pub fn zernel_dir() -> std::path::PathBuf {
88    dirs::home_dir()
89        .unwrap_or_else(|| std::path::PathBuf::from("."))
90        .join(".zernel")
91}
92
93/// Get the experiments database path.
94pub fn experiments_db_path() -> std::path::PathBuf {
95    let dir = zernel_dir().join("experiments");
96    std::fs::create_dir_all(&dir).ok();
97    dir.join("experiments.db")
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn extract_loss() {
106        let ext = MetricExtractor::new();
107        let m = ext.extract_from_line("Epoch 1/10, loss: 1.2345, accuracy: 0.876");
108        assert!((m["loss"] - 1.2345).abs() < 0.0001);
109        assert!((m["accuracy"] - 0.876).abs() < 0.001);
110    }
111
112    #[test]
113    fn extract_equals_format() {
114        let ext = MetricExtractor::new();
115        let m = ext.extract_from_line("loss=0.456 lr=3e-4 grad_norm=0.89");
116        assert!((m["loss"] - 0.456).abs() < 0.001);
117        assert!((m["grad_norm"] - 0.89).abs() < 0.01);
118    }
119
120    #[test]
121    fn extract_huggingface_format() {
122        let ext = MetricExtractor::new();
123        // HF Trainer logs: {'loss': 2.1, 'learning_rate': 5e-05, 'epoch': 0.5}
124        // The regex uses \b word boundary, which fires after the ' quote
125        let m = ext.extract_from_line("loss: 2.1, learning_rate: 5e-05, epoch: 0.5");
126        assert!((m["loss"] - 2.1).abs() < 0.01);
127        assert!((m["epoch"] - 0.5).abs() < 0.01);
128    }
129
130    #[test]
131    fn extract_step() {
132        let ext = MetricExtractor::new();
133        let m = ext.extract_from_line("Step 4821/10000, loss: 1.23");
134        assert_eq!(m["step"], 4821.0);
135    }
136
137    #[test]
138    fn no_match_returns_empty() {
139        let ext = MetricExtractor::new();
140        let m = ext.extract_from_line("Loading dataset from disk...");
141        assert!(m.is_empty());
142    }
143
144    #[test]
145    fn experiment_id_is_unique() {
146        let a = generate_experiment_id();
147        let b = generate_experiment_id();
148        assert_ne!(a, b);
149        assert!(a.starts_with("exp-"));
150    }
151}