comptime sql bindings for zig
ziglang sql
at main 309 lines 10 kB view raw
1//! comptime sql parsing 2//! 3//! extracts metadata from sql strings at compile time: 4//! - column names from SELECT clauses 5//! - named parameter names (:name -> ?) 6//! - positional sql with named params converted 7//! 8//! sql injection safety: 9//! - sql strings are comptime, so user input cannot be concatenated 10//! - parameters are bound via prepared statements, not interpolated 11//! - the :name syntax reinforces parameterized query patterns 12//! 13//! limitations: 14//! - SELECT * returns empty columns (can't know schema) 15//! - no subquery support in column extraction 16//! - no quoted identifier support ("column name") 17 18const std = @import("std"); 19 20/// max named parameters per query 21pub const MAX_PARAMS = 32; 22 23/// max columns per SELECT 24pub const MAX_COLS = 64; 25 26/// max sql string length 27pub const MAX_SQL_LEN = 4096; 28 29/// result of parsing a sql string at comptime 30pub const ParseResult = struct { 31 /// sql with :name params replaced by ? 32 positional: [MAX_SQL_LEN]u8, 33 positional_len: usize, 34 35 /// total parameter count (named + positional) 36 param_count: usize, 37 38 /// extracted named parameter names in order 39 params: [MAX_PARAMS][]const u8, 40 params_len: usize, 41 42 /// extracted column names/aliases from SELECT 43 columns: [MAX_COLS][]const u8, 44 columns_len: usize, 45}; 46 47pub fn parse(comptime sql: []const u8) ParseResult { 48 @setEvalBranchQuota(sql.len * 100); 49 var result = ParseResult{ 50 .positional = undefined, 51 .positional_len = 0, 52 .param_count = 0, 53 .params = undefined, 54 .params_len = 0, 55 .columns = undefined, 56 .columns_len = 0, 57 }; 58 59 parseParams(sql, &result); 60 parseColumns(sql, &result); 61 62 return result; 63} 64 65fn parseParams(comptime sql: []const u8, result: *ParseResult) void { 66 var i: usize = 0; 67 while (i < sql.len) : (i += 1) { 68 if (sql[i] == '?') { 69 result.positional[result.positional_len] = '?'; 70 result.positional_len += 1; 71 result.param_count += 1; 72 } else if (sql[i] == ':' and i + 1 < sql.len and isIdentStart(sql[i + 1])) { 73 result.positional[result.positional_len] = '?'; 74 result.positional_len += 1; 75 result.param_count += 1; 76 77 const start = i + 1; 78 var end = start; 79 while (end < sql.len and isIdentChar(sql[end])) : (end += 1) {} 80 result.params[result.params_len] = sql[start..end]; 81 result.params_len += 1; 82 i = end - 1; 83 } else { 84 result.positional[result.positional_len] = sql[i]; 85 result.positional_len += 1; 86 } 87 } 88} 89 90fn parseColumns(comptime sql: []const u8, result: *ParseResult) void { 91 const select_start = findSelectStart(sql) orelse return; 92 const from_pos = findFromPos(sql, select_start) orelse sql.len; 93 94 // work directly with sql and offset, not a sub-slice 95 const cols_start = select_start + countLeadingWhitespace(sql[select_start..from_pos]); 96 const cols_end = from_pos - countTrailingWhitespace(sql[select_start..from_pos]); 97 98 if (cols_start >= cols_end) return; 99 if (std.mem.eql(u8, sql[cols_start..cols_end], "*")) return; 100 101 var col_i: usize = cols_start; 102 var paren_depth: usize = 0; 103 104 while (col_i < cols_end) { 105 while (col_i < cols_end and isWhitespace(sql[col_i])) : (col_i += 1) {} 106 if (col_i >= cols_end) break; 107 108 var last_ident_start: ?usize = null; 109 var last_ident_end: ?usize = null; 110 111 while (col_i < cols_end) : (col_i += 1) { 112 const c = sql[col_i]; 113 if (c == '(') { 114 paren_depth += 1; 115 } else if (c == ')') { 116 paren_depth -|= 1; 117 } else if (c == ',' and paren_depth == 0) { 118 break; 119 } else if (isIdentStart(c) and paren_depth == 0) { 120 last_ident_start = col_i; 121 while (col_i < cols_end and isIdentChar(sql[col_i])) : (col_i += 1) {} 122 last_ident_end = col_i; 123 col_i -= 1; 124 } 125 } 126 127 if (last_ident_start) |s| { 128 result.columns[result.columns_len] = sql[s..last_ident_end.?]; 129 result.columns_len += 1; 130 } 131 132 if (col_i < cols_end and sql[col_i] == ',') col_i += 1; 133 } 134} 135 136fn countLeadingWhitespace(s: []const u8) usize { 137 var i: usize = 0; 138 while (i < s.len and isWhitespace(s[i])) : (i += 1) {} 139 return i; 140} 141 142fn countTrailingWhitespace(s: []const u8) usize { 143 var i: usize = 0; 144 while (i < s.len and isWhitespace(s[s.len - 1 - i])) : (i += 1) {} 145 return i; 146} 147 148fn isIdentStart(c: u8) bool { 149 return (c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z') or c == '_'; 150} 151 152fn isIdentChar(c: u8) bool { 153 return isIdentStart(c) or (c >= '0' and c <= '9'); 154} 155 156fn isWhitespace(c: u8) bool { 157 return c == ' ' or c == '\t' or c == '\n' or c == '\r'; 158} 159 160fn findSelectStart(comptime sql: []const u8) ?usize { 161 var upper: [sql.len]u8 = undefined; 162 for (sql, 0..) |c, idx| { 163 upper[idx] = if (c >= 'a' and c <= 'z') c - 32 else c; 164 } 165 const idx = std.mem.indexOf(u8, &upper, "SELECT") orelse return null; 166 return idx + 6; 167} 168 169fn findFromPos(comptime sql: []const u8, start: usize) ?usize { 170 var upper: [sql.len]u8 = undefined; 171 for (sql, 0..) |c, idx| { 172 upper[idx] = if (c >= 'a' and c <= 'z') c - 32 else c; 173 } 174 var paren_depth: usize = 0; 175 var j = start; 176 while (j + 4 <= sql.len) : (j += 1) { 177 if (upper[j] == '(') { 178 paren_depth += 1; 179 } else if (upper[j] == ')') { 180 paren_depth -|= 1; 181 } else if (paren_depth == 0 and std.mem.eql(u8, upper[j .. j + 4], "FROM")) { 182 return j; 183 } 184 } 185 return null; 186} 187 188// ----------------------------------------------------------------------------- 189// column extraction tests 190// ----------------------------------------------------------------------------- 191 192test "columns: basic select" { 193 const r = comptime parse("SELECT id, name, age FROM users"); 194 try std.testing.expectEqual(3, r.columns_len); 195 try std.testing.expectEqualStrings("id", r.columns[0]); 196 try std.testing.expectEqualStrings("name", r.columns[1]); 197 try std.testing.expectEqualStrings("age", r.columns[2]); 198} 199 200test "columns: with alias" { 201 const r = comptime parse("SELECT id, first_name AS name FROM users"); 202 try std.testing.expectEqual(2, r.columns_len); 203 try std.testing.expectEqualStrings("id", r.columns[0]); 204 try std.testing.expectEqualStrings("name", r.columns[1]); 205} 206 207test "columns: with function" { 208 const r = comptime parse("SELECT COUNT(*) AS total, MAX(age) AS oldest FROM users"); 209 try std.testing.expectEqual(2, r.columns_len); 210 try std.testing.expectEqualStrings("total", r.columns[0]); 211 try std.testing.expectEqualStrings("oldest", r.columns[1]); 212} 213 214test "columns: nested function" { 215 const r = comptime parse("SELECT COALESCE(name, 'unknown') AS name FROM users"); 216 try std.testing.expectEqual(1, r.columns_len); 217 try std.testing.expectEqualStrings("name", r.columns[0]); 218} 219 220test "columns: table qualified" { 221 const r = comptime parse("SELECT u.id, u.name FROM users u"); 222 try std.testing.expectEqual(2, r.columns_len); 223 try std.testing.expectEqualStrings("id", r.columns[0]); 224 try std.testing.expectEqualStrings("name", r.columns[1]); 225} 226 227test "columns: case expression" { 228 const r = comptime parse("SELECT CASE WHEN x > 0 THEN 1 ELSE 0 END AS flag FROM t"); 229 try std.testing.expectEqual(1, r.columns_len); 230 try std.testing.expectEqualStrings("flag", r.columns[0]); 231} 232 233test "columns: empty string literal" { 234 const r = comptime parse("SELECT id, '' AS empty FROM users"); 235 try std.testing.expectEqual(2, r.columns_len); 236 try std.testing.expectEqualStrings("id", r.columns[0]); 237 try std.testing.expectEqualStrings("empty", r.columns[1]); 238} 239 240test "columns: select star returns empty" { 241 const r = comptime parse("SELECT * FROM users"); 242 try std.testing.expectEqual(0, r.columns_len); 243} 244 245test "columns: multiline sql" { 246 const r = comptime parse( 247 \\SELECT id, name, 248 \\ created_at 249 \\FROM users 250 ); 251 try std.testing.expectEqual(3, r.columns_len); 252 try std.testing.expectEqualStrings("id", r.columns[0]); 253 try std.testing.expectEqualStrings("name", r.columns[1]); 254 try std.testing.expectEqualStrings("created_at", r.columns[2]); 255} 256 257test "columns: snippet function (fts5)" { 258 const r = comptime parse( 259 \\SELECT uri, snippet(docs_fts, 1, '<b>', '</b>', '...', 32) AS snippet 260 \\FROM docs_fts 261 ); 262 try std.testing.expectEqual(2, r.columns_len); 263 try std.testing.expectEqualStrings("uri", r.columns[0]); 264 try std.testing.expectEqualStrings("snippet", r.columns[1]); 265} 266 267// ----------------------------------------------------------------------------- 268// parameter extraction tests 269// ----------------------------------------------------------------------------- 270 271test "params: named" { 272 const r = comptime parse("SELECT * FROM users WHERE id = :id AND age > :min_age"); 273 try std.testing.expectEqual(2, r.params_len); 274 try std.testing.expectEqualStrings("id", r.params[0]); 275 try std.testing.expectEqualStrings("min_age", r.params[1]); 276} 277 278test "params: positional passthrough" { 279 const r = comptime parse("SELECT * FROM users WHERE id = ? AND age > ?"); 280 try std.testing.expectEqual(0, r.params_len); // no named params 281 try std.testing.expectEqual(2, r.param_count); // but two positional 282} 283 284test "params: mixed named and positional" { 285 const r = comptime parse("SELECT * FROM users WHERE id = :id AND age > ?"); 286 try std.testing.expectEqual(1, r.params_len); 287 try std.testing.expectEqualStrings("id", r.params[0]); 288 try std.testing.expectEqual(2, r.param_count); 289} 290 291test "params: conversion to positional" { 292 const r = comptime parse("INSERT INTO users (name, age) VALUES (:name, :age)"); 293 try std.testing.expectEqualStrings( 294 "INSERT INTO users (name, age) VALUES (?, ?)", 295 r.positional[0..r.positional_len], 296 ); 297} 298 299test "params: underscore in name" { 300 const r = comptime parse("SELECT * FROM t WHERE x = :my_param_name"); 301 try std.testing.expectEqual(1, r.params_len); 302 try std.testing.expectEqualStrings("my_param_name", r.params[0]); 303} 304 305test "params: no params" { 306 const r = comptime parse("SELECT id FROM users"); 307 try std.testing.expectEqual(0, r.params_len); 308 try std.testing.expectEqual(0, r.param_count); 309}