zernel/experiments/
store.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Experiment {
12    pub id: String,
13    pub name: String,
14    pub status: ExperimentStatus,
15    pub hyperparams: HashMap<String, serde_json::Value>,
16    pub metrics: HashMap<String, f64>,
17    pub created_at: DateTime<Utc>,
18    pub finished_at: Option<DateTime<Utc>>,
19    pub git_commit: Option<String>,
20    pub script: Option<String>,
21    pub duration_secs: Option<f64>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub enum ExperimentStatus {
26    Running,
27    Done,
28    Failed,
29    Queued,
30}
31
32impl std::fmt::Display for ExperimentStatus {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::Running => write!(f, "Running"),
36            Self::Done => write!(f, "Done"),
37            Self::Failed => write!(f, "Failed"),
38            Self::Queued => write!(f, "Queued"),
39        }
40    }
41}
42
43/// SQLite-backed experiment store.
44pub struct ExperimentStore {
45    conn: Connection,
46}
47
48impl ExperimentStore {
49    pub fn open(path: &Path) -> Result<Self> {
50        let conn = Connection::open(path)?;
51        conn.execute_batch(
52            "CREATE TABLE IF NOT EXISTS experiments (
53                id TEXT PRIMARY KEY,
54                name TEXT NOT NULL,
55                status TEXT NOT NULL,
56                hyperparams TEXT NOT NULL DEFAULT '{}',
57                metrics TEXT NOT NULL DEFAULT '{}',
58                created_at TEXT NOT NULL,
59                finished_at TEXT,
60                git_commit TEXT,
61                script TEXT,
62                duration_secs REAL
63            );
64
65            CREATE INDEX IF NOT EXISTS idx_experiments_created
66                ON experiments(created_at DESC);
67            CREATE INDEX IF NOT EXISTS idx_experiments_status
68                ON experiments(status);",
69        )?;
70        Ok(Self { conn })
71    }
72
73    pub fn insert(&self, exp: &Experiment) -> Result<()> {
74        self.conn.execute(
75            "INSERT INTO experiments (id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs)
76             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
77            (
78                &exp.id,
79                &exp.name,
80                serde_json::to_string(&exp.status)?,
81                serde_json::to_string(&exp.hyperparams)?,
82                serde_json::to_string(&exp.metrics)?,
83                exp.created_at.to_rfc3339(),
84                exp.finished_at.map(|t| t.to_rfc3339()),
85                &exp.git_commit,
86                &exp.script,
87                exp.duration_secs,
88            ),
89        )?;
90        Ok(())
91    }
92
93    pub fn update_status(&self, id: &str, status: &ExperimentStatus) -> Result<()> {
94        self.conn.execute(
95            "UPDATE experiments SET status = ?1 WHERE id = ?2",
96            (serde_json::to_string(status)?, id),
97        )?;
98        Ok(())
99    }
100
101    pub fn update_metrics(&self, id: &str, metrics: &HashMap<String, f64>) -> Result<()> {
102        self.conn.execute(
103            "UPDATE experiments SET metrics = ?1 WHERE id = ?2",
104            (serde_json::to_string(metrics)?, id),
105        )?;
106        Ok(())
107    }
108
109    pub fn finish(&self, id: &str, status: ExperimentStatus, duration_secs: f64) -> Result<()> {
110        let now = Utc::now().to_rfc3339();
111        self.conn.execute(
112            "UPDATE experiments SET status = ?1, finished_at = ?2, duration_secs = ?3 WHERE id = ?4",
113            (serde_json::to_string(&status)?, &now, duration_secs, id),
114        )?;
115        Ok(())
116    }
117
118    pub fn get(&self, id: &str) -> Result<Option<Experiment>> {
119        let mut stmt = self.conn.prepare(
120            "SELECT id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs FROM experiments WHERE id = ?1"
121        )?;
122
123        let mut rows = stmt.query_map([id], Self::row_to_experiment)?;
124        match rows.next() {
125            Some(Ok(exp)) => Ok(Some(exp)),
126            _ => Ok(None),
127        }
128    }
129
130    pub fn list(&self, limit: usize) -> Result<Vec<Experiment>> {
131        let mut stmt = self.conn.prepare(
132            "SELECT id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs FROM experiments ORDER BY created_at DESC LIMIT ?1"
133        )?;
134
135        let experiments = stmt
136            .query_map([limit], Self::row_to_experiment)?
137            .collect::<Result<Vec<_>, _>>()?;
138
139        Ok(experiments)
140    }
141
142    pub fn delete(&self, id: &str) -> Result<bool> {
143        let count = self
144            .conn
145            .execute("DELETE FROM experiments WHERE id = ?1", [id])?;
146        Ok(count > 0)
147    }
148
149    fn row_to_experiment(row: &rusqlite::Row) -> rusqlite::Result<Experiment> {
150        Ok(Experiment {
151            id: row.get(0)?,
152            name: row.get(1)?,
153            status: serde_json::from_str(&row.get::<_, String>(2)?)
154                .unwrap_or(ExperimentStatus::Failed),
155            hyperparams: serde_json::from_str(&row.get::<_, String>(3)?).unwrap_or_default(),
156            metrics: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
157            created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
158                .unwrap_or_default()
159                .with_timezone(&Utc),
160            finished_at: row
161                .get::<_, Option<String>>(6)?
162                .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
163                .map(|t| t.with_timezone(&Utc)),
164            git_commit: row.get(7)?,
165            script: row.get(8)?,
166            duration_secs: row.get(9)?,
167        })
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    fn temp_store() -> (ExperimentStore, tempfile::TempDir) {
176        let dir = tempfile::tempdir().unwrap();
177        let path = dir.path().join("test.db");
178        let store = ExperimentStore::open(&path).unwrap();
179        (store, dir)
180    }
181
182    fn make_experiment(id: &str) -> Experiment {
183        Experiment {
184            id: id.into(),
185            name: format!("test-{id}"),
186            status: ExperimentStatus::Running,
187            hyperparams: HashMap::from([("lr".into(), serde_json::json!(0.001))]),
188            metrics: HashMap::from([("loss".into(), 1.5)]),
189            created_at: Utc::now(),
190            finished_at: None,
191            git_commit: Some("abc123".into()),
192            script: Some("train.py".into()),
193            duration_secs: None,
194        }
195    }
196
197    #[test]
198    fn insert_and_list() {
199        let (store, _dir) = temp_store();
200        store.insert(&make_experiment("exp-001")).unwrap();
201        store.insert(&make_experiment("exp-002")).unwrap();
202
203        let list = store.list(10).unwrap();
204        assert_eq!(list.len(), 2);
205    }
206
207    #[test]
208    fn get_by_id() {
209        let (store, _dir) = temp_store();
210        store.insert(&make_experiment("exp-001")).unwrap();
211
212        let exp = store.get("exp-001").unwrap().unwrap();
213        assert_eq!(exp.name, "test-exp-001");
214    }
215
216    #[test]
217    fn finish_updates_status() {
218        let (store, _dir) = temp_store();
219        store.insert(&make_experiment("exp-001")).unwrap();
220        store
221            .finish("exp-001", ExperimentStatus::Done, 123.4)
222            .unwrap();
223
224        let exp = store.get("exp-001").unwrap().unwrap();
225        assert_eq!(exp.status, ExperimentStatus::Done);
226        assert!((exp.duration_secs.unwrap() - 123.4).abs() < 0.01);
227    }
228
229    #[test]
230    fn delete_experiment() {
231        let (store, _dir) = temp_store();
232        store.insert(&make_experiment("exp-001")).unwrap();
233        assert!(store.delete("exp-001").unwrap());
234        assert!(store.get("exp-001").unwrap().is_none());
235    }
236}