atproto relay implementation in zig zlay.waow.tech
at main 656 lines 26 kB view raw
1//! slurper — multi-host PDS crawl manager 2//! 3//! manages one Subscriber thread per tracked PDS host. handles: 4//! - loading known hosts from DB on startup 5//! - spawning/stopping subscriber workers 6//! - processing crawl requests (adding new hosts) 7//! - host validation (format, domain ban, describeServer, relay loop detection) 8//! - tracking host lifecycle (active → exhausted → blocked) 9//! 10//! all downstream components (Broadcaster, DiskPersist, Validator) are 11//! thread-safe for N concurrent producers, so this just orchestrates. 12 13const std = @import("std"); 14const http = std.http; 15const broadcaster = @import("broadcaster.zig"); 16const validator_mod = @import("validator.zig"); 17const event_log_mod = @import("event_log.zig"); 18const subscriber_mod = @import("subscriber.zig"); 19const collection_index_mod = @import("collection_index.zig"); 20const frame_worker_mod = @import("frame_worker.zig"); 21 22const Allocator = std.mem.Allocator; 23const log = std.log.scoped(.relay); 24 25pub const Options = struct { 26 seed_host: []const u8 = "bsky.network", 27 max_message_size: usize = 5 * 1024 * 1024, 28 frame_workers: u16 = 16, 29 frame_queue_capacity: u16 = 4096, 30}; 31 32// --- host validation --- 33// mirrors indigo relay's ParseHostname + CheckHost + relay loop detection 34 35pub const HostValidationError = error{ 36 EmptyHostname, 37 InvalidCharacter, 38 InvalidLabel, 39 TooFewLabels, 40 LooksLikeIpAddress, 41 PortNotAllowed, 42 LocalhostNotAllowed, 43 DomainBanned, 44 HostUnreachable, 45 NotAPds, 46 IsARelay, 47}; 48 49/// validate and normalize a hostname for crawling. 50/// rejects IPs, ports, localhost, invalid DNS names. 51/// returns lowercased hostname on success. 52pub fn validateHostname(allocator: Allocator, raw: []const u8) HostValidationError![]u8 { 53 if (raw.len == 0) return error.EmptyHostname; 54 55 // strip scheme if present 56 var hostname = raw; 57 for ([_][]const u8{ "https://", "http://", "wss://", "ws://" }) |scheme| { 58 if (hostname.len > scheme.len and std.ascii.startsWithIgnoreCase(hostname, scheme)) { 59 hostname = hostname[scheme.len..]; 60 break; 61 } 62 } 63 64 // strip trailing slash and path 65 if (std.mem.indexOfScalar(u8, hostname, '/')) |i| { 66 hostname = hostname[0..i]; 67 } 68 69 // reject ports (Go relay rejects non-localhost ports) 70 if (std.mem.indexOfScalar(u8, hostname, ':')) |_| { 71 return error.PortNotAllowed; 72 } 73 74 // reject localhost 75 if (std.ascii.eqlIgnoreCase(hostname, "localhost")) { 76 return error.LocalhostNotAllowed; 77 } 78 79 // validate characters and split into labels 80 var label_count: usize = 0; 81 var all_labels_numeric = true; 82 var it = std.mem.splitScalar(u8, hostname, '.'); 83 while (it.next()) |label| { 84 if (label.len == 0 or label.len > 63) return error.InvalidLabel; 85 // labels must be alphanumeric with hyphens, no leading/trailing hyphens 86 if (label[0] == '-' or label[label.len - 1] == '-') return error.InvalidLabel; 87 var is_numeric = true; 88 for (label) |c| { 89 if (!std.ascii.isAlphanumeric(c) and c != '-') return error.InvalidCharacter; 90 if (!std.ascii.isDigit(c)) is_numeric = false; 91 } 92 if (!is_numeric) all_labels_numeric = false; 93 label_count += 1; 94 } 95 96 // need at least 2 labels (e.g. "pds.example.com", "bsky.network") 97 if (label_count < 2) return error.TooFewLabels; 98 99 // all-numeric labels = IP address (e.g. "192.168.1.1") 100 if (all_labels_numeric) return error.LooksLikeIpAddress; 101 102 // lowercase normalize 103 const result = allocator.alloc(u8, hostname.len) catch return error.EmptyHostname; 104 for (hostname, 0..) |c, i| { 105 result[i] = std.ascii.toLower(c); 106 } 107 return result; 108} 109 110/// SSRF protection: resolve hostname and reject private/reserved IP ranges. 111/// Go relay: ssrf.go PublicOnlyTransport — rejects 10/8, 172.16/12, 192.168/16, 127/8, link-local. 112fn rejectPrivateHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { 113 const addr_list = std.net.getAddressList(allocator, hostname, 443) catch return error.HostUnreachable; 114 defer addr_list.deinit(); 115 116 if (addr_list.addrs.len == 0) return error.HostUnreachable; 117 118 // check all resolved addresses — reject if ANY is private 119 for (addr_list.addrs) |addr| { 120 switch (addr.any.family) { 121 std.posix.AF.INET => { 122 const ip4 = addr.in.sa.addr; 123 const bytes: [4]u8 = @bitCast(ip4); 124 if (bytes[0] == 10 or // 10.0.0.0/8 125 (bytes[0] == 172 and (bytes[1] & 0xf0) == 16) or // 172.16.0.0/12 126 (bytes[0] == 192 and bytes[1] == 168) or // 192.168.0.0/16 127 bytes[0] == 127 or // 127.0.0.0/8 128 bytes[0] == 0 or // 0.0.0.0/8 129 (bytes[0] == 169 and bytes[1] == 254)) // 169.254.0.0/16 link-local 130 { 131 log.warn("SSRF: {s} resolves to private IP {d}.{d}.{d}.{d}", .{ hostname, bytes[0], bytes[1], bytes[2], bytes[3] }); 132 return error.HostUnreachable; 133 } 134 }, 135 else => {}, // allow IPv6 for now (could add RFC 4193 check later) 136 } 137 } 138} 139 140/// check that a host is a real PDS by calling describeServer. 141/// also checks Server header for relay loop detection. 142/// Go relay: host_checker.go CheckHost + slurper.go Server header check. 143fn checkHost(allocator: Allocator, hostname: []const u8) HostValidationError!void { 144 // SSRF protection: reject private IPs before making any request 145 rejectPrivateHost(allocator, hostname) catch |err| return err; 146 var url_buf: [512]u8 = undefined; 147 const url = std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.server.describeServer", .{hostname}) catch return error.HostUnreachable; 148 149 var client: http.Client = .{ .allocator = allocator }; 150 defer client.deinit(); 151 152 const uri = std.Uri.parse(url) catch return error.HostUnreachable; 153 var req = client.request(.GET, uri, .{}) catch return error.HostUnreachable; 154 defer req.deinit(); 155 req.sendBodiless() catch return error.HostUnreachable; 156 157 var redirect_buf: [2048]u8 = undefined; 158 const response = req.receiveHead(&redirect_buf) catch return error.HostUnreachable; 159 160 if (response.head.status != .ok) return error.NotAPds; 161 162 // relay loop detection: check Server header for "atproto-relay" 163 // Go relay: slurper.go — auto-bans hosts whose Server header contains "atproto-relay" 164 if (findHeaderInRaw(response.head.bytes, "server")) |server_val| { 165 if (std.mem.indexOf(u8, server_val, "atproto-relay") != null) { 166 return error.IsARelay; 167 } 168 } 169} 170 171/// search raw HTTP headers for a header by name (case-insensitive). 172/// returns the trimmed value, or null if not found. 173fn findHeaderInRaw(raw: []const u8, name: []const u8) ?[]const u8 { 174 var it = std.mem.splitSequence(u8, raw, "\r\n"); 175 _ = it.next(); // skip status line 176 while (it.next()) |line| { 177 if (line.len == 0) break; 178 const colon = std.mem.indexOfScalar(u8, line, ':') orelse continue; 179 const key = std.mem.trim(u8, line[0..colon], " "); 180 if (key.len != name.len) continue; 181 // case-insensitive compare 182 var match = true; 183 for (key, name) |a, b| { 184 if (std.ascii.toLower(a) != std.ascii.toLower(b)) { 185 match = false; 186 break; 187 } 188 } 189 if (match) { 190 return std.mem.trim(u8, line[colon + 1 ..], " "); 191 } 192 } 193 return null; 194} 195 196const WorkerEntry = struct { 197 thread: std.Thread, 198 subscriber: *subscriber_mod.Subscriber, 199}; 200 201pub const Slurper = struct { 202 allocator: Allocator, 203 bc: *broadcaster.Broadcaster, 204 validator: *validator_mod.Validator, 205 persist: *event_log_mod.DiskPersist, 206 collection_index: ?*collection_index_mod.CollectionIndex = null, 207 shutdown: *std.atomic.Value(bool), 208 options: Options, 209 210 // frame processing pool — offloads heavy work from reader threads 211 frame_pool: ?frame_worker_mod.FramePool = null, 212 213 // shared TLS CA bundle — loaded once, used by all subscriber connections 214 ca_bundle: ?std.crypto.Certificate.Bundle = null, 215 216 // active subscriber threads, keyed by host_id 217 workers: std.AutoHashMapUnmanaged(u64, WorkerEntry) = .{}, 218 workers_mutex: std.Thread.Mutex = .{}, 219 220 // crawl request queue 221 crawl_queue: std.ArrayListUnmanaged([]const u8) = .{}, 222 crawl_mutex: std.Thread.Mutex = .{}, 223 crawl_cond: std.Thread.Condition = .{}, 224 225 // background threads 226 startup_thread: ?std.Thread = null, 227 crawl_thread: ?std.Thread = null, 228 229 pub fn init( 230 allocator: Allocator, 231 bc: *broadcaster.Broadcaster, 232 val: *validator_mod.Validator, 233 persist: *event_log_mod.DiskPersist, 234 shutdown: *std.atomic.Value(bool), 235 options: Options, 236 ) Slurper { 237 return .{ 238 .allocator = allocator, 239 .bc = bc, 240 .validator = val, 241 .persist = persist, 242 .shutdown = shutdown, 243 .options = options, 244 }; 245 } 246 247 /// start the slurper: bootstrap hosts from seed relay, load from DB, spawn workers. 248 /// Go relay: pull-hosts bootstraps from bsky.network's listHosts, then crawls each PDS directly. 249 pub fn start(self: *Slurper) !void { 250 // load CA bundle once — shared by all subscriber TLS connections 251 var bundle: std.crypto.Certificate.Bundle = .{}; 252 try bundle.rescan(self.allocator); 253 self.ca_bundle = bundle; 254 log.info("loaded shared CA bundle", .{}); 255 256 // create frame processing pool — worker threads handle heavy decode/validate/persist 257 self.frame_pool = try frame_worker_mod.FramePool.init(self.allocator, .{ 258 .num_workers = self.options.frame_workers, 259 .queue_capacity = self.options.frame_queue_capacity, 260 .stack_size = @import("main.zig").default_stack_size, 261 }); 262 log.info("frame pool started: {d} workers, queue capacity {d}", .{ self.options.frame_workers, self.options.frame_queue_capacity }); 263 264 // spawn worker startup in background so HTTP server + probes come up immediately. 265 // pullHosts + listActiveHosts + spawnWorker all happen in the background thread. 266 self.startup_thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, spawnWorkers, .{self}); 267 self.crawl_thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, processCrawlQueue, .{self}); 268 } 269 270 /// pull PDS host list from the seed relay's com.atproto.sync.listHosts endpoint. 271 /// Go relay: cmd/relay/pull.go — one-time bootstrap, reads REST API, not firehose. 272 pub fn pullHosts(self: *Slurper) !void { 273 var cursor: ?[]const u8 = null; 274 var total: usize = 0; 275 const limit = 500; 276 277 var client: http.Client = .{ .allocator = self.allocator }; 278 defer client.deinit(); 279 280 while (true) { 281 if (self.shutdown.load(.acquire)) break; 282 283 // build URL with pagination 284 var url_buf: [512]u8 = undefined; 285 const url = if (cursor) |c| 286 std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.sync.listHosts?limit={d}&cursor={s}", .{ self.options.seed_host, limit, c }) catch break 287 else 288 std.fmt.bufPrint(&url_buf, "https://{s}/xrpc/com.atproto.sync.listHosts?limit={d}", .{ self.options.seed_host, limit }) catch break; 289 290 var aw: std.Io.Writer.Allocating = .init(self.allocator); 291 defer aw.deinit(); 292 293 const result = client.fetch(.{ 294 .location = .{ .url = url }, 295 .response_writer = &aw.writer, 296 .method = .GET, 297 }) catch |err| { 298 log.warn("pullHosts: fetch failed: {s}", .{@errorName(err)}); 299 break; 300 }; 301 302 if (result.status != .ok) { 303 log.warn("pullHosts: got status {d}", .{@intFromEnum(result.status)}); 304 break; 305 } 306 307 const body = aw.written(); 308 309 // parse JSON response: { "hosts": [{"hostname": "...", "status": "..."}, ...], "cursor": "..." } 310 const parsed = std.json.parseFromSlice(ListHostsResponse, self.allocator, body, .{ .ignore_unknown_fields = true }) catch |err| { 311 log.warn("pullHosts: JSON parse failed: {s}", .{@errorName(err)}); 312 break; 313 }; 314 defer parsed.deinit(); 315 316 const hosts = parsed.value.hosts orelse break; 317 if (hosts.len == 0) break; 318 319 var added: usize = 0; 320 for (hosts) |host| { 321 // skip non-active hosts 322 if (host.status) |s| { 323 if (!std.mem.eql(u8, s, "active")) continue; 324 } 325 // validate hostname format (rejects IPs, localhost, etc.) 326 const normalized = validateHostname(self.allocator, host.hostname) catch continue; 327 defer self.allocator.free(normalized); 328 329 // skip banned domains 330 if (self.persist.isDomainBanned(normalized)) continue; 331 332 // insert into DB (no describeServer check — the seed relay already vetted them) 333 _ = self.persist.getOrCreateHost(normalized) catch continue; 334 added += 1; 335 } 336 total += added; 337 log.info("pullHosts: page fetched, {d} hosts added ({d} total)", .{ added, total }); 338 339 // advance cursor 340 if (parsed.value.cursor) |next_cursor| { 341 // free previous cursor if we allocated one 342 if (cursor) |prev| self.allocator.free(prev); 343 cursor = self.allocator.dupe(u8, next_cursor) catch break; 344 } else { 345 break; // no more pages 346 } 347 } 348 349 // free final cursor 350 if (cursor) |c| self.allocator.free(c); 351 log.info("pullHosts: bootstrap complete, {d} hosts added from {s}", .{ total, self.options.seed_host }); 352 } 353 354 const ListHostsResponse = struct { 355 hosts: ?[]const ListHostEntry = null, 356 cursor: ?[]const u8 = null, 357 }; 358 359 const ListHostEntry = struct { 360 hostname: []const u8, 361 status: ?[]const u8 = null, 362 }; 363 364 /// add a crawl request (from requestCrawl endpoint) 365 pub fn addCrawlRequest(self: *Slurper, hostname: []const u8) !void { 366 const duped = try self.allocator.dupe(u8, hostname); 367 self.crawl_mutex.lock(); 368 defer self.crawl_mutex.unlock(); 369 try self.crawl_queue.append(self.allocator, duped); 370 self.crawl_cond.signal(); 371 } 372 373 /// validate and add a host: format check, domain ban, describeServer, then spawn. 374 /// mirrors Go relay's requestCrawl → SubscribeToHost pipeline. 375 fn addHost(self: *Slurper, raw_hostname: []const u8) !void { 376 // step 1: validate and normalize hostname format 377 // Go relay: host.go ParseHostname 378 const hostname = validateHostname(self.allocator, raw_hostname) catch |err| { 379 log.warn("host validation failed for '{s}': {s}", .{ raw_hostname, @errorName(err) }); 380 return; 381 }; 382 defer self.allocator.free(hostname); 383 384 // step 2: domain ban check (suffix-based) 385 // Go relay: domain_ban.go DomainIsBanned 386 if (self.persist.isDomainBanned(hostname)) { 387 log.warn("host {s}: domain is banned, rejecting", .{hostname}); 388 return; 389 } 390 391 // step 3: check if host is banned/blocked in DB 392 // Go relay: crawl.go checks host.Status == HostStatusBanned 393 if (self.persist.isHostBanned(hostname)) { 394 log.warn("host {s}: banned/blocked in DB, rejecting", .{hostname}); 395 return; 396 } 397 398 // step 4: dedup — check if already tracked 399 // Go relay: crawl.go CheckIfSubscribed 400 const host_info = try self.persist.getOrCreateHost(hostname); 401 { 402 self.workers_mutex.lock(); 403 defer self.workers_mutex.unlock(); 404 if (self.workers.contains(host_info.id)) { 405 log.debug("host {s} already has a worker, skipping", .{hostname}); 406 return; 407 } 408 } 409 410 // step 5: describeServer liveness check 411 // Go relay: host_checker.go CheckHost (with SSRF protection) 412 checkHost(self.allocator, hostname) catch |err| { 413 log.warn("host {s}: describeServer check failed: {s}", .{ hostname, @errorName(err) }); 414 return; 415 }; 416 417 // reset status and failure count — host passed describeServer, give it a fresh start. 418 // without this, exhausted hosts accumulate failures across requestCrawl cycles 419 // and immediately re-exhaust on a single failure. 420 self.persist.updateHostStatus(host_info.id, "active") catch {}; 421 self.persist.resetHostFailures(host_info.id) catch {}; 422 423 try self.spawnWorker(host_info.id, hostname); 424 log.info("added host {s} (id={d})", .{ hostname, host_info.id }); 425 } 426 427 /// spawn a subscriber thread for a host 428 fn spawnWorker(self: *Slurper, host_id: u64, hostname: []const u8) !void { 429 const hostname_duped = try self.allocator.dupe(u8, hostname); 430 errdefer self.allocator.free(hostname_duped); 431 432 const sub = try self.allocator.create(subscriber_mod.Subscriber); 433 errdefer self.allocator.destroy(sub); 434 435 const account_count: u64 = self.persist.getEffectiveAccountCount(host_id); 436 437 sub.* = subscriber_mod.Subscriber.init( 438 self.allocator, 439 self.bc, 440 self.validator, 441 self.persist, 442 self.shutdown, 443 .{ 444 .hostname = hostname_duped, 445 .max_message_size = self.options.max_message_size, 446 .host_id = host_id, 447 .account_count = account_count, 448 .ca_bundle = self.ca_bundle, 449 }, 450 ); 451 sub.collection_index = self.collection_index; 452 if (self.frame_pool) |*fp| sub.pool = fp; 453 454 const thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, runWorker, .{ self, host_id, sub }); 455 456 self.workers_mutex.lock(); 457 defer self.workers_mutex.unlock(); 458 try self.workers.put(self.allocator, host_id, .{ 459 .thread = thread, 460 .subscriber = sub, 461 }); 462 _ = self.bc.stats.connected_inbound.fetchAdd(1, .monotonic); 463 } 464 465 /// worker thread wrapper — runs subscriber, cleans up on exit 466 fn runWorker(self: *Slurper, host_id: u64, sub: *subscriber_mod.Subscriber) void { 467 sub.run(); 468 469 // subscriber returned — remove from active workers 470 self.workers_mutex.lock(); 471 defer self.workers_mutex.unlock(); 472 _ = self.workers.remove(host_id); 473 _ = self.bc.stats.connected_inbound.fetchSub(1, .monotonic); 474 475 log.info("worker for host_id={d} ({s}) exited", .{ host_id, sub.options.hostname }); 476 477 self.allocator.free(sub.options.hostname); 478 self.allocator.destroy(sub); 479 } 480 481 /// background thread: load hosts from DB and spawn all workers. 482 /// runs in background so HTTP server + probes come up immediately. 483 /// Go relay: ResubscribeAllHosts loops with 1ms sleep per host (goroutines). 484 /// we spawn all at once — the brief memory spike from concurrent TLS handshakes 485 /// is shorter than a throttled ramp (many hosts fail-fast, freeing memory quickly). 486 fn spawnWorkers(self: *Slurper) void { 487 // pull hosts from seed relay first — idempotent (getOrCreateHost skips existing) 488 log.info("pulling hosts from {s}...", .{self.options.seed_host}); 489 self.pullHosts() catch |err| { 490 log.warn("pullHosts from {s} failed: {s}", .{ self.options.seed_host, @errorName(err) }); 491 }; 492 493 const hosts = self.persist.listActiveHosts(self.allocator) catch |err| { 494 log.err("failed to load hosts: {s}", .{@errorName(err)}); 495 return; 496 }; 497 defer { 498 for (hosts) |h| { 499 self.allocator.free(h.hostname); 500 self.allocator.free(h.status); 501 } 502 self.allocator.free(hosts); 503 } 504 505 for (hosts) |host| { 506 if (self.shutdown.load(.acquire)) break; 507 self.spawnWorker(host.id, host.hostname) catch |err| { 508 log.warn("failed to spawn worker for {s}: {s}", .{ host.hostname, @errorName(err) }); 509 }; 510 } 511 512 log.info("startup complete: {d} host(s) spawned", .{hosts.len}); 513 } 514 515 /// background thread: process crawl requests 516 fn processCrawlQueue(self: *Slurper) void { 517 while (!self.shutdown.load(.acquire)) { 518 var hostname: ?[]const u8 = null; 519 { 520 self.crawl_mutex.lock(); 521 defer self.crawl_mutex.unlock(); 522 while (self.crawl_queue.items.len == 0 and !self.shutdown.load(.acquire)) { 523 self.crawl_cond.timedWait(&self.crawl_mutex, 1 * std.time.ns_per_s) catch {}; 524 } 525 if (self.crawl_queue.items.len > 0) { 526 hostname = self.crawl_queue.orderedRemove(0); 527 } 528 } 529 530 if (hostname) |h| { 531 defer self.allocator.free(h); 532 self.addHost(h) catch |err| { 533 log.warn("crawl request failed for {s}: {s}", .{ h, @errorName(err) }); 534 }; 535 } 536 } 537 } 538 539 /// number of active workers 540 pub fn workerCount(self: *Slurper) usize { 541 self.workers_mutex.lock(); 542 defer self.workers_mutex.unlock(); 543 return self.workers.count(); 544 } 545 546 /// update rate limits for a running subscriber (called from admin API). 547 /// if the host has a worker, recomputes and applies new limits immediately. 548 pub fn updateHostLimits(self: *Slurper, host_id: u64, account_count: u64) void { 549 self.workers_mutex.lock(); 550 defer self.workers_mutex.unlock(); 551 if (self.workers.get(host_id)) |entry| { 552 const trusted = subscriber_mod.isTrustedHost(entry.subscriber.options.hostname); 553 const limits = subscriber_mod.computeLimits(trusted, account_count); 554 entry.subscriber.rate_limiter.updateLimits(limits.sec, limits.hour, limits.day); 555 log.info("updated rate limits for host_id={d}: sec={d} hour={d} day={d}", .{ 556 host_id, limits.sec, limits.hour, limits.day, 557 }); 558 } 559 } 560 561 /// shutdown all workers and clean up 562 pub fn deinit(self: *Slurper) void { 563 // join background threads 564 if (self.startup_thread) |t| t.join(); 565 self.crawl_cond.signal(); 566 if (self.crawl_thread) |t| t.join(); 567 568 // collect threads to join (can't join while holding workers_mutex) 569 var threads_to_join: std.ArrayListUnmanaged(std.Thread) = .{}; 570 defer threads_to_join.deinit(self.allocator); 571 572 { 573 self.workers_mutex.lock(); 574 defer self.workers_mutex.unlock(); 575 var it = self.workers.iterator(); 576 while (it.next()) |entry| { 577 threads_to_join.append(self.allocator, entry.value_ptr.thread) catch {}; 578 } 579 } 580 581 // join all reader threads FIRST (they stop submitting to pool) 582 for (threads_to_join.items) |t| t.join(); 583 584 // then drain + join pool workers (processes remaining queued frames) 585 if (self.frame_pool) |*fp| { 586 fp.shutdown(); 587 fp.deinit(); 588 self.frame_pool = null; 589 } 590 591 // clean up workers map 592 self.workers.deinit(self.allocator); 593 594 // clean up crawl queue 595 for (self.crawl_queue.items) |h| self.allocator.free(h); 596 self.crawl_queue.deinit(self.allocator); 597 598 // free shared CA bundle 599 if (self.ca_bundle) |*b| b.deinit(self.allocator); 600 } 601}; 602 603// --- tests --- 604 605test "validateHostname accepts valid PDS hostnames" { 606 const alloc = std.testing.allocator; 607 608 // basic valid hostnames 609 const h1 = try validateHostname(alloc, "pds.example.com"); 610 defer alloc.free(h1); 611 try std.testing.expectEqualStrings("pds.example.com", h1); 612 613 // two labels minimum 614 const h2 = try validateHostname(alloc, "bsky.network"); 615 defer alloc.free(h2); 616 try std.testing.expectEqualStrings("bsky.network", h2); 617 618 // lowercases 619 const h3 = try validateHostname(alloc, "PDS.Example.COM"); 620 defer alloc.free(h3); 621 try std.testing.expectEqualStrings("pds.example.com", h3); 622 623 // strips scheme 624 const h4 = try validateHostname(alloc, "https://pds.example.com"); 625 defer alloc.free(h4); 626 try std.testing.expectEqualStrings("pds.example.com", h4); 627 628 // strips path 629 const h5 = try validateHostname(alloc, "pds.example.com/some/path"); 630 defer alloc.free(h5); 631 try std.testing.expectEqualStrings("pds.example.com", h5); 632} 633 634test "validateHostname rejects invalid hostnames" { 635 const alloc = std.testing.allocator; 636 637 // empty 638 try std.testing.expectError(error.EmptyHostname, validateHostname(alloc, "")); 639 // localhost 640 try std.testing.expectError(error.LocalhostNotAllowed, validateHostname(alloc, "localhost")); 641 // single label (non-localhost) 642 try std.testing.expectError(error.TooFewLabels, validateHostname(alloc, "intranet")); 643 // IP address 644 try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "192.168.1.1")); 645 try std.testing.expectError(error.LooksLikeIpAddress, validateHostname(alloc, "10.0.0.1")); 646 // port 647 try std.testing.expectError(error.PortNotAllowed, validateHostname(alloc, "pds.example.com:443")); 648 // invalid characters 649 try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam ple.com")); 650 try std.testing.expectError(error.InvalidCharacter, validateHostname(alloc, "pds.exam_ple.com")); 651 // leading/trailing hyphens 652 try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "-pds.example.com")); 653 try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds-.example.com")); 654 // empty label 655 try std.testing.expectError(error.InvalidLabel, validateHostname(alloc, "pds..example.com")); 656}