atproto relay implementation in zig zlay.waow.tech
at main 334 lines 14 kB view raw
1//! zat relay — AT Protocol firehose relay server 2//! 3//! crawls PDS instances directly via the Slurper (one subscriber per host), 4//! validates frames via DID resolution and signature verification, persists 5//! to disk with relay-assigned seq numbers, and rebroadcasts to downstream 6//! consumers over WebSocket. 7//! 8//! port 3000 (RELAY_PORT): WebSocket firehose + HTTP API (via httpFallback) 9//! /xrpc/com.atproto.sync.subscribeRepos — firehose WebSocket (supports ?cursor=N) 10//! /xrpc/com.atproto.sync.listRepos — paginated account listing 11//! /xrpc/com.atproto.sync.getRepoStatus — single account status 12//! /xrpc/com.atproto.sync.getLatestCommit — latest commit CID + rev 13//! /xrpc/com.atproto.sync.listReposByCollection — repos with records in a collection 14//! /xrpc/com.atproto.sync.listHosts — paginated active host listing 15//! /xrpc/com.atproto.sync.getHostStatus — single host status 16//! /xrpc/com.atproto.sync.requestCrawl — request PDS crawl (POST) 17//! /admin/hosts — list all hosts (GET, admin) 18//! /admin/hosts/block — block a host (POST, admin) 19//! /admin/hosts/unblock — unblock a host (POST, admin) 20//! /_health, /_stats — health, stats 21//! 22//! port 3001 (RELAY_METRICS_PORT): internal metrics + health 23//! /metrics — prometheus metrics 24//! /_health — liveness probe (DB check) 25 26const std = @import("std"); 27const http = std.http; 28const websocket = @import("websocket"); 29const broadcaster = @import("broadcaster.zig"); 30const validator_mod = @import("validator.zig"); 31const slurper_mod = @import("slurper.zig"); 32const event_log_mod = @import("event_log.zig"); 33const collection_index_mod = @import("collection_index.zig"); 34const backfill_mod = @import("backfill.zig"); 35const api = @import("api.zig"); 36const build_options = @import("build_options"); 37 38const log = std.log.scoped(.relay); 39 40/// zig's default thread stack is 16 MB. with ~2,750 subscriber threads that's 41/// 44 GB of virtual memory. 8 MB supports ReleaseSafe — tls.Client.init alone 42/// needs ~134 KiB of stack, and deep call chains under inline-else cipher 43/// dispatch need headroom. only touched pages count as RSS. 44pub const default_stack_size = 8 * 1024 * 1024; 45 46var shutdown_flag: std.atomic.Value(bool) = .{ .raw = false }; 47 48/// metrics-only server on the internal port 49const MetricsServer = struct { 50 server: std.net.Server, 51 stats: *broadcaster.Stats, 52 validator: *validator_mod.Validator, 53 data_dir: []const u8, 54 persist: *event_log_mod.DiskPersist, 55 bc: *broadcaster.Broadcaster, 56 slurper: *slurper_mod.Slurper, 57 58 fn run(self: *MetricsServer) void { 59 while (!shutdown_flag.load(.acquire)) { 60 const conn = self.server.accept() catch |err| { 61 if (shutdown_flag.load(.acquire)) return; 62 log.debug("metrics accept error: {s}", .{@errorName(err)}); 63 continue; 64 }; 65 // 5s read timeout — prevents stale connections from blocking the single-threaded server 66 const timeout = std.posix.timeval{ .sec = 5, .usec = 0 }; 67 std.posix.setsockopt(conn.stream.handle, std.posix.SOL.SOCKET, std.posix.SO.RCVTIMEO, std.mem.asBytes(&timeout)) catch {}; 68 handleMetricsConn(conn.stream, self.stats, self.validator, self.data_dir, self.persist, self.bc, self.slurper); 69 } 70 } 71}; 72 73fn handleMetricsConn(stream: std.net.Stream, stats: *broadcaster.Stats, validator: *validator_mod.Validator, data_dir: []const u8, persist: *event_log_mod.DiskPersist, bc: *broadcaster.Broadcaster, slurp: *slurper_mod.Slurper) void { 74 defer stream.close(); 75 76 var recv_buf: [4096]u8 = undefined; 77 var send_buf: [4096]u8 = undefined; 78 var connection_reader = stream.reader(&recv_buf); 79 var connection_writer = stream.writer(&send_buf); 80 var server = http.Server.init(connection_reader.interface(), &connection_writer.interface); 81 82 var request = server.receiveHead() catch return; 83 const path = request.head.target; 84 85 if (std.mem.eql(u8, path, "/_healthz")) { 86 // trivial liveness — constant-time, no dependencies 87 request.respond("{\"status\":\"ok\"}", .{ .status = .ok, .keep_alive = false, .extra_headers = &.{ 88 .{ .name = "content-type", .value = "application/json" }, 89 .{ .name = "server", .value = "zlay (atproto-relay)" }, 90 } }) catch {}; 91 } else if (std.mem.eql(u8, path, "/_health") or std.mem.eql(u8, path, "/_readyz")) { 92 const db_ok = if (persist.db.exec("SELECT 1", .{})) |_| true else |_| false; 93 const status: http.Status = if (db_ok) .ok else .internal_server_error; 94 const body = if (db_ok) "{\"status\":\"ok\"}" else "{\"status\":\"error\",\"msg\":\"database unavailable\"}"; 95 request.respond(body, .{ .status = status, .keep_alive = false, .extra_headers = &.{ 96 .{ .name = "content-type", .value = "application/json" }, 97 .{ .name = "server", .value = "zlay (atproto-relay)" }, 98 } }) catch {}; 99 } else if (std.mem.eql(u8, path, "/metrics")) { 100 const cache_entries = validator.cacheSize(); 101 const attribution = broadcaster.AttributionMetrics{ 102 .history_entries = bc.history.count(), 103 .evtbuf_entries = persist.evtbufLen(), 104 .did_cache_entries = persist.didCacheLen(), 105 .resolve_queue_len = validator.resolveQueueLen(), 106 .resolve_queued_set_count = validator.resolveQueuedSetCount(), 107 .validator_cache_map_cap = validator.cacheMapCapacity(), 108 .did_cache_map_cap = persist.didCacheMapCap(), 109 .queued_set_map_cap = validator.resolveQueuedSetCapacity(), 110 .evtbuf_cap = persist.evtbufCap(), 111 .outbuf_cap = persist.outbufCap(), 112 .workers_count = slurp.workerCount(), 113 }; 114 115 var metrics_buf: [65536]u8 = undefined; 116 const body = broadcaster.formatPrometheusMetrics(stats, cache_entries, attribution, data_dir, &metrics_buf); 117 request.respond(body, .{ .status = .ok, .keep_alive = false, .extra_headers = &.{ 118 .{ .name = "content-type", .value = "text/plain; version=0.0.4; charset=utf-8" }, 119 .{ .name = "server", .value = "zlay (atproto-relay)" }, 120 } }) catch {}; 121 } else { 122 request.respond("not found", .{ .status = .not_found, .keep_alive = false, .extra_headers = &.{ 123 .{ .name = "content-type", .value = "text/plain" }, 124 .{ .name = "server", .value = "zlay (atproto-relay)" }, 125 } }) catch {}; 126 } 127} 128 129pub fn main() !void { 130 // exp-002: optional GPA wrapper for leak detection. 131 // build with -Duse_gpa=true to enable. on clean shutdown (SIGTERM), 132 // GPA logs every allocation that was never freed, with stack traces. 133 var gpa: std.heap.GeneralPurposeAllocator(.{ 134 .stack_trace_frames = if (build_options.use_gpa) 8 else 0, 135 }) = .init; 136 defer if (build_options.use_gpa) { 137 log.info("GPA: checking for leaks...", .{}); 138 const status = gpa.deinit(); 139 if (status == .leak) { 140 log.err("GPA: leaks detected! see stderr for details", .{}); 141 } else { 142 log.info("GPA: no leaks detected", .{}); 143 } 144 }; 145 const allocator = if (build_options.use_gpa) gpa.allocator() else std.heap.c_allocator; 146 147 // parse config from env 148 const port = parseEnvInt(u16, "RELAY_PORT", 3000); 149 const metrics_port = parseEnvInt(u16, "RELAY_METRICS_PORT", 3001); 150 const upstream = std.posix.getenv("RELAY_UPSTREAM") orelse "bsky.network"; 151 const data_dir = std.posix.getenv("RELAY_DATA_DIR") orelse "data/events"; 152 const retention_hours = parseEnvInt(u64, "RELAY_RETENTION_HOURS", 72); 153 const frame_workers = parseEnvInt(u16, "FRAME_WORKERS", 16); 154 const frame_queue_capacity = parseEnvInt(u16, "FRAME_QUEUE_CAPACITY", 4096); 155 156 // install signal handlers (including SIGPIPE ignore) 157 installSignalHandlers(); 158 159 // init components 160 var bc = broadcaster.Broadcaster.init(allocator); 161 defer bc.deinit(); 162 163 var val = validator_mod.Validator.init(allocator, &bc.stats); 164 defer val.deinit(); 165 try val.start(); 166 167 // init disk persistence (indigo-compatible diskpersist format + Postgres index) 168 const database_url = std.posix.getenv("DATABASE_URL") orelse "postgres://relay:relay@localhost:5432/relay"; 169 var dp = event_log_mod.DiskPersist.init(allocator, data_dir, database_url) catch |err| { 170 log.err("failed to init disk persist at {s}: {s}", .{ data_dir, @errorName(err) }); 171 return err; 172 }; 173 defer dp.deinit(); 174 dp.retention_hours = retention_hours; 175 176 if (dp.lastSeq()) |last| { 177 log.info("event log recovered: last_seq={d}", .{last}); 178 } 179 180 // start flush thread 181 try dp.start(); 182 183 // wire persist into broadcaster for cursor replay and validator for migration checks 184 bc.persist = &dp; 185 val.persist = &dp; 186 187 // init collection index (RocksDB — inspired by lightrail/microcosm.blue) 188 const ci_dir = std.posix.getenv("COLLECTION_INDEX_DIR") orelse "data/collection-index"; 189 var ci = collection_index_mod.CollectionIndex.open(allocator, ci_dir) catch |err| { 190 log.err("failed to init collection index at {s}: {s}", .{ ci_dir, @errorName(err) }); 191 return err; 192 }; 193 defer ci.deinit(); 194 195 // init backfiller (collection index backfill from source relay) 196 var backfiller = backfill_mod.Backfiller.init(allocator, &ci, dp.db); 197 198 // init slurper (multi-host crawl manager) 199 var slurper = slurper_mod.Slurper.init( 200 allocator, 201 &bc, 202 &val, 203 &dp, 204 &shutdown_flag, 205 .{ 206 .seed_host = upstream, 207 .max_message_size = 5 * 1024 * 1024, 208 .frame_workers = frame_workers, 209 .frame_queue_capacity = frame_queue_capacity, 210 }, 211 ); 212 defer slurper.deinit(); 213 slurper.collection_index = &ci; 214 215 // start: loads active hosts from DB, spawns subscriber threads 216 try slurper.start(); 217 218 // start GC thread (runs every 10 minutes) 219 const gc_thread = try std.Thread.spawn(.{ .stack_size = default_stack_size }, gcLoop, .{&dp}); 220 221 // wire HTTP fallback into broadcaster (all API endpoints served on WS port) 222 var http_context = api.HttpContext{ 223 .stats = &bc.stats, 224 .persist = &dp, 225 .slurper = &slurper, 226 .collection_index = &ci, 227 .backfiller = &backfiller, 228 .bc = &bc, 229 .validator = &val, 230 }; 231 bc.http_fallback = api.handleHttpRequest; 232 bc.http_fallback_ctx = @ptrCast(&http_context); 233 234 // start metrics-only server (internal port) 235 const metrics_address = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, metrics_port); 236 var metrics_srv = MetricsServer{ 237 .server = metrics_address.listen(.{ .reuse_address = true }) catch |err| { 238 log.err("metrics server failed to listen on :{d}: {s}", .{ metrics_port, @errorName(err) }); 239 return err; 240 }, 241 .stats = &bc.stats, 242 .validator = &val, 243 .data_dir = data_dir, 244 .persist = &dp, 245 .bc = &bc, 246 .slurper = &slurper, 247 }; 248 const metrics_thread = try std.Thread.spawn(.{ .stack_size = default_stack_size }, MetricsServer.run, .{&metrics_srv}); 249 250 // start downstream WebSocket server (also serves HTTP API via httpFallback) 251 log.info("relay listening on :{d} (ws+http), :{d} (metrics)", .{ port, metrics_port }); 252 log.info("seed host: {s}", .{upstream}); 253 log.info("data dir: {s} (retention: {d}h)", .{ data_dir, retention_hours }); 254 255 var server = try websocket.Server(broadcaster.Handler).init(allocator, .{ 256 .port = port, 257 .address = "0.0.0.0", 258 .max_conn = 4096, 259 .max_message_size = 5 * 1024 * 1024, 260 }); 261 defer server.deinit(); 262 263 const server_thread = try server.listenInNewThread(&bc); 264 265 // wait for shutdown signal 266 while (!shutdown_flag.load(.acquire)) { 267 std.posix.nanosleep(0, 100 * std.time.ns_per_ms); 268 } 269 270 log.info("shutdown signal received, stopping...", .{}); 271 272 // stop WebSocket server (closes all downstream connections) 273 server.stop(); 274 server_thread.join(); 275 276 // wait for GC thread 277 gc_thread.join(); 278 279 // close metrics listener socket to unblock accept(), then join 280 metrics_srv.server.stream.close(); 281 metrics_thread.join(); 282 283 log.info("relay stopped cleanly", .{}); 284} 285 286const malloc_h = @cImport(@cInclude("malloc.h")); 287 288fn gcLoop(dp: *event_log_mod.DiskPersist) void { 289 const gc_interval: u64 = 10 * 60; // 10 minutes in seconds 290 while (!shutdown_flag.load(.acquire)) { 291 // sleep in small increments to check shutdown 292 var remaining: u64 = gc_interval; 293 while (remaining > 0 and !shutdown_flag.load(.acquire)) { 294 const chunk = @min(remaining, 1); 295 std.posix.nanosleep(chunk, 0); 296 remaining -= chunk; 297 } 298 if (shutdown_flag.load(.acquire)) return; 299 300 dp.gc() catch |err| { 301 log.warn("event log GC failed: {s}", .{@errorName(err)}); 302 }; 303 304 _ = malloc_h.malloc_trim(0); 305 log.info("gc: malloc_trim complete", .{}); 306 } 307} 308 309fn signalHandler(_: c_int) callconv(.c) void { 310 shutdown_flag.store(true, .release); 311} 312 313fn installSignalHandlers() void { 314 const act: std.posix.Sigaction = .{ 315 .handler = .{ .handler = signalHandler }, 316 .mask = std.posix.sigemptyset(), 317 .flags = 0, 318 }; 319 std.posix.sigaction(std.posix.SIG.INT, &act, null); 320 std.posix.sigaction(std.posix.SIG.TERM, &act, null); 321 322 // ignore SIGPIPE — writing to disconnected consumers must not crash the process 323 const ignore_act: std.posix.Sigaction = .{ 324 .handler = .{ .handler = std.posix.SIG.IGN }, 325 .mask = std.posix.sigemptyset(), 326 .flags = 0, 327 }; 328 std.posix.sigaction(std.posix.SIG.PIPE, &ignore_act, null); 329} 330 331fn parseEnvInt(comptime T: type, key: []const u8, default: T) T { 332 const val = std.posix.getenv(key) orelse return default; 333 return std.fmt.parseInt(T, val, 10) catch default; 334}