1use nom::{
11 branch::alt,
12 bytes::complete::{tag_no_case, take_while1},
13 character::complete::{char, multispace0, multispace1},
14 combinator::{map, opt},
15 multi::separated_list1,
16 number::complete::double,
17 sequence::{delimited, preceded, tuple},
18 IResult,
19};
20
21#[derive(Debug, PartialEq)]
22pub struct ZqlQuery {
23 pub select: Vec<String>,
24 pub from: String,
25 pub where_clause: Option<WhereClause>,
26 pub order_by: Option<OrderBy>,
27 pub limit: Option<usize>,
28}
29
30#[derive(Debug, PartialEq)]
31pub struct WhereClause {
32 pub conditions: Vec<Condition>,
33}
34
35#[derive(Debug, PartialEq)]
36pub struct Condition {
37 pub field: String,
38 pub op: CompareOp,
39 pub value: Value,
40}
41
42#[derive(Debug, PartialEq)]
43pub enum CompareOp {
44 Eq,
45 NotEq,
46 Lt,
47 Gt,
48 Lte,
49 Gte,
50}
51
52#[derive(Debug, PartialEq)]
53pub enum Value {
54 Number(f64),
55 Text(String),
56}
57
58#[derive(Debug, PartialEq)]
59pub struct OrderBy {
60 pub field: String,
61 pub direction: Direction,
62}
63
64#[derive(Debug, PartialEq)]
65pub enum Direction {
66 Asc,
67 Desc,
68}
69
70fn identifier(input: &str) -> IResult<&str, String> {
71 let (input, id) = take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '*')(input)?;
72 Ok((input, id.to_string()))
73}
74
75fn select_clause(input: &str) -> IResult<&str, Vec<String>> {
76 let (input, _) = tag_no_case("SELECT")(input)?;
77 let (input, _) = multispace1(input)?;
78 separated_list1(tuple((multispace0, char(','), multispace0)), identifier)(input)
79}
80
81fn from_clause(input: &str) -> IResult<&str, String> {
82 let (input, _) = multispace1(input)?;
83 let (input, _) = tag_no_case("FROM")(input)?;
84 let (input, _) = multispace1(input)?;
85 identifier(input)
86}
87
88fn compare_op(input: &str) -> IResult<&str, CompareOp> {
89 alt((
90 map(tag_no_case("!="), |_| CompareOp::NotEq),
91 map(tag_no_case("<="), |_| CompareOp::Lte),
92 map(tag_no_case(">="), |_| CompareOp::Gte),
93 map(char('<'), |_| CompareOp::Lt),
94 map(char('>'), |_| CompareOp::Gt),
95 map(char('='), |_| CompareOp::Eq),
96 ))(input)
97}
98
99fn quoted_string(input: &str) -> IResult<&str, String> {
100 let (input, s) = delimited(char('\''), take_while1(|c: char| c != '\''), char('\''))(input)?;
101 Ok((input, s.to_string()))
102}
103
104fn value(input: &str) -> IResult<&str, Value> {
105 alt((map(quoted_string, Value::Text), map(double, Value::Number)))(input)
106}
107
108fn condition(input: &str) -> IResult<&str, Condition> {
109 let (input, field) = identifier(input)?;
110 let (input, _) = multispace0(input)?;
111 let (input, op) = compare_op(input)?;
112 let (input, _) = multispace0(input)?;
113 let (input, val) = value(input)?;
114 Ok((
115 input,
116 Condition {
117 field,
118 op,
119 value: val,
120 },
121 ))
122}
123
124fn where_clause(input: &str) -> IResult<&str, WhereClause> {
125 let (input, _) = multispace1(input)?;
126 let (input, _) = tag_no_case("WHERE")(input)?;
127 let (input, _) = multispace1(input)?;
128 let (input, conditions) = separated_list1(
129 tuple((multispace1, tag_no_case("AND"), multispace1)),
130 condition,
131 )(input)?;
132 Ok((input, WhereClause { conditions }))
133}
134
135fn direction(input: &str) -> IResult<&str, Direction> {
136 alt((
137 map(tag_no_case("ASC"), |_| Direction::Asc),
138 map(tag_no_case("DESC"), |_| Direction::Desc),
139 ))(input)
140}
141
142fn order_by_clause(input: &str) -> IResult<&str, OrderBy> {
143 let (input, _) = multispace1(input)?;
144 let (input, _) = tag_no_case("ORDER")(input)?;
145 let (input, _) = multispace1(input)?;
146 let (input, _) = tag_no_case("BY")(input)?;
147 let (input, _) = multispace1(input)?;
148 let (input, field) = identifier(input)?;
149 let (input, dir) = opt(preceded(multispace1, direction))(input)?;
150 Ok((
151 input,
152 OrderBy {
153 field,
154 direction: dir.unwrap_or(Direction::Asc),
155 },
156 ))
157}
158
159fn limit_clause(input: &str) -> IResult<&str, usize> {
160 let (input, _) = multispace1(input)?;
161 let (input, _) = tag_no_case("LIMIT")(input)?;
162 let (input, _) = multispace1(input)?;
163 let (input, n) = double(input)?;
164 Ok((input, n as usize))
165}
166
167fn zql_query(input: &str) -> IResult<&str, ZqlQuery> {
168 let (input, _) = multispace0(input)?;
169 let (input, select) = select_clause(input)?;
170 let (input, from) = from_clause(input)?;
171 let (input, where_cl) = opt(where_clause)(input)?;
172 let (input, order) = opt(order_by_clause)(input)?;
173 let (input, limit) = opt(limit_clause)(input)?;
174 let (input, _) = multispace0(input)?;
175 let (input, _) = opt(char(';'))(input)?;
176 let (input, _) = multispace0(input)?;
177
178 Ok((
179 input,
180 ZqlQuery {
181 select,
182 from,
183 where_clause: where_cl,
184 order_by: order,
185 limit,
186 },
187 ))
188}
189
190pub fn parse(input: &str) -> Result<ZqlQuery, String> {
192 match zql_query(input) {
193 Ok(("", query)) => Ok(query),
194 Ok((remaining, _)) => Err(format!("unexpected trailing input: '{remaining}'")),
195 Err(e) => Err(format!("parse error: {e}")),
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn parse_simple_select() {
205 let q = parse("SELECT name, loss FROM experiments").unwrap();
206 assert_eq!(q.select, vec!["name", "loss"]);
207 assert_eq!(q.from, "experiments");
208 assert!(q.where_clause.is_none());
209 }
210
211 #[test]
212 fn parse_with_where() {
213 let q = parse("SELECT * FROM experiments WHERE loss < 1.5").unwrap();
214 assert_eq!(q.select, vec!["*"]);
215 let w = q.where_clause.unwrap();
216 assert_eq!(w.conditions.len(), 1);
217 assert_eq!(w.conditions[0].field, "loss");
218 assert_eq!(w.conditions[0].op, CompareOp::Lt);
219 assert_eq!(w.conditions[0].value, Value::Number(1.5));
220 }
221
222 #[test]
223 fn parse_with_order_and_limit() {
224 let q = parse("SELECT name FROM experiments ORDER BY loss ASC LIMIT 10;").unwrap();
225 let ob = q.order_by.unwrap();
226 assert_eq!(ob.field, "loss");
227 assert_eq!(ob.direction, Direction::Asc);
228 assert_eq!(q.limit, Some(10));
229 }
230
231 #[test]
232 fn parse_where_with_string() {
233 let q = parse("SELECT * FROM experiments WHERE name = 'baseline'").unwrap();
234 let w = q.where_clause.unwrap();
235 assert_eq!(w.conditions[0].value, Value::Text("baseline".into()));
236 }
237
238 #[test]
239 fn parse_multiple_conditions() {
240 let q = parse("SELECT * FROM experiments WHERE loss < 1.5 AND accuracy > 0.9").unwrap();
241 let w = q.where_clause.unwrap();
242 assert_eq!(w.conditions.len(), 2);
243 }
244
245 #[test]
246 fn parse_error_on_garbage() {
247 assert!(parse("not a query").is_err());
248 }
249}
250
251#[cfg(test)]
252mod proptests {
253 use super::*;
254 use proptest::prelude::*;
255
256 proptest! {
257 #[test]
259 fn parse_never_panics(input in ".*") {
260 let _ = parse(&input);
262 }
263
264 #[test]
266 fn parse_is_deterministic(input in ".*") {
267 let r1 = parse(&input);
268 let r2 = parse(&input);
269 match (&r1, &r2) {
270 (Ok(_), Ok(_)) => {} (Err(e1), Err(e2)) => assert_eq!(e1, e2),
272 _ => panic!("non-deterministic parse"),
273 }
274 }
275
276 #[test]
278 fn valid_select_always_parses(
279 col in "[a-z_]{1,10}",
280 table in "(experiments|jobs|models)",
281 ) {
282 let query = format!("SELECT {col} FROM {table}");
283 assert!(parse(&query).is_ok(), "failed to parse: {query}");
284 }
285 }
286}