1use anyhow::Result;
4use chrono::{DateTime, Utc};
5use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Experiment {
12 pub id: String,
13 pub name: String,
14 pub status: ExperimentStatus,
15 pub hyperparams: HashMap<String, serde_json::Value>,
16 pub metrics: HashMap<String, f64>,
17 pub created_at: DateTime<Utc>,
18 pub finished_at: Option<DateTime<Utc>>,
19 pub git_commit: Option<String>,
20 pub script: Option<String>,
21 pub duration_secs: Option<f64>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub enum ExperimentStatus {
26 Running,
27 Done,
28 Failed,
29 Queued,
30}
31
32impl std::fmt::Display for ExperimentStatus {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::Running => write!(f, "Running"),
36 Self::Done => write!(f, "Done"),
37 Self::Failed => write!(f, "Failed"),
38 Self::Queued => write!(f, "Queued"),
39 }
40 }
41}
42
43pub struct ExperimentStore {
45 conn: Connection,
46}
47
48impl ExperimentStore {
49 pub fn open(path: &Path) -> Result<Self> {
50 let conn = Connection::open(path)?;
51 conn.execute_batch(
52 "CREATE TABLE IF NOT EXISTS experiments (
53 id TEXT PRIMARY KEY,
54 name TEXT NOT NULL,
55 status TEXT NOT NULL,
56 hyperparams TEXT NOT NULL DEFAULT '{}',
57 metrics TEXT NOT NULL DEFAULT '{}',
58 created_at TEXT NOT NULL,
59 finished_at TEXT,
60 git_commit TEXT,
61 script TEXT,
62 duration_secs REAL
63 );
64
65 CREATE INDEX IF NOT EXISTS idx_experiments_created
66 ON experiments(created_at DESC);
67 CREATE INDEX IF NOT EXISTS idx_experiments_status
68 ON experiments(status);",
69 )?;
70 Ok(Self { conn })
71 }
72
73 pub fn insert(&self, exp: &Experiment) -> Result<()> {
74 self.conn.execute(
75 "INSERT INTO experiments (id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs)
76 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
77 (
78 &exp.id,
79 &exp.name,
80 serde_json::to_string(&exp.status)?,
81 serde_json::to_string(&exp.hyperparams)?,
82 serde_json::to_string(&exp.metrics)?,
83 exp.created_at.to_rfc3339(),
84 exp.finished_at.map(|t| t.to_rfc3339()),
85 &exp.git_commit,
86 &exp.script,
87 exp.duration_secs,
88 ),
89 )?;
90 Ok(())
91 }
92
93 pub fn update_status(&self, id: &str, status: &ExperimentStatus) -> Result<()> {
94 self.conn.execute(
95 "UPDATE experiments SET status = ?1 WHERE id = ?2",
96 (serde_json::to_string(status)?, id),
97 )?;
98 Ok(())
99 }
100
101 pub fn update_metrics(&self, id: &str, metrics: &HashMap<String, f64>) -> Result<()> {
102 self.conn.execute(
103 "UPDATE experiments SET metrics = ?1 WHERE id = ?2",
104 (serde_json::to_string(metrics)?, id),
105 )?;
106 Ok(())
107 }
108
109 pub fn finish(&self, id: &str, status: ExperimentStatus, duration_secs: f64) -> Result<()> {
110 let now = Utc::now().to_rfc3339();
111 self.conn.execute(
112 "UPDATE experiments SET status = ?1, finished_at = ?2, duration_secs = ?3 WHERE id = ?4",
113 (serde_json::to_string(&status)?, &now, duration_secs, id),
114 )?;
115 Ok(())
116 }
117
118 pub fn get(&self, id: &str) -> Result<Option<Experiment>> {
119 let mut stmt = self.conn.prepare(
120 "SELECT id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs FROM experiments WHERE id = ?1"
121 )?;
122
123 let mut rows = stmt.query_map([id], Self::row_to_experiment)?;
124 match rows.next() {
125 Some(Ok(exp)) => Ok(Some(exp)),
126 _ => Ok(None),
127 }
128 }
129
130 pub fn list(&self, limit: usize) -> Result<Vec<Experiment>> {
131 let mut stmt = self.conn.prepare(
132 "SELECT id, name, status, hyperparams, metrics, created_at, finished_at, git_commit, script, duration_secs FROM experiments ORDER BY created_at DESC LIMIT ?1"
133 )?;
134
135 let experiments = stmt
136 .query_map([limit], Self::row_to_experiment)?
137 .collect::<Result<Vec<_>, _>>()?;
138
139 Ok(experiments)
140 }
141
142 pub fn delete(&self, id: &str) -> Result<bool> {
143 let count = self
144 .conn
145 .execute("DELETE FROM experiments WHERE id = ?1", [id])?;
146 Ok(count > 0)
147 }
148
149 fn row_to_experiment(row: &rusqlite::Row) -> rusqlite::Result<Experiment> {
150 Ok(Experiment {
151 id: row.get(0)?,
152 name: row.get(1)?,
153 status: serde_json::from_str(&row.get::<_, String>(2)?)
154 .unwrap_or(ExperimentStatus::Failed),
155 hyperparams: serde_json::from_str(&row.get::<_, String>(3)?).unwrap_or_default(),
156 metrics: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
157 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
158 .unwrap_or_default()
159 .with_timezone(&Utc),
160 finished_at: row
161 .get::<_, Option<String>>(6)?
162 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
163 .map(|t| t.with_timezone(&Utc)),
164 git_commit: row.get(7)?,
165 script: row.get(8)?,
166 duration_secs: row.get(9)?,
167 })
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 fn temp_store() -> (ExperimentStore, tempfile::TempDir) {
176 let dir = tempfile::tempdir().unwrap();
177 let path = dir.path().join("test.db");
178 let store = ExperimentStore::open(&path).unwrap();
179 (store, dir)
180 }
181
182 fn make_experiment(id: &str) -> Experiment {
183 Experiment {
184 id: id.into(),
185 name: format!("test-{id}"),
186 status: ExperimentStatus::Running,
187 hyperparams: HashMap::from([("lr".into(), serde_json::json!(0.001))]),
188 metrics: HashMap::from([("loss".into(), 1.5)]),
189 created_at: Utc::now(),
190 finished_at: None,
191 git_commit: Some("abc123".into()),
192 script: Some("train.py".into()),
193 duration_secs: None,
194 }
195 }
196
197 #[test]
198 fn insert_and_list() {
199 let (store, _dir) = temp_store();
200 store.insert(&make_experiment("exp-001")).unwrap();
201 store.insert(&make_experiment("exp-002")).unwrap();
202
203 let list = store.list(10).unwrap();
204 assert_eq!(list.len(), 2);
205 }
206
207 #[test]
208 fn get_by_id() {
209 let (store, _dir) = temp_store();
210 store.insert(&make_experiment("exp-001")).unwrap();
211
212 let exp = store.get("exp-001").unwrap().unwrap();
213 assert_eq!(exp.name, "test-exp-001");
214 }
215
216 #[test]
217 fn finish_updates_status() {
218 let (store, _dir) = temp_store();
219 store.insert(&make_experiment("exp-001")).unwrap();
220 store
221 .finish("exp-001", ExperimentStatus::Done, 123.4)
222 .unwrap();
223
224 let exp = store.get("exp-001").unwrap().unwrap();
225 assert_eq!(exp.status, ExperimentStatus::Done);
226 assert!((exp.duration_secs.unwrap() - 123.4).abs() < 0.01);
227 }
228
229 #[test]
230 fn delete_experiment() {
231 let (store, _dir) = temp_store();
232 store.insert(&make_experiment("exp-001")).unwrap();
233 assert!(store.delete("exp-001").unwrap());
234 assert!(store.get("exp-001").unwrap().is_none());
235 }
236}