atproto relay implementation in zig
zlay.waow.tech
1//! zat relay — AT Protocol firehose relay server
2//!
3//! crawls PDS instances directly via the Slurper (one subscriber per host),
4//! validates frames via DID resolution and signature verification, persists
5//! to disk with relay-assigned seq numbers, and rebroadcasts to downstream
6//! consumers over WebSocket.
7//!
8//! port 3000 (RELAY_PORT): WebSocket firehose + HTTP API (via httpFallback)
9//! /xrpc/com.atproto.sync.subscribeRepos — firehose WebSocket (supports ?cursor=N)
10//! /xrpc/com.atproto.sync.listRepos — paginated account listing
11//! /xrpc/com.atproto.sync.getRepoStatus — single account status
12//! /xrpc/com.atproto.sync.getLatestCommit — latest commit CID + rev
13//! /xrpc/com.atproto.sync.listReposByCollection — repos with records in a collection
14//! /xrpc/com.atproto.sync.listHosts — paginated active host listing
15//! /xrpc/com.atproto.sync.getHostStatus — single host status
16//! /xrpc/com.atproto.sync.requestCrawl — request PDS crawl (POST)
17//! /admin/hosts — list all hosts (GET, admin)
18//! /admin/hosts/block — block a host (POST, admin)
19//! /admin/hosts/unblock — unblock a host (POST, admin)
20//! /_health, /_stats — health, stats
21//!
22//! port 3001 (RELAY_METRICS_PORT): internal metrics + health
23//! /metrics — prometheus metrics
24//! /_health — liveness probe (DB check)
25
26const std = @import("std");
27const http = std.http;
28const websocket = @import("websocket");
29const broadcaster = @import("broadcaster.zig");
30const validator_mod = @import("validator.zig");
31const slurper_mod = @import("slurper.zig");
32const event_log_mod = @import("event_log.zig");
33const collection_index_mod = @import("collection_index.zig");
34const backfill_mod = @import("backfill.zig");
35const api = @import("api.zig");
36const build_options = @import("build_options");
37
38const log = std.log.scoped(.relay);
39
40/// zig's default thread stack is 16 MB. with ~2,750 subscriber threads that's
41/// 44 GB of virtual memory. 8 MB supports ReleaseSafe — tls.Client.init alone
42/// needs ~134 KiB of stack, and deep call chains under inline-else cipher
43/// dispatch need headroom. only touched pages count as RSS.
44pub const default_stack_size = 8 * 1024 * 1024;
45
46var shutdown_flag: std.atomic.Value(bool) = .{ .raw = false };
47
48/// metrics-only server on the internal port
49const MetricsServer = struct {
50 server: std.net.Server,
51 stats: *broadcaster.Stats,
52 validator: *validator_mod.Validator,
53 data_dir: []const u8,
54 persist: *event_log_mod.DiskPersist,
55 bc: *broadcaster.Broadcaster,
56 slurper: *slurper_mod.Slurper,
57
58 fn run(self: *MetricsServer) void {
59 while (!shutdown_flag.load(.acquire)) {
60 const conn = self.server.accept() catch |err| {
61 if (shutdown_flag.load(.acquire)) return;
62 log.debug("metrics accept error: {s}", .{@errorName(err)});
63 continue;
64 };
65 // 5s read timeout — prevents stale connections from blocking the single-threaded server
66 const timeout = std.posix.timeval{ .sec = 5, .usec = 0 };
67 std.posix.setsockopt(conn.stream.handle, std.posix.SOL.SOCKET, std.posix.SO.RCVTIMEO, std.mem.asBytes(&timeout)) catch {};
68 handleMetricsConn(conn.stream, self.stats, self.validator, self.data_dir, self.persist, self.bc, self.slurper);
69 }
70 }
71};
72
73fn handleMetricsConn(stream: std.net.Stream, stats: *broadcaster.Stats, validator: *validator_mod.Validator, data_dir: []const u8, persist: *event_log_mod.DiskPersist, bc: *broadcaster.Broadcaster, slurp: *slurper_mod.Slurper) void {
74 defer stream.close();
75
76 var recv_buf: [4096]u8 = undefined;
77 var send_buf: [4096]u8 = undefined;
78 var connection_reader = stream.reader(&recv_buf);
79 var connection_writer = stream.writer(&send_buf);
80 var server = http.Server.init(connection_reader.interface(), &connection_writer.interface);
81
82 var request = server.receiveHead() catch return;
83 const path = request.head.target;
84
85 if (std.mem.eql(u8, path, "/_healthz")) {
86 // trivial liveness — constant-time, no dependencies
87 request.respond("{\"status\":\"ok\"}", .{ .status = .ok, .keep_alive = false, .extra_headers = &.{
88 .{ .name = "content-type", .value = "application/json" },
89 .{ .name = "server", .value = "zlay (atproto-relay)" },
90 } }) catch {};
91 } else if (std.mem.eql(u8, path, "/_health") or std.mem.eql(u8, path, "/_readyz")) {
92 const db_ok = if (persist.db.exec("SELECT 1", .{})) |_| true else |_| false;
93 const status: http.Status = if (db_ok) .ok else .internal_server_error;
94 const body = if (db_ok) "{\"status\":\"ok\"}" else "{\"status\":\"error\",\"msg\":\"database unavailable\"}";
95 request.respond(body, .{ .status = status, .keep_alive = false, .extra_headers = &.{
96 .{ .name = "content-type", .value = "application/json" },
97 .{ .name = "server", .value = "zlay (atproto-relay)" },
98 } }) catch {};
99 } else if (std.mem.eql(u8, path, "/metrics")) {
100 const cache_entries = validator.cacheSize();
101 const attribution = broadcaster.AttributionMetrics{
102 .history_entries = bc.history.count(),
103 .evtbuf_entries = persist.evtbufLen(),
104 .did_cache_entries = persist.didCacheLen(),
105 .resolve_queue_len = validator.resolveQueueLen(),
106 .resolve_queued_set_count = validator.resolveQueuedSetCount(),
107 .validator_cache_map_cap = validator.cacheMapCapacity(),
108 .did_cache_map_cap = persist.didCacheMapCap(),
109 .queued_set_map_cap = validator.resolveQueuedSetCapacity(),
110 .evtbuf_cap = persist.evtbufCap(),
111 .outbuf_cap = persist.outbufCap(),
112 .workers_count = slurp.workerCount(),
113 };
114
115 var metrics_buf: [65536]u8 = undefined;
116 const body = broadcaster.formatPrometheusMetrics(stats, cache_entries, attribution, data_dir, &metrics_buf);
117 request.respond(body, .{ .status = .ok, .keep_alive = false, .extra_headers = &.{
118 .{ .name = "content-type", .value = "text/plain; version=0.0.4; charset=utf-8" },
119 .{ .name = "server", .value = "zlay (atproto-relay)" },
120 } }) catch {};
121 } else {
122 request.respond("not found", .{ .status = .not_found, .keep_alive = false, .extra_headers = &.{
123 .{ .name = "content-type", .value = "text/plain" },
124 .{ .name = "server", .value = "zlay (atproto-relay)" },
125 } }) catch {};
126 }
127}
128
129pub fn main() !void {
130 // exp-002: optional GPA wrapper for leak detection.
131 // build with -Duse_gpa=true to enable. on clean shutdown (SIGTERM),
132 // GPA logs every allocation that was never freed, with stack traces.
133 var gpa: std.heap.GeneralPurposeAllocator(.{
134 .stack_trace_frames = if (build_options.use_gpa) 8 else 0,
135 }) = .init;
136 defer if (build_options.use_gpa) {
137 log.info("GPA: checking for leaks...", .{});
138 const status = gpa.deinit();
139 if (status == .leak) {
140 log.err("GPA: leaks detected! see stderr for details", .{});
141 } else {
142 log.info("GPA: no leaks detected", .{});
143 }
144 };
145 const allocator = if (build_options.use_gpa) gpa.allocator() else std.heap.c_allocator;
146
147 // parse config from env
148 const port = parseEnvInt(u16, "RELAY_PORT", 3000);
149 const metrics_port = parseEnvInt(u16, "RELAY_METRICS_PORT", 3001);
150 const upstream = std.posix.getenv("RELAY_UPSTREAM") orelse "bsky.network";
151 const data_dir = std.posix.getenv("RELAY_DATA_DIR") orelse "data/events";
152 const retention_hours = parseEnvInt(u64, "RELAY_RETENTION_HOURS", 72);
153 const frame_workers = parseEnvInt(u16, "FRAME_WORKERS", 16);
154 const frame_queue_capacity = parseEnvInt(u16, "FRAME_QUEUE_CAPACITY", 4096);
155
156 // install signal handlers (including SIGPIPE ignore)
157 installSignalHandlers();
158
159 // init components
160 var bc = broadcaster.Broadcaster.init(allocator);
161 defer bc.deinit();
162
163 var val = validator_mod.Validator.init(allocator, &bc.stats);
164 defer val.deinit();
165 try val.start();
166
167 // init disk persistence (indigo-compatible diskpersist format + Postgres index)
168 const database_url = std.posix.getenv("DATABASE_URL") orelse "postgres://relay:relay@localhost:5432/relay";
169 var dp = event_log_mod.DiskPersist.init(allocator, data_dir, database_url) catch |err| {
170 log.err("failed to init disk persist at {s}: {s}", .{ data_dir, @errorName(err) });
171 return err;
172 };
173 defer dp.deinit();
174 dp.retention_hours = retention_hours;
175
176 if (dp.lastSeq()) |last| {
177 log.info("event log recovered: last_seq={d}", .{last});
178 }
179
180 // start flush thread
181 try dp.start();
182
183 // wire persist into broadcaster for cursor replay and validator for migration checks
184 bc.persist = &dp;
185 val.persist = &dp;
186
187 // init collection index (RocksDB — inspired by lightrail/microcosm.blue)
188 const ci_dir = std.posix.getenv("COLLECTION_INDEX_DIR") orelse "data/collection-index";
189 var ci = collection_index_mod.CollectionIndex.open(allocator, ci_dir) catch |err| {
190 log.err("failed to init collection index at {s}: {s}", .{ ci_dir, @errorName(err) });
191 return err;
192 };
193 defer ci.deinit();
194
195 // init backfiller (collection index backfill from source relay)
196 var backfiller = backfill_mod.Backfiller.init(allocator, &ci, dp.db);
197
198 // init slurper (multi-host crawl manager)
199 var slurper = slurper_mod.Slurper.init(
200 allocator,
201 &bc,
202 &val,
203 &dp,
204 &shutdown_flag,
205 .{
206 .seed_host = upstream,
207 .max_message_size = 5 * 1024 * 1024,
208 .frame_workers = frame_workers,
209 .frame_queue_capacity = frame_queue_capacity,
210 },
211 );
212 defer slurper.deinit();
213 slurper.collection_index = &ci;
214
215 // start: loads active hosts from DB, spawns subscriber threads
216 try slurper.start();
217
218 // start GC thread (runs every 10 minutes)
219 const gc_thread = try std.Thread.spawn(.{ .stack_size = default_stack_size }, gcLoop, .{&dp});
220
221 // wire HTTP fallback into broadcaster (all API endpoints served on WS port)
222 var http_context = api.HttpContext{
223 .stats = &bc.stats,
224 .persist = &dp,
225 .slurper = &slurper,
226 .collection_index = &ci,
227 .backfiller = &backfiller,
228 .bc = &bc,
229 .validator = &val,
230 };
231 bc.http_fallback = api.handleHttpRequest;
232 bc.http_fallback_ctx = @ptrCast(&http_context);
233
234 // start metrics-only server (internal port)
235 const metrics_address = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, metrics_port);
236 var metrics_srv = MetricsServer{
237 .server = metrics_address.listen(.{ .reuse_address = true }) catch |err| {
238 log.err("metrics server failed to listen on :{d}: {s}", .{ metrics_port, @errorName(err) });
239 return err;
240 },
241 .stats = &bc.stats,
242 .validator = &val,
243 .data_dir = data_dir,
244 .persist = &dp,
245 .bc = &bc,
246 .slurper = &slurper,
247 };
248 const metrics_thread = try std.Thread.spawn(.{ .stack_size = default_stack_size }, MetricsServer.run, .{&metrics_srv});
249
250 // start downstream WebSocket server (also serves HTTP API via httpFallback)
251 log.info("relay listening on :{d} (ws+http), :{d} (metrics)", .{ port, metrics_port });
252 log.info("seed host: {s}", .{upstream});
253 log.info("data dir: {s} (retention: {d}h)", .{ data_dir, retention_hours });
254
255 var server = try websocket.Server(broadcaster.Handler).init(allocator, .{
256 .port = port,
257 .address = "0.0.0.0",
258 .max_conn = 4096,
259 .max_message_size = 5 * 1024 * 1024,
260 });
261 defer server.deinit();
262
263 const server_thread = try server.listenInNewThread(&bc);
264
265 // wait for shutdown signal
266 while (!shutdown_flag.load(.acquire)) {
267 std.posix.nanosleep(0, 100 * std.time.ns_per_ms);
268 }
269
270 log.info("shutdown signal received, stopping...", .{});
271
272 // stop WebSocket server (closes all downstream connections)
273 server.stop();
274 server_thread.join();
275
276 // wait for GC thread
277 gc_thread.join();
278
279 // close metrics listener socket to unblock accept(), then join
280 metrics_srv.server.stream.close();
281 metrics_thread.join();
282
283 log.info("relay stopped cleanly", .{});
284}
285
286const malloc_h = @cImport(@cInclude("malloc.h"));
287
288fn gcLoop(dp: *event_log_mod.DiskPersist) void {
289 const gc_interval: u64 = 10 * 60; // 10 minutes in seconds
290 while (!shutdown_flag.load(.acquire)) {
291 // sleep in small increments to check shutdown
292 var remaining: u64 = gc_interval;
293 while (remaining > 0 and !shutdown_flag.load(.acquire)) {
294 const chunk = @min(remaining, 1);
295 std.posix.nanosleep(chunk, 0);
296 remaining -= chunk;
297 }
298 if (shutdown_flag.load(.acquire)) return;
299
300 dp.gc() catch |err| {
301 log.warn("event log GC failed: {s}", .{@errorName(err)});
302 };
303
304 _ = malloc_h.malloc_trim(0);
305 log.info("gc: malloc_trim complete", .{});
306 }
307}
308
309fn signalHandler(_: c_int) callconv(.c) void {
310 shutdown_flag.store(true, .release);
311}
312
313fn installSignalHandlers() void {
314 const act: std.posix.Sigaction = .{
315 .handler = .{ .handler = signalHandler },
316 .mask = std.posix.sigemptyset(),
317 .flags = 0,
318 };
319 std.posix.sigaction(std.posix.SIG.INT, &act, null);
320 std.posix.sigaction(std.posix.SIG.TERM, &act, null);
321
322 // ignore SIGPIPE — writing to disconnected consumers must not crash the process
323 const ignore_act: std.posix.Sigaction = .{
324 .handler = .{ .handler = std.posix.SIG.IGN },
325 .mask = std.posix.sigemptyset(),
326 .flags = 0,
327 };
328 std.posix.sigaction(std.posix.SIG.PIPE, &ignore_act, null);
329}
330
331fn parseEnvInt(comptime T: type, key: []const u8, default: T) T {
332 const val = std.posix.getenv(key) orelse return default;
333 return std.fmt.parseInt(T, val, 10) catch default;
334}