comptime sql bindings for zig
ziglang
sql
1//! comptime sql parsing
2//!
3//! extracts metadata from sql strings at compile time:
4//! - column names from SELECT clauses
5//! - named parameter names (:name -> ?)
6//! - positional sql with named params converted
7//!
8//! sql injection safety:
9//! - sql strings are comptime, so user input cannot be concatenated
10//! - parameters are bound via prepared statements, not interpolated
11//! - the :name syntax reinforces parameterized query patterns
12//!
13//! limitations:
14//! - SELECT * returns empty columns (can't know schema)
15//! - no subquery support in column extraction
16//! - no quoted identifier support ("column name")
17
18const std = @import("std");
19
20/// max named parameters per query
21pub const MAX_PARAMS = 32;
22
23/// max columns per SELECT
24pub const MAX_COLS = 64;
25
26/// max sql string length
27pub const MAX_SQL_LEN = 4096;
28
29/// result of parsing a sql string at comptime
30pub const ParseResult = struct {
31 /// sql with :name params replaced by ?
32 positional: [MAX_SQL_LEN]u8,
33 positional_len: usize,
34
35 /// total parameter count (named + positional)
36 param_count: usize,
37
38 /// extracted named parameter names in order
39 params: [MAX_PARAMS][]const u8,
40 params_len: usize,
41
42 /// extracted column names/aliases from SELECT
43 columns: [MAX_COLS][]const u8,
44 columns_len: usize,
45};
46
47pub fn parse(comptime sql: []const u8) ParseResult {
48 @setEvalBranchQuota(sql.len * 100);
49 var result = ParseResult{
50 .positional = undefined,
51 .positional_len = 0,
52 .param_count = 0,
53 .params = undefined,
54 .params_len = 0,
55 .columns = undefined,
56 .columns_len = 0,
57 };
58
59 parseParams(sql, &result);
60 parseColumns(sql, &result);
61
62 return result;
63}
64
65fn parseParams(comptime sql: []const u8, result: *ParseResult) void {
66 var i: usize = 0;
67 while (i < sql.len) : (i += 1) {
68 if (sql[i] == '?') {
69 result.positional[result.positional_len] = '?';
70 result.positional_len += 1;
71 result.param_count += 1;
72 } else if (sql[i] == ':' and i + 1 < sql.len and isIdentStart(sql[i + 1])) {
73 result.positional[result.positional_len] = '?';
74 result.positional_len += 1;
75 result.param_count += 1;
76
77 const start = i + 1;
78 var end = start;
79 while (end < sql.len and isIdentChar(sql[end])) : (end += 1) {}
80 result.params[result.params_len] = sql[start..end];
81 result.params_len += 1;
82 i = end - 1;
83 } else {
84 result.positional[result.positional_len] = sql[i];
85 result.positional_len += 1;
86 }
87 }
88}
89
90fn parseColumns(comptime sql: []const u8, result: *ParseResult) void {
91 const select_start = findSelectStart(sql) orelse return;
92 const from_pos = findFromPos(sql, select_start) orelse sql.len;
93
94 // work directly with sql and offset, not a sub-slice
95 const cols_start = select_start + countLeadingWhitespace(sql[select_start..from_pos]);
96 const cols_end = from_pos - countTrailingWhitespace(sql[select_start..from_pos]);
97
98 if (cols_start >= cols_end) return;
99 if (std.mem.eql(u8, sql[cols_start..cols_end], "*")) return;
100
101 var col_i: usize = cols_start;
102 var paren_depth: usize = 0;
103
104 while (col_i < cols_end) {
105 while (col_i < cols_end and isWhitespace(sql[col_i])) : (col_i += 1) {}
106 if (col_i >= cols_end) break;
107
108 var last_ident_start: ?usize = null;
109 var last_ident_end: ?usize = null;
110
111 while (col_i < cols_end) : (col_i += 1) {
112 const c = sql[col_i];
113 if (c == '(') {
114 paren_depth += 1;
115 } else if (c == ')') {
116 paren_depth -|= 1;
117 } else if (c == ',' and paren_depth == 0) {
118 break;
119 } else if (isIdentStart(c) and paren_depth == 0) {
120 last_ident_start = col_i;
121 while (col_i < cols_end and isIdentChar(sql[col_i])) : (col_i += 1) {}
122 last_ident_end = col_i;
123 col_i -= 1;
124 }
125 }
126
127 if (last_ident_start) |s| {
128 result.columns[result.columns_len] = sql[s..last_ident_end.?];
129 result.columns_len += 1;
130 }
131
132 if (col_i < cols_end and sql[col_i] == ',') col_i += 1;
133 }
134}
135
136fn countLeadingWhitespace(s: []const u8) usize {
137 var i: usize = 0;
138 while (i < s.len and isWhitespace(s[i])) : (i += 1) {}
139 return i;
140}
141
142fn countTrailingWhitespace(s: []const u8) usize {
143 var i: usize = 0;
144 while (i < s.len and isWhitespace(s[s.len - 1 - i])) : (i += 1) {}
145 return i;
146}
147
148fn isIdentStart(c: u8) bool {
149 return (c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z') or c == '_';
150}
151
152fn isIdentChar(c: u8) bool {
153 return isIdentStart(c) or (c >= '0' and c <= '9');
154}
155
156fn isWhitespace(c: u8) bool {
157 return c == ' ' or c == '\t' or c == '\n' or c == '\r';
158}
159
160fn findSelectStart(comptime sql: []const u8) ?usize {
161 var upper: [sql.len]u8 = undefined;
162 for (sql, 0..) |c, idx| {
163 upper[idx] = if (c >= 'a' and c <= 'z') c - 32 else c;
164 }
165 const idx = std.mem.indexOf(u8, &upper, "SELECT") orelse return null;
166 return idx + 6;
167}
168
169fn findFromPos(comptime sql: []const u8, start: usize) ?usize {
170 var upper: [sql.len]u8 = undefined;
171 for (sql, 0..) |c, idx| {
172 upper[idx] = if (c >= 'a' and c <= 'z') c - 32 else c;
173 }
174 var paren_depth: usize = 0;
175 var j = start;
176 while (j + 4 <= sql.len) : (j += 1) {
177 if (upper[j] == '(') {
178 paren_depth += 1;
179 } else if (upper[j] == ')') {
180 paren_depth -|= 1;
181 } else if (paren_depth == 0 and std.mem.eql(u8, upper[j .. j + 4], "FROM")) {
182 return j;
183 }
184 }
185 return null;
186}
187
188// -----------------------------------------------------------------------------
189// column extraction tests
190// -----------------------------------------------------------------------------
191
192test "columns: basic select" {
193 const r = comptime parse("SELECT id, name, age FROM users");
194 try std.testing.expectEqual(3, r.columns_len);
195 try std.testing.expectEqualStrings("id", r.columns[0]);
196 try std.testing.expectEqualStrings("name", r.columns[1]);
197 try std.testing.expectEqualStrings("age", r.columns[2]);
198}
199
200test "columns: with alias" {
201 const r = comptime parse("SELECT id, first_name AS name FROM users");
202 try std.testing.expectEqual(2, r.columns_len);
203 try std.testing.expectEqualStrings("id", r.columns[0]);
204 try std.testing.expectEqualStrings("name", r.columns[1]);
205}
206
207test "columns: with function" {
208 const r = comptime parse("SELECT COUNT(*) AS total, MAX(age) AS oldest FROM users");
209 try std.testing.expectEqual(2, r.columns_len);
210 try std.testing.expectEqualStrings("total", r.columns[0]);
211 try std.testing.expectEqualStrings("oldest", r.columns[1]);
212}
213
214test "columns: nested function" {
215 const r = comptime parse("SELECT COALESCE(name, 'unknown') AS name FROM users");
216 try std.testing.expectEqual(1, r.columns_len);
217 try std.testing.expectEqualStrings("name", r.columns[0]);
218}
219
220test "columns: table qualified" {
221 const r = comptime parse("SELECT u.id, u.name FROM users u");
222 try std.testing.expectEqual(2, r.columns_len);
223 try std.testing.expectEqualStrings("id", r.columns[0]);
224 try std.testing.expectEqualStrings("name", r.columns[1]);
225}
226
227test "columns: case expression" {
228 const r = comptime parse("SELECT CASE WHEN x > 0 THEN 1 ELSE 0 END AS flag FROM t");
229 try std.testing.expectEqual(1, r.columns_len);
230 try std.testing.expectEqualStrings("flag", r.columns[0]);
231}
232
233test "columns: empty string literal" {
234 const r = comptime parse("SELECT id, '' AS empty FROM users");
235 try std.testing.expectEqual(2, r.columns_len);
236 try std.testing.expectEqualStrings("id", r.columns[0]);
237 try std.testing.expectEqualStrings("empty", r.columns[1]);
238}
239
240test "columns: select star returns empty" {
241 const r = comptime parse("SELECT * FROM users");
242 try std.testing.expectEqual(0, r.columns_len);
243}
244
245test "columns: multiline sql" {
246 const r = comptime parse(
247 \\SELECT id, name,
248 \\ created_at
249 \\FROM users
250 );
251 try std.testing.expectEqual(3, r.columns_len);
252 try std.testing.expectEqualStrings("id", r.columns[0]);
253 try std.testing.expectEqualStrings("name", r.columns[1]);
254 try std.testing.expectEqualStrings("created_at", r.columns[2]);
255}
256
257test "columns: snippet function (fts5)" {
258 const r = comptime parse(
259 \\SELECT uri, snippet(docs_fts, 1, '<b>', '</b>', '...', 32) AS snippet
260 \\FROM docs_fts
261 );
262 try std.testing.expectEqual(2, r.columns_len);
263 try std.testing.expectEqualStrings("uri", r.columns[0]);
264 try std.testing.expectEqualStrings("snippet", r.columns[1]);
265}
266
267// -----------------------------------------------------------------------------
268// parameter extraction tests
269// -----------------------------------------------------------------------------
270
271test "params: named" {
272 const r = comptime parse("SELECT * FROM users WHERE id = :id AND age > :min_age");
273 try std.testing.expectEqual(2, r.params_len);
274 try std.testing.expectEqualStrings("id", r.params[0]);
275 try std.testing.expectEqualStrings("min_age", r.params[1]);
276}
277
278test "params: positional passthrough" {
279 const r = comptime parse("SELECT * FROM users WHERE id = ? AND age > ?");
280 try std.testing.expectEqual(0, r.params_len); // no named params
281 try std.testing.expectEqual(2, r.param_count); // but two positional
282}
283
284test "params: mixed named and positional" {
285 const r = comptime parse("SELECT * FROM users WHERE id = :id AND age > ?");
286 try std.testing.expectEqual(1, r.params_len);
287 try std.testing.expectEqualStrings("id", r.params[0]);
288 try std.testing.expectEqual(2, r.param_count);
289}
290
291test "params: conversion to positional" {
292 const r = comptime parse("INSERT INTO users (name, age) VALUES (:name, :age)");
293 try std.testing.expectEqualStrings(
294 "INSERT INTO users (name, age) VALUES (?, ?)",
295 r.positional[0..r.positional_len],
296 );
297}
298
299test "params: underscore in name" {
300 const r = comptime parse("SELECT * FROM t WHERE x = :my_param_name");
301 try std.testing.expectEqual(1, r.params_len);
302 try std.testing.expectEqualStrings("my_param_name", r.params[0]);
303}
304
305test "params: no params" {
306 const r = comptime parse("SELECT id FROM users");
307 try std.testing.expectEqual(0, r.params_len);
308 try std.testing.expectEqual(0, r.param_count);
309}