atproto relay implementation in zig zlay.waow.tech
at main 267 lines 11 kB view raw
1//! admin endpoint handlers for relay management. 2//! 3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD. 4//! includes host blocking/unblocking, account bans, and backfill control. 5 6const std = @import("std"); 7const h = @import("http.zig"); 8const router = @import("router.zig"); 9const websocket = @import("websocket"); 10const broadcaster = @import("../broadcaster.zig"); 11const event_log_mod = @import("../event_log.zig"); 12const backfill_mod = @import("../backfill.zig"); 13 14const log = std.log.scoped(.relay); 15 16const HttpContext = router.HttpContext; 17 18/// check admin auth via headers, send error response if not authorized. returns true if authorized. 19pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool { 20 const admin_pw = std.posix.getenv("RELAY_ADMIN_PASSWORD") orelse { 21 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}"); 22 return false; 23 }; 24 25 const kv = headers orelse { 26 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 27 return false; 28 }; 29 30 // handshake parser lowercases all header names 31 const auth_value = kv.get("authorization") orelse { 32 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 33 return false; 34 }; 35 36 const bearer_prefix = "Bearer "; 37 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) { 38 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}"); 39 return false; 40 } 41 const token = auth_value[bearer_prefix.len..]; 42 if (!std.mem.eql(u8, token, admin_pw)) { 43 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}"); 44 return false; 45 } 46 return true; 47} 48 49pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 50 if (!checkAdmin(conn, headers)) return; 51 52 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 53 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}"); 54 return; 55 }; 56 defer parsed.deinit(); 57 const did = parsed.value.did; 58 59 // resolve DID → UID and take down 60 const uid = ctx.persist.uidForDid(did) catch { 61 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}"); 62 return; 63 }; 64 ctx.persist.takeDownUser(uid) catch { 65 h.respondJson(conn, .internal_server_error, "{\"error\":\"takedown failed\"}"); 66 return; 67 }; 68 69 // emit #account event so downstream consumers see the takedown 70 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| { 71 if (ctx.persist.persist(.account, uid, frame_bytes)) |relay_seq| { 72 ctx.bc.stats.relay_seq.store(relay_seq, .release); 73 const broadcast_data = broadcaster.resequenceFrame(ctx.persist.allocator, frame_bytes, relay_seq) orelse frame_bytes; 74 ctx.bc.broadcast(relay_seq, broadcast_data); 75 log.info("admin: emitted #account takedown event for {s} (seq={d})", .{ did, relay_seq }); 76 } else |err| { 77 log.warn("admin: failed to persist #account takedown event: {s}", .{@errorName(err)}); 78 } 79 } 80 81 log.info("admin: banned {s} (uid={d})", .{ did, uid }); 82 h.respondJson(conn, .ok, "{\"success\":true}"); 83} 84 85pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 86 if (!checkAdmin(conn, headers)) return; 87 88 const persist = ctx.persist; 89 const hosts = persist.listAllHosts(persist.allocator) catch { 90 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 91 return; 92 }; 93 defer { 94 for (hosts) |host| { 95 persist.allocator.free(host.hostname); 96 persist.allocator.free(host.status); 97 } 98 persist.allocator.free(hosts); 99 } 100 101 var list: std.ArrayListUnmanaged(u8) = .{}; 102 defer list.deinit(persist.allocator); 103 const w = list.writer(persist.allocator); 104 105 w.writeAll("{\"hosts\":[") catch return; 106 107 for (hosts, 0..) |host, i| { 108 if (i > 0) w.writeByte(',') catch return; 109 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d}}}", .{ 110 host.id, 111 host.hostname, 112 host.status, 113 host.last_seq, 114 host.failed_attempts, 115 }) catch return; 116 } 117 118 std.fmt.format(w, "],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return; 119 h.respondJson(conn, .ok, list.items); 120} 121 122pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void { 123 if (!checkAdmin(conn, headers)) return; 124 125 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 126 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 127 return; 128 }; 129 defer parsed.deinit(); 130 131 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 132 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 133 return; 134 }; 135 136 persist.updateHostStatus(host_info.id, "blocked") catch { 137 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 138 return; 139 }; 140 141 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 142 h.respondJson(conn, .ok, "{\"success\":true}"); 143} 144 145pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void { 146 if (!checkAdmin(conn, headers)) return; 147 148 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 149 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 150 return; 151 }; 152 defer parsed.deinit(); 153 154 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 155 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 156 return; 157 }; 158 159 persist.updateHostStatus(host_info.id, "active") catch { 160 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 161 return; 162 }; 163 persist.resetHostFailures(host_info.id) catch {}; 164 165 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 166 h.respondJson(conn, .ok, "{\"success\":true}"); 167} 168 169pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 170 if (!checkAdmin(conn, headers)) return; 171 172 const source = h.queryParam(query, "source") orelse "bsky.network"; 173 174 backfiller.start(source) catch |err| { 175 switch (err) { 176 error.AlreadyRunning => { 177 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}"); 178 }, 179 else => { 180 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}"); 181 }, 182 } 183 return; 184 }; 185 186 var buf: [256]u8 = undefined; 187 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch { 188 h.respondJson(conn, .ok, "{\"status\":\"started\"}"); 189 return; 190 }; 191 h.respondJson(conn, .ok, resp_body); 192} 193 194pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 195 if (!checkAdmin(conn, headers)) return; 196 197 const body = backfiller.getStatus(backfiller.allocator) catch { 198 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}"); 199 return; 200 }; 201 defer backfiller.allocator.free(body); 202 203 h.respondJson(conn, .ok, body); 204} 205 206// --- protocol helpers (used only by handleBan) --- 207 208/// build a CBOR #account frame for a takedown event. 209/// header: {op: 1, t: "#account"}, payload: {seq: 0, did: "...", time: "...", active: false, status: "takendown"} 210fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 { 211 const zat = @import("zat"); 212 const cbor = zat.cbor; 213 214 const header: cbor.Value = .{ .map = &.{ 215 .{ .key = "op", .value = .{ .unsigned = 1 } }, 216 .{ .key = "t", .value = .{ .text = "#account" } }, 217 } }; 218 219 var time_buf: [24]u8 = undefined; 220 const time_str = formatTimestamp(&time_buf); 221 222 const payload: cbor.Value = .{ .map = &.{ 223 .{ .key = "seq", .value = .{ .unsigned = 0 } }, 224 .{ .key = "did", .value = .{ .text = did } }, 225 .{ .key = "time", .value = .{ .text = time_str } }, 226 .{ .key = "active", .value = .{ .boolean = false } }, 227 .{ .key = "status", .value = .{ .text = "takendown" } }, 228 } }; 229 230 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null; 231 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch { 232 allocator.free(header_bytes); 233 return null; 234 }; 235 236 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch { 237 allocator.free(header_bytes); 238 allocator.free(payload_bytes); 239 return null; 240 }; 241 @memcpy(frame[0..header_bytes.len], header_bytes); 242 @memcpy(frame[header_bytes.len..], payload_bytes); 243 244 allocator.free(header_bytes); 245 allocator.free(payload_bytes); 246 247 return frame; 248} 249 250/// format current UTC time as ISO 8601 (YYYY-MM-DDTHH:MM:SSZ) 251fn formatTimestamp(buf: *[24]u8) []const u8 { 252 const ts: u64 = @intCast(std.time.timestamp()); 253 const es = std.time.epoch.EpochSeconds{ .secs = ts }; 254 const day = es.getEpochDay(); 255 const yd = day.calculateYearDay(); 256 const md = yd.calculateMonthDay(); 257 const ds = es.getDaySeconds(); 258 259 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{ 260 yd.year, 261 @as(u32, @intFromEnum(md.month)) + 1, 262 @as(u32, md.day_index) + 1, 263 ds.getHoursIntoDay(), 264 ds.getMinutesIntoHour(), 265 ds.getSecondsIntoMinute(), 266 }) catch "1970-01-01T00:00:00Z"; 267}