atproto relay implementation in zig zlay.waow.tech
at 3d236cdd98cd08ebe120f1ca000dd98228bc398a 450 lines 17 kB view raw
1//! relay frame validator — DID key resolution + real signature verification 2//! 3//! validates firehose commit frames by verifying the commit signature against 4//! the pre-resolved signing key for the DID. accepts pre-decoded CBOR payload 5//! from the subscriber (decoded via zat SDK). on cache miss, skips validation 6//! and queues background resolution. no frame is ever blocked on network I/O. 7 8const std = @import("std"); 9const zat = @import("zat"); 10const broadcaster = @import("broadcaster.zig"); 11 12const Allocator = std.mem.Allocator; 13const log = std.log.scoped(.relay); 14 15/// decoded and cached signing key for a DID 16const CachedKey = struct { 17 key_type: zat.multicodec.KeyType, 18 raw: [33]u8, // compressed public key (secp256k1 or p256) 19 len: u8, 20 resolve_time: i64 = 0, // epoch seconds when resolved 21}; 22 23pub const ValidationResult = struct { 24 valid: bool, 25 skipped: bool, 26 data_cid: ?[]const u8 = null, // MST root CID from verified commit 27 commit_rev: ?[]const u8 = null, // rev from verified commit 28}; 29 30/// configuration for commit validation checks 31pub const ValidatorConfig = struct { 32 /// verify MST structure during signature verification 33 verify_mst: bool = false, // off by default for relay throughput 34 /// verify commit diffs via MST inversion (sync 1.1) 35 verify_commit_diff: bool = false, 36 /// max allowed operations per commit 37 max_ops: usize = 200, 38 /// max clock skew for rev timestamps (seconds) 39 rev_clock_skew: i64 = 300, // 5 minutes 40}; 41 42pub const Validator = struct { 43 allocator: Allocator, 44 stats: *broadcaster.Stats, 45 config: ValidatorConfig, 46 // DID → signing key cache (decoded, ready for verification) 47 cache: std.StringHashMapUnmanaged(CachedKey) = .{}, 48 cache_mutex: std.Thread.Mutex = .{}, 49 // background resolve queue 50 queue: std.ArrayListUnmanaged([]const u8) = .{}, 51 queue_mutex: std.Thread.Mutex = .{}, 52 queue_cond: std.Thread.Condition = .{}, 53 resolver_thread: ?std.Thread = null, 54 alive: std.atomic.Value(bool) = .{ .raw = true }, 55 56 pub fn init(allocator: Allocator, stats: *broadcaster.Stats) Validator { 57 return initWithConfig(allocator, stats, .{}); 58 } 59 60 pub fn initWithConfig(allocator: Allocator, stats: *broadcaster.Stats, config: ValidatorConfig) Validator { 61 return .{ 62 .allocator = allocator, 63 .stats = stats, 64 .config = config, 65 }; 66 } 67 68 pub fn deinit(self: *Validator) void { 69 self.alive.store(false, .release); 70 self.queue_cond.signal(); 71 if (self.resolver_thread) |t| t.join(); 72 73 // free cache keys (CachedKey is inline, no separate free needed) 74 var cache_it = self.cache.iterator(); 75 while (cache_it.next()) |entry| { 76 self.allocator.free(entry.key_ptr.*); 77 } 78 self.cache.deinit(self.allocator); 79 80 // free queued DIDs 81 for (self.queue.items) |did| { 82 self.allocator.free(did); 83 } 84 self.queue.deinit(self.allocator); 85 } 86 87 /// start the background resolver thread 88 pub fn start(self: *Validator) !void { 89 self.resolver_thread = try std.Thread.spawn(.{}, resolveLoop, .{self}); 90 } 91 92 /// validate a commit frame using a pre-decoded CBOR payload (from SDK decoder). 93 /// on cache miss, queues background resolution and skips. 94 pub fn validateCommit(self: *Validator, payload: zat.cbor.Value) ValidationResult { 95 // extract DID from decoded payload 96 const did = payload.getString("repo") orelse { 97 _ = self.stats.skipped.fetchAdd(1, .monotonic); 98 return .{ .valid = true, .skipped = true }; 99 }; 100 101 // check cache for pre-resolved signing key 102 const cached_key: ?CachedKey = blk: { 103 self.cache_mutex.lock(); 104 defer self.cache_mutex.unlock(); 105 break :blk self.cache.get(did); 106 }; 107 108 if (cached_key == null) { 109 // cache miss — queue for background resolution, skip validation 110 _ = self.stats.cache_misses.fetchAdd(1, .monotonic); 111 _ = self.stats.skipped.fetchAdd(1, .monotonic); 112 self.queueResolve(did); 113 return .{ .valid = true, .skipped = true }; 114 } 115 116 _ = self.stats.cache_hits.fetchAdd(1, .monotonic); 117 118 // cache hit — do structure checks + signature verification 119 if (self.verifyCommit(payload, did, cached_key.?)) |vr| { 120 _ = self.stats.validated.fetchAdd(1, .monotonic); 121 return vr; 122 } else |err| { 123 log.debug("commit verification failed for {s}: {s}", .{ did, @errorName(err) }); 124 _ = self.stats.failed.fetchAdd(1, .monotonic); 125 return .{ .valid = false, .skipped = false }; 126 } 127 } 128 129 fn verifyCommit(self: *Validator, payload: zat.cbor.Value, expected_did: []const u8, cached_key: CachedKey) !ValidationResult { 130 // commit structure checks first (cheap, no allocation) 131 try self.checkCommitStructure(payload); 132 133 // extract blocks (raw CAR bytes) from the pre-decoded payload 134 const blocks = payload.getBytes("blocks") orelse return error.InvalidFrame; 135 136 // blocks size check 137 if (blocks.len > 2 * 1024 * 1024) return error.InvalidFrame; 138 139 // build public key for verification 140 const public_key = zat.multicodec.PublicKey{ 141 .key_type = cached_key.key_type, 142 .raw = cached_key.raw[0..cached_key.len], 143 }; 144 145 // run real signature verification (needs its own arena for CAR/MST temporaries) 146 var arena = std.heap.ArenaAllocator.init(self.allocator); 147 defer arena.deinit(); 148 const alloc = arena.allocator(); 149 150 // try sync 1.1 path: extract ops and use verifyCommitDiff 151 if (self.config.verify_commit_diff) { 152 if (self.extractOps(alloc, payload)) |msg_ops| { 153 // get stored prev_data from payload 154 const prev_data: ?[]const u8 = if (payload.get("prevData")) |pd| switch (pd) { 155 .cid => |c| c.raw, 156 .null => null, 157 else => null, 158 } else null; 159 160 const diff_result = zat.verifyCommitDiff(alloc, blocks, msg_ops, prev_data, public_key, .{ 161 .expected_did = expected_did, 162 .skip_inversion = prev_data == null, 163 }) catch |err| { 164 return err; 165 }; 166 167 return .{ 168 .valid = true, 169 .skipped = false, 170 .data_cid = diff_result.data_cid, 171 .commit_rev = diff_result.commit_rev, 172 }; 173 } 174 } 175 176 // fallback: legacy verification (signature + optional MST walk) 177 const result = zat.verifyCommitCar(alloc, blocks, public_key, .{ 178 .verify_mst = self.config.verify_mst, 179 .expected_did = expected_did, 180 }) catch |err| { 181 return err; 182 }; 183 184 return .{ 185 .valid = true, 186 .skipped = false, 187 .data_cid = result.commit_cid, 188 .commit_rev = result.commit_rev, 189 }; 190 } 191 192 /// extract ops from payload and convert to mst.Operation array 193 fn extractOps(self: *Validator, alloc: Allocator, payload: zat.cbor.Value) ?[]const zat.MstOperation { 194 _ = self; 195 const ops_array = payload.getArray("ops") orelse return null; 196 var ops: std.ArrayListUnmanaged(zat.MstOperation) = .{}; 197 for (ops_array) |op| { 198 const action = op.getString("action") orelse continue; 199 const collection = op.getString("collection") orelse continue; 200 const rkey = op.getString("rkey") orelse continue; 201 202 // build path: "collection/rkey" 203 const path = std.fmt.allocPrint(alloc, "{s}/{s}", .{ collection, rkey }) catch return null; 204 205 // extract CID values 206 const cid_value: ?[]const u8 = if (op.get("cid")) |v| switch (v) { 207 .cid => |c| c.raw, 208 else => null, 209 } else null; 210 211 var value: ?[]const u8 = null; 212 var prev: ?[]const u8 = null; 213 214 if (std.mem.eql(u8, action, "create")) { 215 value = cid_value; 216 } else if (std.mem.eql(u8, action, "update")) { 217 value = cid_value; 218 // prev is extracted from the MST during inversion, not from payload 219 // for update ops, we need both value and prev — prev comes from prevData chain 220 prev = if (op.get("prev")) |v| switch (v) { 221 .cid => |c| c.raw, 222 else => null, 223 } else null; 224 } else if (std.mem.eql(u8, action, "delete")) { 225 prev = if (op.get("prev")) |v| switch (v) { 226 .cid => |c| c.raw, 227 else => null, 228 } else null; 229 } else continue; 230 231 ops.append(alloc, .{ 232 .path = path, 233 .value = value, 234 .prev = prev, 235 }) catch return null; 236 } 237 238 if (ops.items.len == 0) return null; 239 return ops.items; 240 } 241 242 fn checkCommitStructure(self: *Validator, payload: zat.cbor.Value) !void { 243 // check repo field is a valid DID 244 const repo = payload.getString("repo") orelse return error.InvalidFrame; 245 if (zat.Did.parse(repo) == null) return error.InvalidFrame; 246 247 // check rev is a valid TID 248 if (payload.getString("rev")) |rev| { 249 if (zat.Tid.parse(rev) == null) return error.InvalidFrame; 250 } 251 252 // check ops count 253 if (payload.get("ops")) |ops_value| { 254 switch (ops_value) { 255 .array => |ops| { 256 if (ops.len > self.config.max_ops) return error.InvalidFrame; 257 // validate each op has valid collection/rkey 258 for (ops) |op| { 259 if (op.getString("collection")) |coll| { 260 if (zat.Nsid.parse(coll) == null) return error.InvalidFrame; 261 } 262 if (op.getString("rkey")) |rk| { 263 if (zat.Rkey.parse(rk) == null) return error.InvalidFrame; 264 } 265 } 266 }, 267 else => return error.InvalidFrame, 268 } 269 } 270 } 271 272 fn queueResolve(self: *Validator, did: []const u8) void { 273 // check if already cached (race between validate and resolver) 274 { 275 self.cache_mutex.lock(); 276 defer self.cache_mutex.unlock(); 277 if (self.cache.contains(did)) return; 278 } 279 280 const duped = self.allocator.dupe(u8, did) catch return; 281 282 self.queue_mutex.lock(); 283 defer self.queue_mutex.unlock(); 284 self.queue.append(self.allocator, duped) catch { 285 self.allocator.free(duped); 286 return; 287 }; 288 self.queue_cond.signal(); 289 } 290 291 fn resolveLoop(self: *Validator) void { 292 var resolver = zat.DidResolver.init(self.allocator); 293 defer resolver.deinit(); 294 295 while (self.alive.load(.acquire)) { 296 var did: ?[]const u8 = null; 297 { 298 self.queue_mutex.lock(); 299 defer self.queue_mutex.unlock(); 300 while (self.queue.items.len == 0 and self.alive.load(.acquire)) { 301 self.queue_cond.timedWait(&self.queue_mutex, 1 * std.time.ns_per_s) catch {}; 302 } 303 if (self.queue.items.len > 0) { 304 did = self.queue.orderedRemove(0); 305 } 306 } 307 308 if (did) |d| { 309 defer self.allocator.free(d); 310 311 // skip if already cached (resolved while queued) 312 { 313 self.cache_mutex.lock(); 314 defer self.cache_mutex.unlock(); 315 if (self.cache.contains(d)) continue; 316 } 317 318 // resolve DID → signing key 319 const parsed = zat.Did.parse(d) orelse continue; 320 var doc = resolver.resolve(parsed) catch |err| { 321 log.debug("DID resolve failed for {s}: {s}", .{ d, @errorName(err) }); 322 continue; 323 }; 324 defer doc.deinit(); 325 326 // extract and decode signing key 327 const vm = doc.signingKey() orelse continue; 328 const key_bytes = zat.multibase.decode(self.allocator, vm.public_key_multibase) catch continue; 329 defer self.allocator.free(key_bytes); 330 const public_key = zat.multicodec.parsePublicKey(key_bytes) catch continue; 331 332 // store decoded key in cache (fixed-size, no pointer chasing) 333 var cached = CachedKey{ 334 .key_type = public_key.key_type, 335 .raw = undefined, 336 .len = @intCast(public_key.raw.len), 337 .resolve_time = std.time.timestamp(), 338 }; 339 @memcpy(cached.raw[0..public_key.raw.len], public_key.raw); 340 341 const did_duped = self.allocator.dupe(u8, d) catch continue; 342 343 self.cache_mutex.lock(); 344 defer self.cache_mutex.unlock(); 345 self.cache.put(self.allocator, did_duped, cached) catch { 346 self.allocator.free(did_duped); 347 }; 348 } 349 } 350 } 351 352 /// evict a DID's cached signing key (e.g. on #identity event). 353 /// the next commit from this DID will trigger a fresh resolution. 354 pub fn evictKey(self: *Validator, did: []const u8) void { 355 self.cache_mutex.lock(); 356 defer self.cache_mutex.unlock(); 357 if (self.cache.fetchRemove(did)) |entry| { 358 self.allocator.free(entry.key); 359 } 360 } 361 362 /// cache size (for diagnostics) 363 pub fn cacheSize(self: *Validator) usize { 364 self.cache_mutex.lock(); 365 defer self.cache_mutex.unlock(); 366 return self.cache.count(); 367 } 368}; 369 370// --- tests --- 371 372test "validateCommit skips on cache miss" { 373 var stats = broadcaster.Stats{}; 374 var v = Validator.init(std.testing.allocator, &stats); 375 defer v.deinit(); 376 377 // build a commit payload using SDK 378 const payload: zat.cbor.Value = .{ .map = &.{ 379 .{ .key = "repo", .value = .{ .text = "did:plc:test123" } }, 380 .{ .key = "seq", .value = .{ .unsigned = 42 } }, 381 .{ .key = "rev", .value = .{ .text = "3k2abc000000" } }, 382 .{ .key = "time", .value = .{ .text = "2024-01-15T10:30:00Z" } }, 383 } }; 384 385 const result = v.validateCommit(payload); 386 try std.testing.expect(result.valid); 387 try std.testing.expect(result.skipped); 388 try std.testing.expectEqual(@as(u64, 1), stats.cache_misses.load(.acquire)); 389} 390 391test "validateCommit skips when no repo field" { 392 var stats = broadcaster.Stats{}; 393 var v = Validator.init(std.testing.allocator, &stats); 394 defer v.deinit(); 395 396 // payload without "repo" field 397 const payload: zat.cbor.Value = .{ .map = &.{ 398 .{ .key = "seq", .value = .{ .unsigned = 42 } }, 399 } }; 400 401 const result = v.validateCommit(payload); 402 try std.testing.expect(result.valid); 403 try std.testing.expect(result.skipped); 404 try std.testing.expectEqual(@as(u64, 1), stats.skipped.load(.acquire)); 405} 406 407test "checkCommitStructure rejects invalid DID" { 408 var stats = broadcaster.Stats{}; 409 var v = Validator.init(std.testing.allocator, &stats); 410 defer v.deinit(); 411 412 const payload: zat.cbor.Value = .{ .map = &.{ 413 .{ .key = "repo", .value = .{ .text = "not-a-did" } }, 414 } }; 415 416 try std.testing.expectError(error.InvalidFrame, v.checkCommitStructure(payload)); 417} 418 419test "checkCommitStructure accepts valid commit" { 420 var stats = broadcaster.Stats{}; 421 var v = Validator.init(std.testing.allocator, &stats); 422 defer v.deinit(); 423 424 const payload: zat.cbor.Value = .{ .map = &.{ 425 .{ .key = "repo", .value = .{ .text = "did:plc:test123" } }, 426 .{ .key = "rev", .value = .{ .text = "3k2abcdefghij" } }, 427 } }; 428 429 try v.checkCommitStructure(payload); 430} 431 432test "checkCommitStructure rejects too many ops" { 433 var stats = broadcaster.Stats{}; 434 var v = Validator.initWithConfig(std.testing.allocator, &stats, .{ .max_ops = 2 }); 435 defer v.deinit(); 436 437 // build ops array with 3 items (over limit of 2) 438 const ops = [_]zat.cbor.Value{ 439 .{ .map = &.{.{ .key = "action", .value = .{ .text = "create" } }} }, 440 .{ .map = &.{.{ .key = "action", .value = .{ .text = "create" } }} }, 441 .{ .map = &.{.{ .key = "action", .value = .{ .text = "create" } }} }, 442 }; 443 444 const payload: zat.cbor.Value = .{ .map = &.{ 445 .{ .key = "repo", .value = .{ .text = "did:plc:test123" } }, 446 .{ .key = "ops", .value = .{ .array = &ops } }, 447 } }; 448 449 try std.testing.expectError(error.InvalidFrame, v.checkCommitStructure(payload)); 450}