atproto relay implementation in zig zlay.waow.tech

feat: host validation, dual account status, rate limiting, FutureCursor

- host validation pipeline: format check, domain ban, SSRF protection,
describeServer liveness, relay loop detection via Server header
- dual account status: host_id + upstream_status columns, combined
local/upstream active check on XRPC endpoints
- #account event processing: updates upstream_status, drops commits
for inactive accounts
- per-host rate limiting: 100 evt/sec token bucket per subscriber
- per-day new host rate limit: 50/day (configurable)
- FutureCursor handling: per-subscriber shutdown, host set to idle
- time-based cursor flush every 4s (matches Go relay)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

+564 -41
+1
build.zig
··· 49 49 "src/validator.zig", 50 50 "src/subscriber.zig", 51 51 "src/event_log.zig", 52 + "src/slurper.zig", 52 53 }; 53 54 inline for (test_files) |file| { 54 55 const t = b.addTest(.{
+111
src/event_log.zig
··· 135 135 \\CREATE TABLE IF NOT EXISTS account ( 136 136 \\ uid BIGSERIAL PRIMARY KEY, 137 137 \\ did TEXT NOT NULL UNIQUE, 138 + \\ host_id BIGINT NOT NULL DEFAULT 0, 138 139 \\ status TEXT NOT NULL DEFAULT 'active', 140 + \\ upstream_status TEXT NOT NULL DEFAULT 'active', 139 141 \\ created_at TIMESTAMPTZ NOT NULL DEFAULT now() 140 142 \\) 141 143 , .{}); 144 + 145 + // migration: add columns if they don't exist (for existing deployments) 146 + _ = pool.exec("ALTER TABLE account ADD COLUMN IF NOT EXISTS host_id BIGINT NOT NULL DEFAULT 0", .{}) catch {}; 147 + _ = pool.exec("ALTER TABLE account ADD COLUMN IF NOT EXISTS upstream_status TEXT NOT NULL DEFAULT 'active'", .{}) catch {}; 142 148 143 149 _ = try pool.exec( 144 150 \\CREATE TABLE IF NOT EXISTS account_repo ( ··· 214 220 /// start the background flush thread 215 221 pub fn start(self: *DiskPersist) !void { 216 222 self.flush_thread = try std.Thread.spawn(.{}, flushLoop, .{self}); 223 + } 224 + 225 + /// resolve a DID to a numeric UID, associating with a host. 226 + /// on first encounter, creates account row with host_id. 227 + /// on subsequent encounters from a different host, updates host_id. 228 + /// Go relay: preProcessEvent → CreateAccountHost / EnsureAccountHost 229 + pub fn uidForDidFromHost(self: *DiskPersist, did: []const u8, host_id: u64) !u64 { 230 + const uid = try self.uidForDid(did); 231 + if (host_id > 0) { 232 + const current_host = self.getAccountHostId(uid) catch 0; 233 + if (current_host == 0) { 234 + // first encounter: set host_id 235 + self.setAccountHostId(uid, host_id) catch {}; 236 + } else if (current_host != host_id) { 237 + // host mismatch: account may have migrated 238 + // Go relay re-resolves DID doc here; we log and update 239 + // (full DID re-resolution for migration validation is TODO) 240 + log.info("account {s} (uid={d}) host changed: {d} → {d}", .{ did, uid, current_host, host_id }); 241 + self.setAccountHostId(uid, host_id) catch {}; 242 + } 243 + } 244 + return uid; 217 245 } 218 246 219 247 /// resolve a DID to a numeric UID. creates a new account row on first encounter. ··· 302 330 ); 303 331 } 304 332 333 + // --- account status --- 334 + 335 + /// get the host_id for an account. returns 0 if not set. 336 + pub fn getAccountHostId(self: *DiskPersist, uid: u64) !u64 { 337 + var row = (try self.db.rowUnsafe( 338 + "SELECT host_id FROM account WHERE uid = $1", 339 + .{@as(i64, @intCast(uid))}, 340 + )) orelse return 0; 341 + defer row.deinit() catch {}; 342 + const hid = row.get(i64, 0); 343 + return if (hid > 0) @intCast(hid) else 0; 344 + } 345 + 346 + /// set the host_id for an account (first encounter or migration) 347 + pub fn setAccountHostId(self: *DiskPersist, uid: u64, host_id: u64) !void { 348 + _ = try self.db.exec( 349 + "UPDATE account SET host_id = $2 WHERE uid = $1", 350 + .{ @as(i64, @intCast(uid)), @as(i64, @intCast(host_id)) }, 351 + ); 352 + } 353 + 354 + /// update the upstream status for an account (from #account events). 355 + /// Go relay: account.go UpdateAccountUpstreamStatus 356 + pub fn updateAccountUpstreamStatus(self: *DiskPersist, uid: u64, upstream_status: []const u8) !void { 357 + _ = try self.db.exec( 358 + "UPDATE account SET upstream_status = $2 WHERE uid = $1", 359 + .{ @as(i64, @intCast(uid)), upstream_status }, 360 + ); 361 + } 362 + 363 + /// check if an account is active (both local status and upstream status). 364 + /// Go relay: models.Account.IsActive() 365 + pub fn isAccountActive(self: *DiskPersist, uid: u64) !bool { 366 + var row = (try self.db.rowUnsafe( 367 + "SELECT status, upstream_status FROM account WHERE uid = $1", 368 + .{@as(i64, @intCast(uid))}, 369 + )) orelse return false; 370 + defer row.deinit() catch {}; 371 + const status = row.get([]const u8, 0); 372 + const upstream = row.get([]const u8, 1); 373 + // active if local is active AND upstream is active 374 + const local_ok = std.mem.eql(u8, status, "active"); 375 + const upstream_ok = std.mem.eql(u8, upstream, "active"); 376 + return local_ok and upstream_ok; 377 + } 378 + 305 379 // --- host management --- 306 380 307 381 pub const Host = struct { ··· 331 405 .id = @intCast(row.get(i64, 0)), 332 406 .last_seq = @intCast(row.get(i64, 1)), 333 407 }; 408 + } 409 + 410 + /// check if a host is banned or blocked by status 411 + pub fn isHostBanned(self: *DiskPersist, hostname: []const u8) bool { 412 + var row = self.db.rowUnsafe( 413 + "SELECT status FROM host WHERE hostname = $1", 414 + .{hostname}, 415 + ) catch return false; 416 + if (row) |*r| { 417 + defer r.deinit() catch {}; 418 + const status = r.get([]const u8, 0); 419 + return std.mem.eql(u8, status, "banned") or std.mem.eql(u8, status, "blocked"); 420 + } 421 + return false; 334 422 } 335 423 336 424 /// update cursor position for a host ··· 421 509 ) orelse return 0; 422 510 defer row.deinit() catch {}; 423 511 return @intCast(row.get(i32, 0)); 512 + } 513 + 514 + /// check if a hostname (or any parent domain) is banned. 515 + /// Go relay: domain_ban.go DomainIsBanned — suffix-based check. 516 + pub fn isDomainBanned(self: *DiskPersist, hostname: []const u8) bool { 517 + // check each suffix: "pds.host.example.com", "host.example.com", "example.com" 518 + var offset: usize = 0; 519 + while (offset < hostname.len) { 520 + const suffix = hostname[offset..]; 521 + var row = self.db.rowUnsafe( 522 + "SELECT 1 FROM domain_ban WHERE domain = $1", 523 + .{suffix}, 524 + ) catch return false; 525 + if (row) |*r| { 526 + r.deinit() catch {}; 527 + return true; 528 + } 529 + // advance past next dot 530 + if (std.mem.indexOfScalarPos(u8, hostname, offset, '.')) |dot| { 531 + offset = dot + 1; 532 + } else break; 533 + } 534 + return false; 424 535 } 425 536 426 537 /// reset failure count (on successful connection)
+60 -18
src/main.zig
··· 330 330 }; 331 331 defer parsed.deinit(); 332 332 333 - slurper.addCrawlRequest(parsed.value.hostname) catch { 333 + // fast validation: hostname format (Go relay does this synchronously in handler) 334 + const hostname = slurper_mod.validateHostname(slurper.allocator, parsed.value.hostname) catch |err| { 335 + log.warn("requestCrawl rejected '{s}': {s}", .{ parsed.value.hostname, @errorName(err) }); 336 + httpRespond(stream, "400 Bad Request", "application/json", switch (err) { 337 + error.EmptyHostname => "{\"error\":\"empty hostname\"}", 338 + error.InvalidCharacter => "{\"error\":\"hostname contains invalid characters\"}", 339 + error.InvalidLabel => "{\"error\":\"hostname has invalid label\"}", 340 + error.TooFewLabels => "{\"error\":\"hostname must have at least two labels (e.g. pds.example.com)\"}", 341 + error.LooksLikeIpAddress => "{\"error\":\"IP addresses not allowed, use a hostname\"}", 342 + error.PortNotAllowed => "{\"error\":\"port numbers not allowed\"}", 343 + error.LocalhostNotAllowed => "{\"error\":\"localhost not allowed\"}", 344 + else => "{\"error\":\"invalid hostname\"}", 345 + }); 346 + return; 347 + }; 348 + defer slurper.allocator.free(hostname); 349 + 350 + // fast validation: domain ban check 351 + if (slurper.persist.isDomainBanned(hostname)) { 352 + log.warn("requestCrawl rejected '{s}': domain banned", .{hostname}); 353 + httpRespond(stream, "400 Bad Request", "application/json", "{\"error\":\"domain is banned\"}"); 354 + return; 355 + } 356 + 357 + // enqueue for async processing (describeServer check happens in crawl processor) 358 + slurper.addCrawlRequest(hostname) catch { 334 359 httpRespond(stream, "500 Internal Server Error", "application/json", "{\"error\":\"failed to store crawl request\"}"); 335 360 return; 336 361 }; 337 362 338 - log.info("crawl requested: {s}", .{parsed.value.hostname}); 363 + log.info("crawl requested: {s}", .{hostname}); 339 364 httpRespond(stream, "200 OK", "application/json", "{\"success\":true}"); 340 365 } 341 366 ··· 471 496 } 472 497 473 498 // query accounts with repo state, paginated by UID 499 + // includes both local status and upstream_status for combined active check 474 500 var result = persist.db.query( 475 - \\SELECT a.uid, a.did, a.status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') 501 + \\SELECT a.uid, a.did, a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') 476 502 \\FROM account a LEFT JOIN account_repo r ON a.uid = r.uid 477 503 \\WHERE a.uid > $1 ORDER BY a.uid ASC LIMIT $2 478 504 , .{ cursor_val, limit }) catch { ··· 496 522 497 523 const uid = row.get(i64, 0); 498 524 const did = row.get([]const u8, 1); 499 - const status = row.get([]const u8, 2); 500 - const rev = row.get([]const u8, 3); 501 - const head = row.get([]const u8, 4); 525 + const local_status = row.get([]const u8, 2); 526 + const upstream_status = row.get([]const u8, 3); 527 + const rev = row.get([]const u8, 4); 528 + const head = row.get([]const u8, 5); 502 529 503 - const active = std.mem.eql(u8, status, "active"); 530 + // Go relay: Account.IsActive() — both local AND upstream must be active 531 + const local_ok = std.mem.eql(u8, local_status, "active"); 532 + const upstream_ok = std.mem.eql(u8, upstream_status, "active"); 533 + const active = local_ok and upstream_ok; 534 + // Go relay: Account.AccountStatus() — local takes priority 535 + const status = if (!local_ok) local_status else upstream_status; 504 536 505 537 w.writeAll("{\"did\":\"") catch return; 506 538 w.writeAll(did) catch return; ··· 555 587 return; 556 588 } 557 589 558 - // look up account 590 + // look up account (includes both local and upstream status) 559 591 var row = (persist.db.rowUnsafe( 560 - "SELECT a.uid, a.status, COALESCE(r.rev, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 592 + "SELECT a.uid, a.status, a.upstream_status, COALESCE(r.rev, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 561 593 .{did}, 562 594 ) catch { 563 595 httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); ··· 568 600 }; 569 601 defer row.deinit() catch {}; 570 602 571 - const status = row.get([]const u8, 1); 572 - const rev = row.get([]const u8, 2); 573 - const active = std.mem.eql(u8, status, "active"); 603 + const local_status = row.get([]const u8, 1); 604 + const upstream_status = row.get([]const u8, 2); 605 + const rev = row.get([]const u8, 3); 606 + // Go relay: Account.IsActive() / AccountStatus() 607 + const local_ok = std.mem.eql(u8, local_status, "active"); 608 + const upstream_ok = std.mem.eql(u8, upstream_status, "active"); 609 + const active = local_ok and upstream_ok; 610 + const status = if (!local_ok) local_status else upstream_status; 574 611 575 612 var buf: [4096]u8 = undefined; 576 613 var fbs = std.io.fixedBufferStream(&buf); ··· 609 646 return; 610 647 } 611 648 612 - // look up account + repo state 649 + // look up account + repo state (includes both local and upstream status) 613 650 var row = (persist.db.rowUnsafe( 614 - "SELECT a.status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 651 + "SELECT a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 615 652 .{did}, 616 653 ) catch { 617 654 httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); ··· 622 659 }; 623 660 defer row.deinit() catch {}; 624 661 625 - const status = row.get([]const u8, 0); 626 - const rev = row.get([]const u8, 1); 627 - const cid = row.get([]const u8, 2); 662 + const local_status = row.get([]const u8, 0); 663 + const upstream_status = row.get([]const u8, 1); 664 + const rev = row.get([]const u8, 2); 665 + const cid = row.get([]const u8, 3); 666 + 667 + // combined status: local takes priority (Go relay: AccountStatus()) 668 + const status = if (!std.mem.eql(u8, local_status, "active")) local_status else upstream_status; 628 669 629 670 // check account status (match Go relay behavior) 630 671 if (std.mem.eql(u8, status, "takendown") or std.mem.eql(u8, status, "suspended")) { ··· 692 733 693 734 fn httpRespond(stream: std.net.Stream, status: []const u8, content_type: []const u8, body: []const u8) void { 694 735 var buf: [4096]u8 = undefined; 695 - const response = std.fmt.bufPrint(&buf, "HTTP/1.1 {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nConnection: close\r\n\r\n{s}", .{ status, content_type, body.len, body }) catch return; 736 + // include Server header for relay loop detection (Go relay: "indigo-relay (atproto-relay)") 737 + const response = std.fmt.bufPrint(&buf, "HTTP/1.1 {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nServer: zlay (atproto-relay)\r\nConnection: close\r\n\r\n{s}", .{ status, content_type, body.len, body }) catch return; 696 738 _ = stream.write(response) catch {}; 697 739 } 698 740
+281 -4
src/slurper.zig
··· 4 4 //! - loading known hosts from DB on startup 5 5 //! - spawning/stopping subscriber workers 6 6 //! - processing crawl requests (adding new hosts) 7 + //! - host validation (format, domain ban, describeServer, relay loop detection) 7 8 //! - tracking host lifecycle (active → exhausted → blocked) 8 9 //! 9 10 //! all downstream components (Broadcaster, DiskPersist, Validator) are 10 11 //! thread-safe for N concurrent producers, so this just orchestrates. 11 12 12 13 const std = @import("std"); 14 + const http = std.http; 13 15 const broadcaster = @import("broadcaster.zig"); 14 16 const validator_mod = @import("validator.zig"); 15 17 const event_log_mod = @import("event_log.zig"); ··· 21 23 pub const Options = struct { 22 24 seed_host: []const u8 = "bsky.network", 23 25 max_message_size: usize = 5 * 1024 * 1024, 26 + new_hosts_per_day: u32 = 50, // Go relay: RELAY_NEW_HOSTS_PER_DAY_LIMIT default 50 24 27 }; 25 28 29 + // --- host validation --- 30 + // mirrors indigo relay's ParseHostname + CheckHost + relay loop detection 31 + 32 + pub const HostValidationError = error{ 33 + EmptyHostname, 34 + InvalidCharacter, 35 + InvalidLabel, 36 + TooFewLabels, 37 + LooksLikeIpAddress, 38 + PortNotAllowed, 39 + LocalhostNotAllowed, 40 + DomainBanned, 41 + HostUnreachable, 42 + NotAPds, 43 + IsARelay, 44 + }; 45 + 46 + /// validate and normalize a hostname for crawling. 47 + /// rejects IPs, ports, localhost, invalid DNS names. 48 + /// returns lowercased hostname on success. 49 + pub fn validateHostname(allocator: Allocator, raw: []const u8) HostValidationError![]u8 { 50 + if (raw.len == 0) return error.EmptyHostname; 51 + 52 + // strip scheme if present 53 + var hostname = raw; 54 + for ([_][]const u8{ "https://", "http://", "wss://", "ws://" }) |scheme| { 55 + if (hostname.len > scheme.len and std.ascii.startsWithIgnoreCase(hostname, scheme)) { 56 + hostname = hostname[scheme.len..]; 57 + break; 58 + } 59 + } 60 + 61 + // strip trailing slash and path 62 + if (std.mem.indexOfScalar(u8, hostname, '/')) |i| { 63 + hostname = hostname[0..i]; 64 + } 65 + 66 + // reject ports (Go relay rejects non-localhost ports) 67 + if (std.mem.indexOfScalar(u8, hostname, ':')) |_| { 68 + return error.PortNotAllowed; 69 + } 70 + 71 + // reject localhost 72 + if (std.ascii.eqlIgnoreCase(hostname, "localhost")) { 73 + return error.LocalhostNotAllowed; 74 + } 75 + 76 + // validate characters and split into labels 77 + var label_count: usize = 0; 78 + var all_labels_numeric = true; 79 + var it = std.mem.splitScalar(u8, hostname, '.'); 80 + while (it.next()) |label| { 81 + if (label.len == 0 or label.len > 63) return error.InvalidLabel; 82 + // labels must be alphanumeric with hyphens, no leading/trailing hyphens 83 + if (label[0] == '-' or label[label.len - 1] == '-') return error.InvalidLabel; 84 + var is_numeric = true; 85 + for (label) |c| { 86 + if (!std.ascii.isAlphanumeric(c) and c != '-') return error.InvalidCharacter; 87 + if (!std.ascii.isDigit(c)) is_numeric = false; 88 + } 89 + if (!is_numeric) all_labels_numeric = false; 90 + label_count += 1; 91 + } 92 + 93 + // need at least 2 labels (e.g. "pds.example.com", "bsky.network") 94 + if (label_count < 2) return error.TooFewLabels; 95 + 96 + // all-numeric labels = IP address (e.g. "192.168.1.1") 97 + if (all_labels_numeric) return error.LooksLikeIpAddress; 98 + 99 + // lowercase normalize 100 + const result = allocator.alloc(u8, hostname.len) catch return error.EmptyHostname; 101 + for (hostname, 0..) |c, i| { 102 + result[i] = std.ascii.toLower(c); 103 + } 104 + return result; 105 + } 106 + 107 + /// SSRF protection: resolve hostname and reject private/reserved IP ranges. 108 + /// Go relay: ssrf.go PublicOnlyTransport — rejects 10/8, 172.16/12, 192.168/16, 127/8, link-local. 109 + fn rejectPrivateHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { 110 + const addr_list = std.net.getAddressList(allocator, hostname, 443) catch return error.HostUnreachable; 111 + defer addr_list.deinit(); 112 + 113 + if (addr_list.addrs.len == 0) return error.HostUnreachable; 114 + 115 + // check all resolved addresses — reject if ANY is private 116 + for (addr_list.addrs) |addr| { 117 + switch (addr.any.family) { 118 + std.posix.AF.INET => { 119 + const ip4 = addr.in.sa.addr; 120 + const bytes: [4]u8 = @bitCast(ip4); 121 + if (bytes[0] == 10 or // 10.0.0.0/8 122 + (bytes[0] == 172 and (bytes[1] & 0xf0) == 16) or // 172.16.0.0/12 123 + (bytes[0] == 192 and bytes[1] == 168) or // 192.168.0.0/16 124 + bytes[0] == 127 or // 127.0.0.0/8 125 + bytes[0] == 0 or // 0.0.0.0/8 126 + (bytes[0] == 169 and bytes[1] == 254)) // 169.254.0.0/16 link-local 127 + { 128 + log.warn("SSRF: {s} resolves to private IP {d}.{d}.{d}.{d}", .{ hostname, bytes[0], bytes[1], bytes[2], bytes[3] }); 129 + return error.HostUnreachable; 130 + } 131 + }, 132 + else => {}, // allow IPv6 for now (could add RFC 4193 check later) 133 + } 134 + } 135 + } 136 + 137 + /// check that a host is a real PDS by calling describeServer. 138 + /// also checks Server header for relay loop detection. 139 + /// Go relay: host_checker.go CheckHost + slurper.go Server header check. 140 + fn checkHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { 141 + // SSRF protection: reject private IPs before making any request 142 + rejectPrivateHost(allocator, hostname) catch |err| return err; 143 + var url_buf: [512]u8 = undefined; 144 + const url = std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.server.describeServer", .{hostname}) catch return error.HostUnreachable; 145 + 146 + var client: http.Client = .{ .allocator = allocator }; 147 + defer client.deinit(); 148 + 149 + const uri = std.Uri.parse(url) catch return error.HostUnreachable; 150 + var req = client.request(.GET, uri, .{}) catch return error.HostUnreachable; 151 + defer req.deinit(); 152 + req.sendBodiless() catch return error.HostUnreachable; 153 + 154 + var redirect_buf: [2048]u8 = undefined; 155 + const response = req.receiveHead(&redirect_buf) catch return error.HostUnreachable; 156 + 157 + if (response.head.status != .ok) return error.NotAPds; 158 + 159 + // relay loop detection: check Server header for "atproto-relay" 160 + // Go relay: slurper.go — auto-bans hosts whose Server header contains "atproto-relay" 161 + if (findHeaderInRaw(response.head.bytes, "server")) |server_val| { 162 + if (std.mem.indexOf(u8, server_val, "atproto-relay") != null) { 163 + return error.IsARelay; 164 + } 165 + } 166 + } 167 + 168 + /// search raw HTTP headers for a header by name (case-insensitive). 169 + /// returns the trimmed value, or null if not found. 170 + fn findHeaderInRaw(raw: []const u8, name: []const u8) ?[]const u8 { 171 + var it = std.mem.splitSequence(u8, raw, "\r\n"); 172 + _ = it.next(); // skip status line 173 + while (it.next()) |line| { 174 + if (line.len == 0) break; 175 + const colon = std.mem.indexOfScalar(u8, line, ':') orelse continue; 176 + const key = std.mem.trim(u8, line[0..colon], " "); 177 + if (key.len != name.len) continue; 178 + // case-insensitive compare 179 + var match = true; 180 + for (key, name) |a, b| { 181 + if (std.ascii.toLower(a) != std.ascii.toLower(b)) { 182 + match = false; 183 + break; 184 + } 185 + } 186 + if (match) { 187 + return std.mem.trim(u8, line[colon + 1 ..], " "); 188 + } 189 + } 190 + return null; 191 + } 192 + 26 193 const WorkerEntry = struct { 27 194 thread: std.Thread, 28 195 subscriber: *subscriber_mod.Subscriber, ··· 47 214 48 215 // crawl processing thread 49 216 crawl_thread: ?std.Thread = null, 217 + 218 + // per-day new host rate limit (Go relay: HostPerDayLimiter) 219 + hosts_added_today: u32 = 0, 220 + rate_limit_day_start: i64 = 0, 50 221 51 222 pub fn init( 52 223 allocator: Allocator, ··· 103 274 self.crawl_cond.signal(); 104 275 } 105 276 106 - /// add a host and spawn a worker for it. idempotent — skips if already tracked. 107 - fn addHost(self: *Slurper, hostname: []const u8) !void { 277 + /// validate and add a host: format check, domain ban, describeServer, then spawn. 278 + /// mirrors Go relay's requestCrawl → SubscribeToHost pipeline. 279 + fn addHost(self: *Slurper, raw_hostname: []const u8) !void { 280 + // step 1: validate and normalize hostname format 281 + // Go relay: host.go ParseHostname 282 + const hostname = validateHostname(self.allocator, raw_hostname) catch |err| { 283 + log.warn("host validation failed for '{s}': {s}", .{ raw_hostname, @errorName(err) }); 284 + return; 285 + }; 286 + defer self.allocator.free(hostname); 287 + 288 + // step 2: domain ban check (suffix-based) 289 + // Go relay: domain_ban.go DomainIsBanned 290 + if (self.persist.isDomainBanned(hostname)) { 291 + log.warn("host {s}: domain is banned, rejecting", .{hostname}); 292 + return; 293 + } 294 + 295 + // step 3: check if host is banned/blocked in DB 296 + // Go relay: crawl.go checks host.Status == HostStatusBanned 297 + if (self.persist.isHostBanned(hostname)) { 298 + log.warn("host {s}: banned/blocked in DB, rejecting", .{hostname}); 299 + return; 300 + } 301 + 302 + // step 4: dedup — check if already tracked 303 + // Go relay: crawl.go CheckIfSubscribed 108 304 const host_info = try self.persist.getOrCreateHost(hostname); 109 - 110 - // check if worker already running 111 305 { 112 306 self.workers_mutex.lock(); 113 307 defer self.workers_mutex.unlock(); ··· 117 311 } 118 312 } 119 313 314 + // step 5: per-day rate limit for new hosts 315 + // Go relay: HostPerDayLimiter (sliding window, 50/day default) 316 + if (!self.checkHostRateLimit()) { 317 + log.warn("host {s}: new-hosts-per-day limit reached ({d}), rejecting", .{ hostname, self.options.new_hosts_per_day }); 318 + return; 319 + } 320 + 321 + // step 6: describeServer liveness check 322 + // Go relay: host_checker.go CheckHost (with SSRF protection) 323 + checkHost(self.allocator, hostname) catch |err| { 324 + log.warn("host {s}: describeServer check failed: {s}", .{ hostname, @errorName(err) }); 325 + return; 326 + }; 327 + 120 328 try self.spawnWorker(host_info.id, hostname); 121 329 log.info("added host {s} (id={d})", .{ hostname, host_info.id }); 330 + } 331 + 332 + /// check and consume one token from the daily host rate limit. 333 + /// resets counter when a new UTC day starts. 334 + fn checkHostRateLimit(self: *Slurper) bool { 335 + const now = std.time.timestamp(); 336 + const day_seconds: i64 = 86400; 337 + if (now - self.rate_limit_day_start >= day_seconds) { 338 + self.hosts_added_today = 0; 339 + self.rate_limit_day_start = now; 340 + } 341 + if (self.hosts_added_today >= self.options.new_hosts_per_day) return false; 342 + self.hosts_added_today += 1; 343 + return true; 122 344 } 123 345 124 346 /// spawn a subscriber thread for a host ··· 228 450 self.crawl_queue.deinit(self.allocator); 229 451 } 230 452 }; 453 + 454 + // --- tests --- 455 + 456 + test "validateHostname accepts valid PDS hostnames" { 457 + const alloc = std.testing.allocator; 458 + 459 + // basic valid hostnames 460 + const h1 = try validateHostname(alloc, "pds.example.com"); 461 + defer alloc.free(h1); 462 + try std.testing.expectEqualStrings("pds.example.com", h1); 463 + 464 + // two labels minimum 465 + const h2 = try validateHostname(alloc, "bsky.network"); 466 + defer alloc.free(h2); 467 + try std.testing.expectEqualStrings("bsky.network", h2); 468 + 469 + // lowercases 470 + const h3 = try validateHostname(alloc, "PDS.Example.COM"); 471 + defer alloc.free(h3); 472 + try std.testing.expectEqualStrings("pds.example.com", h3); 473 + 474 + // strips scheme 475 + const h4 = try validateHostname(alloc, "https://pds.example.com"); 476 + defer alloc.free(h4); 477 + try std.testing.expectEqualStrings("pds.example.com", h4); 478 + 479 + // strips path 480 + const h5 = try validateHostname(alloc, "pds.example.com/some/path"); 481 + defer alloc.free(h5); 482 + try std.testing.expectEqualStrings("pds.example.com", h5); 483 + } 484 + 485 + test "validateHostname rejects invalid hostnames" { 486 + const alloc = std.testing.allocator; 487 + 488 + // empty 489 + try std.testing.expectError(error.EmptyHostname, validateHostname(alloc, "")); 490 + // localhost 491 + try std.testing.expectError(error.LocalhostNotAllowed, validateHostname(alloc, "localhost")); 492 + // single label (non-localhost) 493 + try std.testing.expectError(error.TooFewLabels, validateHostname(alloc, "intranet")); 494 + // IP address 495 + try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "192.168.1.1")); 496 + try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "10.0.0.1")); 497 + // port 498 + try std.testing.expectError(error.PortNotAllowed, validateHostname(alloc, "pds.example.com:443")); 499 + // invalid characters 500 + try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam ple.com")); 501 + try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam_ple.com")); 502 + // leading/trailing hyphens 503 + try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "-pds.example.com")); 504 + try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds-.example.com")); 505 + // empty label 506 + try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds..example.com")); 507 + }
+111 -19
src/subscriber.zig
··· 17 17 const log = std.log.scoped(.relay); 18 18 19 19 const max_consecutive_failures = 15; 20 - const cursor_flush_interval = 1000; // flush cursor to DB every N frames 20 + const cursor_flush_interval_sec = 4; // flush cursor to DB every N seconds (Go relay: 4s) 21 + const default_rate_limit: u64 = 100; // events per second per host (Go relay: 50/sec baseline) 21 22 22 23 pub const Options = struct { 23 24 hostname: []const u8 = "bsky.network", ··· 33 34 persist: ?*event_log_mod.DiskPersist, 34 35 shutdown: *std.atomic.Value(bool), 35 36 last_upstream_seq: ?u64 = null, 36 - frames_since_flush: u64 = 0, 37 + last_cursor_flush: i64 = 0, 38 + 39 + // per-host shutdown (e.g. FutureCursor — stops only this subscriber) 40 + host_shutdown: std.atomic.Value(bool) = .{ .raw = false }, 41 + 42 + // per-host rate limiting (token bucket) 43 + rate_tokens: u64 = default_rate_limit, 44 + rate_last_refill: i64 = 0, 45 + rate_dropped: u64 = 0, 37 46 38 47 pub fn init( 39 48 allocator: Allocator, ··· 53 62 }; 54 63 } 55 64 65 + /// check if this subscriber should stop (global or per-host shutdown) 66 + fn shouldStop(self: *Subscriber) bool { 67 + return self.shutdown.load(.acquire) or self.host_shutdown.load(.acquire); 68 + } 69 + 56 70 /// run the subscriber loop. reconnects with exponential backoff. 57 71 /// blocks until shutdown flag is set or host is exhausted. 58 72 pub fn run(self: *Subscriber) void { ··· 72 86 } 73 87 } 74 88 75 - while (!self.shutdown.load(.acquire)) { 89 + while (!self.shouldStop()) { 76 90 log.info("host {s}: connecting...", .{self.options.hostname}); 77 91 78 92 self.connectAndRead() catch |err| { 79 - if (self.shutdown.load(.acquire)) return; 93 + if (self.shouldStop()) return; 80 94 log.err("host {s}: error: {s}, reconnecting in {d}s...", .{ self.options.hostname, @errorName(err), backoff }); 81 95 }; 82 96 83 - if (self.shutdown.load(.acquire)) return; 97 + if (self.shouldStop()) return; 84 98 85 99 // track failures for this host 86 100 if (self.options.host_id > 0) { ··· 96 110 97 111 // backoff sleep in small increments so we can check shutdown 98 112 var remaining: u64 = backoff; 99 - while (remaining > 0 and !self.shutdown.load(.acquire)) { 113 + while (remaining > 0 and !self.shouldStop()) { 100 114 const chunk = @min(remaining, 1); 101 115 std.posix.nanosleep(chunk, 0); 102 116 remaining -= chunk; ··· 179 193 180 194 // check op field (1 = message, -1 = error) 181 195 const op = header.getInt("op") orelse return; 182 - if (op == -1) return; // error frame from upstream, skip 196 + if (op == -1) { 197 + // error frame from upstream — check for FutureCursor 198 + // Go relay: slurper.go — sets host to idle and disconnects 199 + if (zat.cbor.decodeAll(alloc, payload_data) catch null) |err_payload| { 200 + const err_name = err_payload.getString("error") orelse "unknown"; 201 + const err_msg = err_payload.getString("message") orelse ""; 202 + log.warn("host {s}: error frame: {s}: {s}", .{ sub.options.hostname, err_name, err_msg }); 203 + if (std.mem.eql(u8, err_name, "FutureCursor")) { 204 + // our cursor is ahead of the PDS — set host to idle, stop this subscriber only 205 + if (sub.persist) |dp| { 206 + if (sub.options.host_id > 0) { 207 + dp.updateHostStatus(sub.options.host_id, "idle") catch {}; 208 + } 209 + } 210 + sub.host_shutdown.store(true, .release); 211 + } 212 + } 213 + return; 214 + } 183 215 184 216 const frame_type = header.getString("t") orelse return; 185 217 const payload = zat.cbor.decodeAll(alloc, payload_data) catch |err| { ··· 195 227 sub.bc.stats.seq.store(s, .release); 196 228 } 197 229 198 - // periodic cursor flush 199 - sub.frames_since_flush += 1; 200 - if (sub.frames_since_flush >= cursor_flush_interval) { 201 - sub.flushCursor(); 202 - sub.frames_since_flush = 0; 230 + // time-based cursor flush (Go relay: every 4 seconds) 231 + { 232 + const now = std.time.timestamp(); 233 + if (now - sub.last_cursor_flush >= cursor_flush_interval_sec) { 234 + sub.flushCursor(); 235 + sub.last_cursor_flush = now; 236 + } 237 + } 238 + 239 + // per-host rate limiting (token bucket, refills once per second) 240 + // Go relay: sliding window limiters per host (50/sec baseline) 241 + { 242 + const now = std.time.timestamp(); 243 + if (now > sub.rate_last_refill) { 244 + sub.rate_tokens = default_rate_limit; 245 + sub.rate_last_refill = now; 246 + if (sub.rate_dropped > 0) { 247 + log.warn("host {s}: rate limited, dropped {d} events in last window", .{ sub.options.hostname, sub.rate_dropped }); 248 + sub.rate_dropped = 0; 249 + } 250 + } 251 + if (sub.rate_tokens == 0) { 252 + sub.rate_dropped += 1; 253 + return; 254 + } 255 + sub.rate_tokens -= 1; 203 256 } 204 257 205 258 // route by frame type 206 259 const is_commit = std.mem.eql(u8, frame_type, "#commit"); 260 + const is_account = std.mem.eql(u8, frame_type, "#account"); 207 261 208 262 // extract DID: "repo" for commits, "did" for identity/account 209 263 const did: ?[]const u8 = if (is_commit) ··· 216 270 if (did) |d| sub.validator.evictKey(d); 217 271 } 218 272 219 - // validate commit frames using pre-decoded payload 273 + // resolve DID → numeric UID for event header (host-aware) 274 + const uid: u64 = if (sub.persist) |dp| blk: { 275 + break :blk if (did) |d| 276 + dp.uidForDidFromHost(d, sub.options.host_id) catch 0 277 + else 278 + 0; 279 + } else 0; 280 + 281 + // process #account events: update upstream status 282 + // Go relay: ingest.go processAccountEvent 283 + if (is_account) { 284 + if (sub.persist) |dp| { 285 + if (uid > 0) { 286 + const active = payload.get("active"); 287 + const status_str = payload.getString("status"); 288 + const new_status: []const u8 = if (active) |a| switch (a) { 289 + .true => "active", 290 + else => status_str orelse "inactive", 291 + } else status_str orelse "inactive"; 292 + dp.updateAccountUpstreamStatus(uid, new_status) catch |err| { 293 + log.debug("upstream status update failed: {s}", .{@errorName(err)}); 294 + }; 295 + } 296 + } 297 + } 298 + 299 + // for commits: check account is active, validate, extract state 220 300 var commit_data_cid: ?[]const u8 = null; 221 301 var commit_rev: ?[]const u8 = null; 222 302 if (is_commit) { 303 + // drop commits for inactive accounts 304 + // Go relay: EnsureAccountActive — silently drops 305 + if (sub.persist) |dp| { 306 + if (uid > 0) { 307 + const active = dp.isAccountActive(uid) catch true; 308 + if (!active) { 309 + _ = sub.bc.stats.skipped.fetchAdd(1, .monotonic); 310 + return; 311 + } 312 + } 313 + } 314 + 223 315 const result = sub.validator.validateCommit(payload); 224 316 if (!result.valid) return; 225 317 commit_data_cid = result.data_cid; ··· 227 319 } 228 320 229 321 // determine event kind for persistence 230 - const kind: event_log_mod.EvtKind = if (is_commit) .commit else .identity; 231 - 232 - // resolve DID → numeric UID for event header 233 - const uid: u64 = if (sub.persist) |dp| blk: { 234 - break :blk if (did) |d| dp.uidForDid(d) catch 0 else 0; 235 - } else 0; 322 + const kind: event_log_mod.EvtKind = if (is_commit) 323 + .commit 324 + else if (is_account) 325 + .account 326 + else 327 + .identity; 236 328 237 329 // persist and get relay-assigned seq, broadcast raw bytes 238 330 if (sub.persist) |dp| {