atproto relay implementation in zig
zlay.waow.tech
1//! admin endpoint handlers for relay management.
2//!
3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD.
4//! includes host blocking/unblocking, account bans, and backfill control.
5
6const std = @import("std");
7const h = @import("http.zig");
8const router = @import("router.zig");
9const websocket = @import("websocket");
10const broadcaster = @import("../broadcaster.zig");
11const event_log_mod = @import("../event_log.zig");
12const backfill_mod = @import("../backfill.zig");
13
14const log = std.log.scoped(.relay);
15
16const HttpContext = router.HttpContext;
17
18/// check admin auth via headers, send error response if not authorized. returns true if authorized.
19pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool {
20 const admin_pw = std.posix.getenv("RELAY_ADMIN_PASSWORD") orelse {
21 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}");
22 return false;
23 };
24
25 const kv = headers orelse {
26 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
27 return false;
28 };
29
30 // handshake parser lowercases all header names
31 const auth_value = kv.get("authorization") orelse {
32 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
33 return false;
34 };
35
36 const bearer_prefix = "Bearer ";
37 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) {
38 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}");
39 return false;
40 }
41 const token = auth_value[bearer_prefix.len..];
42 if (!std.mem.eql(u8, token, admin_pw)) {
43 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}");
44 return false;
45 }
46 return true;
47}
48
49pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
50 if (!checkAdmin(conn, headers)) return;
51
52 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
53 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}");
54 return;
55 };
56 defer parsed.deinit();
57 const did = parsed.value.did;
58
59 // resolve DID → UID and take down
60 const uid = ctx.persist.uidForDid(did) catch {
61 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}");
62 return;
63 };
64 ctx.persist.takeDownUser(uid) catch {
65 h.respondJson(conn, .internal_server_error, "{\"error\":\"takedown failed\"}");
66 return;
67 };
68
69 // emit #account event so downstream consumers see the takedown
70 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| {
71 if (ctx.persist.persist(.account, uid, frame_bytes)) |relay_seq| {
72 ctx.bc.stats.relay_seq.store(relay_seq, .release);
73 const broadcast_data = broadcaster.resequenceFrame(ctx.persist.allocator, frame_bytes, relay_seq) orelse frame_bytes;
74 ctx.bc.broadcast(relay_seq, broadcast_data);
75 log.info("admin: emitted #account takedown event for {s} (seq={d})", .{ did, relay_seq });
76 } else |err| {
77 log.warn("admin: failed to persist #account takedown event: {s}", .{@errorName(err)});
78 }
79 }
80
81 log.info("admin: banned {s} (uid={d})", .{ did, uid });
82 h.respondJson(conn, .ok, "{\"success\":true}");
83}
84
85pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
86 if (!checkAdmin(conn, headers)) return;
87
88 const persist = ctx.persist;
89 const hosts = persist.listAllHosts(persist.allocator) catch {
90 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
91 return;
92 };
93 defer {
94 for (hosts) |host| {
95 persist.allocator.free(host.hostname);
96 persist.allocator.free(host.status);
97 }
98 persist.allocator.free(hosts);
99 }
100
101 var list: std.ArrayListUnmanaged(u8) = .{};
102 defer list.deinit(persist.allocator);
103 const w = list.writer(persist.allocator);
104
105 w.writeAll("{\"hosts\":[") catch return;
106
107 for (hosts, 0..) |host, i| {
108 if (i > 0) w.writeByte(',') catch return;
109 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d}}}", .{
110 host.id,
111 host.hostname,
112 host.status,
113 host.last_seq,
114 host.failed_attempts,
115 }) catch return;
116 }
117
118 std.fmt.format(w, "],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return;
119 h.respondJson(conn, .ok, list.items);
120}
121
122pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void {
123 if (!checkAdmin(conn, headers)) return;
124
125 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
126 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
127 return;
128 };
129 defer parsed.deinit();
130
131 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch {
132 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}");
133 return;
134 };
135
136 persist.updateHostStatus(host_info.id, "blocked") catch {
137 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}");
138 return;
139 };
140
141 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id });
142 h.respondJson(conn, .ok, "{\"success\":true}");
143}
144
145pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void {
146 if (!checkAdmin(conn, headers)) return;
147
148 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
149 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
150 return;
151 };
152 defer parsed.deinit();
153
154 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch {
155 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}");
156 return;
157 };
158
159 persist.updateHostStatus(host_info.id, "active") catch {
160 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}");
161 return;
162 };
163 persist.resetHostFailures(host_info.id) catch {};
164
165 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id });
166 h.respondJson(conn, .ok, "{\"success\":true}");
167}
168
169pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
170 if (!checkAdmin(conn, headers)) return;
171
172 const source = h.queryParam(query, "source") orelse "bsky.network";
173
174 backfiller.start(source) catch |err| {
175 switch (err) {
176 error.AlreadyRunning => {
177 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}");
178 },
179 else => {
180 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}");
181 },
182 }
183 return;
184 };
185
186 var buf: [256]u8 = undefined;
187 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch {
188 h.respondJson(conn, .ok, "{\"status\":\"started\"}");
189 return;
190 };
191 h.respondJson(conn, .ok, resp_body);
192}
193
194pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
195 if (!checkAdmin(conn, headers)) return;
196
197 const body = backfiller.getStatus(backfiller.allocator) catch {
198 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}");
199 return;
200 };
201 defer backfiller.allocator.free(body);
202
203 h.respondJson(conn, .ok, body);
204}
205
206// --- protocol helpers (used only by handleBan) ---
207
208/// build a CBOR #account frame for a takedown event.
209/// header: {op: 1, t: "#account"}, payload: {seq: 0, did: "...", time: "...", active: false, status: "takendown"}
210fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 {
211 const zat = @import("zat");
212 const cbor = zat.cbor;
213
214 const header: cbor.Value = .{ .map = &.{
215 .{ .key = "op", .value = .{ .unsigned = 1 } },
216 .{ .key = "t", .value = .{ .text = "#account" } },
217 } };
218
219 var time_buf: [24]u8 = undefined;
220 const time_str = formatTimestamp(&time_buf);
221
222 const payload: cbor.Value = .{ .map = &.{
223 .{ .key = "seq", .value = .{ .unsigned = 0 } },
224 .{ .key = "did", .value = .{ .text = did } },
225 .{ .key = "time", .value = .{ .text = time_str } },
226 .{ .key = "active", .value = .{ .boolean = false } },
227 .{ .key = "status", .value = .{ .text = "takendown" } },
228 } };
229
230 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null;
231 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch {
232 allocator.free(header_bytes);
233 return null;
234 };
235
236 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch {
237 allocator.free(header_bytes);
238 allocator.free(payload_bytes);
239 return null;
240 };
241 @memcpy(frame[0..header_bytes.len], header_bytes);
242 @memcpy(frame[header_bytes.len..], payload_bytes);
243
244 allocator.free(header_bytes);
245 allocator.free(payload_bytes);
246
247 return frame;
248}
249
250/// format current UTC time as ISO 8601 (YYYY-MM-DDTHH:MM:SSZ)
251fn formatTimestamp(buf: *[24]u8) []const u8 {
252 const ts: u64 = @intCast(std.time.timestamp());
253 const es = std.time.epoch.EpochSeconds{ .secs = ts };
254 const day = es.getEpochDay();
255 const yd = day.calculateYearDay();
256 const md = yd.calculateMonthDay();
257 const ds = es.getDaySeconds();
258
259 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{
260 yd.year,
261 @as(u32, @intFromEnum(md.month)) + 1,
262 @as(u32, md.day_index) + 1,
263 ds.getHoursIntoDay(),
264 ds.getMinutesIntoHour(),
265 ds.getSecondsIntoMinute(),
266 }) catch "1970-01-01T00:00:00Z";
267}