//! slurper — multi-host PDS crawl manager //! //! manages one Subscriber thread per tracked PDS host. handles: //! - loading known hosts from DB on startup //! - spawning/stopping subscriber workers //! - processing crawl requests (adding new hosts) //! - host validation (format, domain ban, describeServer, relay loop detection) //! - tracking host lifecycle (active → exhausted → blocked) //! //! all downstream components (Broadcaster, DiskPersist, Validator) are //! thread-safe for N concurrent producers, so this just orchestrates. const std = @import("std"); const http = std.http; const broadcaster = @import("broadcaster.zig"); const validator_mod = @import("validator.zig"); const event_log_mod = @import("event_log.zig"); const subscriber_mod = @import("subscriber.zig"); const collection_index_mod = @import("collection_index.zig"); const frame_worker_mod = @import("frame_worker.zig"); const Allocator = std.mem.Allocator; const log = std.log.scoped(.relay); pub const Options = struct { seed_host: []const u8 = "bsky.network", max_message_size: usize = 5 * 1024 * 1024, frame_workers: u16 = 16, frame_queue_capacity: u16 = 4096, }; // --- host validation --- // mirrors indigo relay's ParseHostname + CheckHost + relay loop detection pub const HostValidationError = error{ EmptyHostname, InvalidCharacter, InvalidLabel, TooFewLabels, LooksLikeIpAddress, PortNotAllowed, LocalhostNotAllowed, DomainBanned, HostUnreachable, NotAPds, IsARelay, }; /// validate and normalize a hostname for crawling. /// rejects IPs, ports, localhost, invalid DNS names. /// returns lowercased hostname on success. pub fn validateHostname(allocator: Allocator, raw: []const u8) HostValidationError![]u8 { if (raw.len == 0) return error.EmptyHostname; // strip scheme if present var hostname = raw; for ([_][]const u8{ "https://", "http://", "wss://", "ws://" }) |scheme| { if (hostname.len > scheme.len and std.ascii.startsWithIgnoreCase(hostname, scheme)) { hostname = hostname[scheme.len..]; break; } } // strip trailing slash and path if (std.mem.indexOfScalar(u8, hostname, '/')) |i| { hostname = hostname[0..i]; } // reject ports (Go relay rejects non-localhost ports) if (std.mem.indexOfScalar(u8, hostname, ':')) |_| { return error.PortNotAllowed; } // reject localhost if (std.ascii.eqlIgnoreCase(hostname, "localhost")) { return error.LocalhostNotAllowed; } // validate characters and split into labels var label_count: usize = 0; var all_labels_numeric = true; var it = std.mem.splitScalar(u8, hostname, '.'); while (it.next()) |label| { if (label.len == 0 or label.len > 63) return error.InvalidLabel; // labels must be alphanumeric with hyphens, no leading/trailing hyphens if (label[0] == '-' or label[label.len - 1] == '-') return error.InvalidLabel; var is_numeric = true; for (label) |c| { if (!std.ascii.isAlphanumeric(c) and c != '-') return error.InvalidCharacter; if (!std.ascii.isDigit(c)) is_numeric = false; } if (!is_numeric) all_labels_numeric = false; label_count += 1; } // need at least 2 labels (e.g. "pds.example.com", "bsky.network") if (label_count < 2) return error.TooFewLabels; // all-numeric labels = IP address (e.g. "192.168.1.1") if (all_labels_numeric) return error.LooksLikeIpAddress; // lowercase normalize const result = allocator.alloc(u8, hostname.len) catch return error.EmptyHostname; for (hostname, 0..) |c, i| { result[i] = std.ascii.toLower(c); } return result; } /// SSRF protection: resolve hostname and reject private/reserved IP ranges. /// Go relay: ssrf.go PublicOnlyTransport — rejects 10/8, 172.16/12, 192.168/16, 127/8, link-local. fn rejectPrivateHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { const addr_list = std.net.getAddressList(allocator, hostname, 443) catch return error.HostUnreachable; defer addr_list.deinit(); if (addr_list.addrs.len == 0) return error.HostUnreachable; // check all resolved addresses — reject if ANY is private for (addr_list.addrs) |addr| { switch (addr.any.family) { std.posix.AF.INET => { const ip4 = addr.in.sa.addr; const bytes: [4]u8 = @bitCast(ip4); if (bytes[0] == 10 or // 10.0.0.0/8 (bytes[0] == 172 and (bytes[1] & 0xf0) == 16) or // 172.16.0.0/12 (bytes[0] == 192 and bytes[1] == 168) or // 192.168.0.0/16 bytes[0] == 127 or // 127.0.0.0/8 bytes[0] == 0 or // 0.0.0.0/8 (bytes[0] == 169 and bytes[1] == 254)) // 169.254.0.0/16 link-local { log.warn("SSRF: {s} resolves to private IP {d}.{d}.{d}.{d}", .{ hostname, bytes[0], bytes[1], bytes[2], bytes[3] }); return error.HostUnreachable; } }, else => {}, // allow IPv6 for now (could add RFC 4193 check later) } } } /// check that a host is a real PDS by calling describeServer. /// also checks Server header for relay loop detection. /// Go relay: host_checker.go CheckHost + slurper.go Server header check. fn checkHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { // SSRF protection: reject private IPs before making any request rejectPrivateHost(allocator, hostname) catch |err| return err; var url_buf: [512]u8 = undefined; const url = std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.server.describeServer", .{hostname}) catch return error.HostUnreachable; var client: http.Client = .{ .allocator = allocator }; defer client.deinit(); const uri = std.Uri.parse(url) catch return error.HostUnreachable; var req = client.request(.GET, uri, .{}) catch return error.HostUnreachable; defer req.deinit(); req.sendBodiless() catch return error.HostUnreachable; var redirect_buf: [2048]u8 = undefined; const response = req.receiveHead(&redirect_buf) catch return error.HostUnreachable; if (response.head.status != .ok) return error.NotAPds; // relay loop detection: check Server header for "atproto-relay" // Go relay: slurper.go — auto-bans hosts whose Server header contains "atproto-relay" if (findHeaderInRaw(response.head.bytes, "server")) |server_val| { if (std.mem.indexOf(u8, server_val, "atproto-relay") != null) { return error.IsARelay; } } } /// search raw HTTP headers for a header by name (case-insensitive). /// returns the trimmed value, or null if not found. fn findHeaderInRaw(raw: []const u8, name: []const u8) ?[]const u8 { var it = std.mem.splitSequence(u8, raw, "\r\n"); _ = it.next(); // skip status line while (it.next()) |line| { if (line.len == 0) break; const colon = std.mem.indexOfScalar(u8, line, ':') orelse continue; const key = std.mem.trim(u8, line[0..colon], " "); if (key.len != name.len) continue; // case-insensitive compare var match = true; for (key, name) |a, b| { if (std.ascii.toLower(a) != std.ascii.toLower(b)) { match = false; break; } } if (match) { return std.mem.trim(u8, line[colon + 1 ..], " "); } } return null; } const WorkerEntry = struct { thread: std.Thread, subscriber: *subscriber_mod.Subscriber, }; pub const Slurper = struct { allocator: Allocator, bc: *broadcaster.Broadcaster, validator: *validator_mod.Validator, persist: *event_log_mod.DiskPersist, collection_index: ?*collection_index_mod.CollectionIndex = null, shutdown: *std.atomic.Value(bool), options: Options, // frame processing pool — offloads heavy work from reader threads frame_pool: ?frame_worker_mod.FramePool = null, // shared TLS CA bundle — loaded once, used by all subscriber connections ca_bundle: ?std.crypto.Certificate.Bundle = null, // active subscriber threads, keyed by host_id workers: std.AutoHashMapUnmanaged(u64, WorkerEntry) = .{}, workers_mutex: std.Thread.Mutex = .{}, // crawl request queue crawl_queue: std.ArrayListUnmanaged([]const u8) = .{}, crawl_mutex: std.Thread.Mutex = .{}, crawl_cond: std.Thread.Condition = .{}, // background threads startup_thread: ?std.Thread = null, crawl_thread: ?std.Thread = null, pub fn init( allocator: Allocator, bc: *broadcaster.Broadcaster, val: *validator_mod.Validator, persist: *event_log_mod.DiskPersist, shutdown: *std.atomic.Value(bool), options: Options, ) Slurper { return .{ .allocator = allocator, .bc = bc, .validator = val, .persist = persist, .shutdown = shutdown, .options = options, }; } /// start the slurper: bootstrap hosts from seed relay, load from DB, spawn workers. /// Go relay: pull-hosts bootstraps from bsky.network's listHosts, then crawls each PDS directly. pub fn start(self: *Slurper) !void { // load CA bundle once — shared by all subscriber TLS connections var bundle: std.crypto.Certificate.Bundle = .{}; try bundle.rescan(self.allocator); self.ca_bundle = bundle; log.info("loaded shared CA bundle", .{}); // create frame processing pool — worker threads handle heavy decode/validate/persist self.frame_pool = try frame_worker_mod.FramePool.init(self.allocator, .{ .num_workers = self.options.frame_workers, .queue_capacity = self.options.frame_queue_capacity, .stack_size = @import("main.zig").default_stack_size, }); log.info("frame pool started: {d} workers, queue capacity {d}", .{ self.options.frame_workers, self.options.frame_queue_capacity }); // spawn worker startup in background so HTTP server + probes come up immediately. // pullHosts + listActiveHosts + spawnWorker all happen in the background thread. self.startup_thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, spawnWorkers, .{self}); self.crawl_thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, processCrawlQueue, .{self}); } /// pull PDS host list from the seed relay's com.atproto.sync.listHosts endpoint. /// Go relay: cmd/relay/pull.go — one-time bootstrap, reads REST API, not firehose. pub fn pullHosts(self: *Slurper) !void { var cursor: ?[]const u8 = null; var total: usize = 0; const limit = 500; var client: http.Client = .{ .allocator = self.allocator }; defer client.deinit(); while (true) { if (self.shutdown.load(.acquire)) break; // build URL with pagination var url_buf: [512]u8 = undefined; const url = if (cursor) |c| std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.sync.listHosts?limit={d}&cursor={s}", .{ self.options.seed_host, limit, c }) catch break else std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.sync.listHosts?limit={d}", .{ self.options.seed_host, limit }) catch break; var aw: std.Io.Writer.Allocating = .init(self.allocator); defer aw.deinit(); const result = client.fetch(.{ .location = .{ .url = url }, .response_writer = &aw.writer, .method = .GET, }) catch |err| { log.warn("pullHosts: fetch failed: {s}", .{@errorName(err)}); break; }; if (result.status != .ok) { log.warn("pullHosts: got status {d}", .{@intFromEnum(result.status)}); break; } const body = aw.written(); // parse JSON response: { "hosts": [{"hostname": "...", "status": "..."}, ...], "cursor": "..." } const parsed = std.json.parseFromSlice(ListHostsResponse, self.allocator, body, .{ .ignore_unknown_fields = true }) catch |err| { log.warn("pullHosts: JSON parse failed: {s}", .{@errorName(err)}); break; }; defer parsed.deinit(); const hosts = parsed.value.hosts orelse break; if (hosts.len == 0) break; var added: usize = 0; for (hosts) |host| { // skip non-active hosts if (host.status) |s| { if (!std.mem.eql(u8, s, "active")) continue; } // validate hostname format (rejects IPs, localhost, etc.) const normalized = validateHostname(self.allocator, host.hostname) catch continue; defer self.allocator.free(normalized); // skip banned domains if (self.persist.isDomainBanned(normalized)) continue; // insert into DB (no describeServer check — the seed relay already vetted them) _ = self.persist.getOrCreateHost(normalized) catch continue; added += 1; } total += added; log.info("pullHosts: page fetched, {d} hosts added ({d} total)", .{ added, total }); // advance cursor if (parsed.value.cursor) |next_cursor| { // free previous cursor if we allocated one if (cursor) |prev| self.allocator.free(prev); cursor = self.allocator.dupe(u8, next_cursor) catch break; } else { break; // no more pages } } // free final cursor if (cursor) |c| self.allocator.free(c); log.info("pullHosts: bootstrap complete, {d} hosts added from {s}", .{ total, self.options.seed_host }); } const ListHostsResponse = struct { hosts: ?[]const ListHostEntry = null, cursor: ?[]const u8 = null, }; const ListHostEntry = struct { hostname: []const u8, status: ?[]const u8 = null, }; /// add a crawl request (from requestCrawl endpoint) pub fn addCrawlRequest(self: *Slurper, hostname: []const u8) !void { const duped = try self.allocator.dupe(u8, hostname); self.crawl_mutex.lock(); defer self.crawl_mutex.unlock(); try self.crawl_queue.append(self.allocator, duped); self.crawl_cond.signal(); } /// validate and add a host: format check, domain ban, describeServer, then spawn. /// mirrors Go relay's requestCrawl → SubscribeToHost pipeline. fn addHost(self: *Slurper, raw_hostname: []const u8) !void { // step 1: validate and normalize hostname format // Go relay: host.go ParseHostname const hostname = validateHostname(self.allocator, raw_hostname) catch |err| { log.warn("host validation failed for '{s}': {s}", .{ raw_hostname, @errorName(err) }); return; }; defer self.allocator.free(hostname); // step 2: domain ban check (suffix-based) // Go relay: domain_ban.go DomainIsBanned if (self.persist.isDomainBanned(hostname)) { log.warn("host {s}: domain is banned, rejecting", .{hostname}); return; } // step 3: check if host is banned/blocked in DB // Go relay: crawl.go checks host.Status == HostStatusBanned if (self.persist.isHostBanned(hostname)) { log.warn("host {s}: banned/blocked in DB, rejecting", .{hostname}); return; } // step 4: dedup — check if already tracked // Go relay: crawl.go CheckIfSubscribed const host_info = try self.persist.getOrCreateHost(hostname); { self.workers_mutex.lock(); defer self.workers_mutex.unlock(); if (self.workers.contains(host_info.id)) { log.debug("host {s} already has a worker, skipping", .{hostname}); return; } } // step 5: describeServer liveness check // Go relay: host_checker.go CheckHost (with SSRF protection) checkHost(self.allocator, hostname) catch |err| { log.warn("host {s}: describeServer check failed: {s}", .{ hostname, @errorName(err) }); return; }; // reset status and failure count — host passed describeServer, give it a fresh start. // without this, exhausted hosts accumulate failures across requestCrawl cycles // and immediately re-exhaust on a single failure. self.persist.updateHostStatus(host_info.id, "active") catch {}; self.persist.resetHostFailures(host_info.id) catch {}; try self.spawnWorker(host_info.id, hostname); log.info("added host {s} (id={d})", .{ hostname, host_info.id }); } /// spawn a subscriber thread for a host fn spawnWorker(self: *Slurper, host_id: u64, hostname: []const u8) !void { const hostname_duped = try self.allocator.dupe(u8, hostname); errdefer self.allocator.free(hostname_duped); const sub = try self.allocator.create(subscriber_mod.Subscriber); errdefer self.allocator.destroy(sub); const account_count: u64 = self.persist.getEffectiveAccountCount(host_id); sub.* = subscriber_mod.Subscriber.init( self.allocator, self.bc, self.validator, self.persist, self.shutdown, .{ .hostname = hostname_duped, .max_message_size = self.options.max_message_size, .host_id = host_id, .account_count = account_count, .ca_bundle = self.ca_bundle, }, ); sub.collection_index = self.collection_index; if (self.frame_pool) |*fp| sub.pool = fp; const thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, runWorker, .{ self, host_id, sub }); self.workers_mutex.lock(); defer self.workers_mutex.unlock(); try self.workers.put(self.allocator, host_id, .{ .thread = thread, .subscriber = sub, }); _ = self.bc.stats.connected_inbound.fetchAdd(1, .monotonic); } /// worker thread wrapper — runs subscriber, cleans up on exit fn runWorker(self: *Slurper, host_id: u64, sub: *subscriber_mod.Subscriber) void { sub.run(); // subscriber returned — remove from active workers self.workers_mutex.lock(); defer self.workers_mutex.unlock(); _ = self.workers.remove(host_id); _ = self.bc.stats.connected_inbound.fetchSub(1, .monotonic); log.info("worker for host_id={d} ({s}) exited", .{ host_id, sub.options.hostname }); self.allocator.free(sub.options.hostname); self.allocator.destroy(sub); } /// background thread: load hosts from DB and spawn all workers. /// runs in background so HTTP server + probes come up immediately. /// Go relay: ResubscribeAllHosts loops with 1ms sleep per host (goroutines). /// we spawn all at once — the brief memory spike from concurrent TLS handshakes /// is shorter than a throttled ramp (many hosts fail-fast, freeing memory quickly). fn spawnWorkers(self: *Slurper) void { // pull hosts from seed relay first — idempotent (getOrCreateHost skips existing) log.info("pulling hosts from {s}...", .{self.options.seed_host}); self.pullHosts() catch |err| { log.warn("pullHosts from {s} failed: {s}", .{ self.options.seed_host, @errorName(err) }); }; const hosts = self.persist.listActiveHosts(self.allocator) catch |err| { log.err("failed to load hosts: {s}", .{@errorName(err)}); return; }; defer { for (hosts) |h| { self.allocator.free(h.hostname); self.allocator.free(h.status); } self.allocator.free(hosts); } for (hosts) |host| { if (self.shutdown.load(.acquire)) break; self.spawnWorker(host.id, host.hostname) catch |err| { log.warn("failed to spawn worker for {s}: {s}", .{ host.hostname, @errorName(err) }); }; } log.info("startup complete: {d} host(s) spawned", .{hosts.len}); } /// background thread: process crawl requests fn processCrawlQueue(self: *Slurper) void { while (!self.shutdown.load(.acquire)) { var hostname: ?[]const u8 = null; { self.crawl_mutex.lock(); defer self.crawl_mutex.unlock(); while (self.crawl_queue.items.len == 0 and !self.shutdown.load(.acquire)) { self.crawl_cond.timedWait(&self.crawl_mutex, 1 * std.time.ns_per_s) catch {}; } if (self.crawl_queue.items.len > 0) { hostname = self.crawl_queue.orderedRemove(0); } } if (hostname) |h| { defer self.allocator.free(h); self.addHost(h) catch |err| { log.warn("crawl request failed for {s}: {s}", .{ h, @errorName(err) }); }; } } } /// number of active workers pub fn workerCount(self: *Slurper) usize { self.workers_mutex.lock(); defer self.workers_mutex.unlock(); return self.workers.count(); } /// update rate limits for a running subscriber (called from admin API). /// if the host has a worker, recomputes and applies new limits immediately. pub fn updateHostLimits(self: *Slurper, host_id: u64, account_count: u64) void { self.workers_mutex.lock(); defer self.workers_mutex.unlock(); if (self.workers.get(host_id)) |entry| { const trusted = subscriber_mod.isTrustedHost(entry.subscriber.options.hostname); const limits = subscriber_mod.computeLimits(trusted, account_count); entry.subscriber.rate_limiter.updateLimits(limits.sec, limits.hour, limits.day); log.info("updated rate limits for host_id={d}: sec={d} hour={d} day={d}", .{ host_id, limits.sec, limits.hour, limits.day, }); } } /// shutdown all workers and clean up pub fn deinit(self: *Slurper) void { // join background threads if (self.startup_thread) |t| t.join(); self.crawl_cond.signal(); if (self.crawl_thread) |t| t.join(); // collect threads to join (can't join while holding workers_mutex) var threads_to_join: std.ArrayListUnmanaged(std.Thread) = .{}; defer threads_to_join.deinit(self.allocator); { self.workers_mutex.lock(); defer self.workers_mutex.unlock(); var it = self.workers.iterator(); while (it.next()) |entry| { threads_to_join.append(self.allocator, entry.value_ptr.thread) catch {}; } } // join all reader threads FIRST (they stop submitting to pool) for (threads_to_join.items) |t| t.join(); // then drain + join pool workers (processes remaining queued frames) if (self.frame_pool) |*fp| { fp.shutdown(); fp.deinit(); self.frame_pool = null; } // clean up workers map self.workers.deinit(self.allocator); // clean up crawl queue for (self.crawl_queue.items) |h| self.allocator.free(h); self.crawl_queue.deinit(self.allocator); // free shared CA bundle if (self.ca_bundle) |*b| b.deinit(self.allocator); } }; // --- tests --- test "validateHostname accepts valid PDS hostnames" { const alloc = std.testing.allocator; // basic valid hostnames const h1 = try validateHostname(alloc, "pds.example.com"); defer alloc.free(h1); try std.testing.expectEqualStrings("pds.example.com", h1); // two labels minimum const h2 = try validateHostname(alloc, "bsky.network"); defer alloc.free(h2); try std.testing.expectEqualStrings("bsky.network", h2); // lowercases const h3 = try validateHostname(alloc, "PDS.Example.COM"); defer alloc.free(h3); try std.testing.expectEqualStrings("pds.example.com", h3); // strips scheme const h4 = try validateHostname(alloc, "https://pds.example.com"); defer alloc.free(h4); try std.testing.expectEqualStrings("pds.example.com", h4); // strips path const h5 = try validateHostname(alloc, "pds.example.com/some/path"); defer alloc.free(h5); try std.testing.expectEqualStrings("pds.example.com", h5); } test "validateHostname rejects invalid hostnames" { const alloc = std.testing.allocator; // empty try std.testing.expectError(error.EmptyHostname, validateHostname(alloc, "")); // localhost try std.testing.expectError(error.LocalhostNotAllowed, validateHostname(alloc, "localhost")); // single label (non-localhost) try std.testing.expectError(error.TooFewLabels, validateHostname(alloc, "intranet")); // IP address try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "192.168.1.1")); try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "10.0.0.1")); // port try std.testing.expectError(error.PortNotAllowed, validateHostname(alloc, "pds.example.com:443")); // invalid characters try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam ple.com")); try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam_ple.com")); // leading/trailing hyphens try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "-pds.example.com")); try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds-.example.com")); // empty label try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds..example.com")); }