An asynchronous IO runtime
at main 601 lines 19 kB view raw
1const std = @import("std"); 2const stda = @import("../../stda.zig"); 3const ourio = @import("ourio"); 4 5const Allocator = std.mem.Allocator; 6const Context = ourio.Context; 7const Ring = ourio.Ring; 8const Task = ourio.Task; 9const assert = std.debug.assert; 10const net = stda.net; 11const posix = std.posix; 12 13const default_dns = &stda.options.nameservers; 14 15pub const Resolver = struct { 16 gpa: Allocator, 17 ctx: Context, 18 config: Config = .{}, 19 20 const Msg = enum { open_resolv, read_resolv }; 21 22 /// initialize a Resolver instance. When the resolver is complete with initialization, a userptr 23 /// result type will be delivered to ctx. The resolver will then be ready to resolve DNS queries 24 pub fn init( 25 self: *Resolver, 26 gpa: Allocator, 27 io: *Ring, 28 ctx: Context, 29 ) Allocator.Error!void { 30 self.* = .{ 31 .gpa = gpa, 32 .ctx = ctx, 33 }; 34 35 _ = try io.open("/etc/resolv.conf", .{ .CLOEXEC = true }, 0, .{ 36 .cb = Resolver.onCompletion, 37 .ptr = self, 38 .msg = @intFromEnum(Msg.open_resolv), 39 }); 40 } 41 42 pub fn deinit(self: *Resolver) void { 43 self.gpa.free(self.config.nameservers); 44 } 45 46 pub fn resolveQuery(self: *Resolver, io: *Ring, query: Question, ctx: ourio.Context) !void { 47 assert(self.config.nameservers.len > 0); 48 49 const conn = try self.gpa.create(Connection); 50 conn.* = .{ .gpa = self.gpa, .ctx = ctx, .config = self.config }; 51 try conn.writeQuestion(query); 52 53 try conn.tryNext(io); 54 } 55 56 pub fn onCompletion(io: *Ring, task: Task) anyerror!void { 57 const self = task.userdataCast(Resolver); 58 const msg = task.msgToEnum(Resolver.Msg); 59 const result = task.result.?; 60 61 switch (msg) { 62 .open_resolv => { 63 const fd = result.open catch { 64 self.config.nameservers = try self.gpa.dupe(std.net.Address, default_dns); 65 const t: Task = .{ 66 .callback = self.ctx.cb, 67 .msg = self.ctx.msg, 68 .userdata = self.ctx.ptr, 69 .result = .{ .userptr = self }, 70 }; 71 try self.ctx.cb(io, t); 72 return; 73 }; 74 75 const buffer = try self.gpa.alloc(u8, 4096); 76 errdefer self.gpa.free(buffer); 77 78 _ = try io.read(fd, buffer, .{ 79 .cb = Resolver.onCompletion, 80 .ptr = self, 81 .msg = @intFromEnum(Resolver.Msg.read_resolv), 82 }); 83 }, 84 85 .read_resolv => { 86 const buffer = task.req.read.buffer; 87 defer self.gpa.free(buffer); 88 89 _ = try io.close(task.req.read.fd, .{}); 90 91 const n = result.read catch |err| { 92 const t: Task = .{ 93 .callback = self.ctx.cb, 94 .msg = self.ctx.msg, 95 .userdata = self.ctx.ptr, 96 .result = .{ .userptr = err }, 97 }; 98 try self.ctx.cb(io, t); 99 return; 100 }; 101 102 if (n >= buffer.len) { 103 @panic("TODO: more to read"); 104 } 105 106 var line_iter = std.mem.splitScalar(u8, buffer[0..n], '\n'); 107 var addresses: std.ArrayListUnmanaged(std.net.Address) = .empty; 108 defer addresses.deinit(self.gpa); 109 110 while (line_iter.next()) |line| { 111 if (line.len == 0 or line[0] == ';' or line[0] == '#') continue; 112 113 var iter = std.mem.splitAny(u8, line, &std.ascii.whitespace); 114 const key = iter.first(); 115 116 if (std.mem.eql(u8, key, "nameserver")) { 117 const addr = try std.net.Address.parseIp(iter.rest(), 53); 118 try addresses.append(self.gpa, addr); 119 continue; 120 } 121 122 if (std.mem.eql(u8, key, "options")) { 123 while (iter.next()) |opt| { 124 if (std.mem.startsWith(u8, opt, "timeout:")) { 125 const timeout = std.fmt.parseInt(u5, opt[8..], 10) catch 30; 126 self.config.timeout_s = @min(30, timeout); 127 continue; 128 } 129 130 if (std.mem.startsWith(u8, opt, "attempts:")) { 131 const attempts = std.fmt.parseInt(u3, opt[9..], 10) catch 5; 132 self.config.attempts = @max(@min(5, attempts), 1); 133 continue; 134 } 135 136 if (std.mem.eql(u8, opt, "edns0")) { 137 self.config.edns0 = true; 138 continue; 139 } 140 } 141 } 142 } 143 144 self.config.nameservers = try addresses.toOwnedSlice(self.gpa); 145 146 const t: Task = .{ 147 .callback = self.ctx.cb, 148 .msg = self.ctx.msg, 149 .userdata = self.ctx.ptr, 150 .result = .{ .userptr = self }, 151 }; 152 try self.ctx.cb(io, t); 153 }, 154 } 155 } 156}; 157 158pub const Config = struct { 159 nameservers: []const std.net.Address = &.{}, 160 161 /// timeout_s is silently capped to 30 according to man resolv.conf 162 timeout_s: u5 = 30, 163 164 /// attempts is capped at 5 165 attempts: u3 = 5, 166 167 edns0: bool = false, 168}; 169 170pub const Header = packed struct(u96) { 171 id: u16 = 0, 172 173 flags1: packed struct(u8) { 174 recursion_desired: bool = true, 175 truncated: bool = false, 176 authoritative_answer: bool = false, 177 opcode: enum(u4) { 178 query = 0, 179 inverse_query = 1, 180 server_status_request = 2, 181 } = .query, 182 is_response: bool = false, 183 } = .{}, 184 185 flags2: packed struct(u8) { 186 response_code: enum(u4) { 187 success = 0, 188 format_error = 1, 189 server_failure = 2, 190 name_error = 3, 191 not_implemented = 4, 192 refuse = 5, 193 } = .success, 194 z: u3 = 0, 195 recursion_available: bool = false, 196 } = .{}, 197 198 question_count: u16 = 0, 199 200 answer_count: u16 = 0, 201 202 authority_count: u16 = 0, 203 204 additional_count: u16 = 0, 205 206 pub fn asBytes(self: Header) [12]u8 { 207 var bytes: [12]u8 = undefined; 208 var fbs = std.io.fixedBufferStream(&bytes); 209 fbs.writer().writeInt(u16, self.id, .big) catch unreachable; 210 211 fbs.writer().writeByte(@bitCast(self.flags1)) catch unreachable; 212 fbs.writer().writeByte(@bitCast(self.flags2)) catch unreachable; 213 214 fbs.writer().writeInt(u16, self.question_count, .big) catch unreachable; 215 fbs.writer().writeInt(u16, self.answer_count, .big) catch unreachable; 216 fbs.writer().writeInt(u16, self.authority_count, .big) catch unreachable; 217 fbs.writer().writeInt(u16, self.additional_count, .big) catch unreachable; 218 assert(fbs.pos == 12); 219 return bytes; 220 } 221}; 222 223pub const Question = struct { 224 host: []const u8, 225 type: ResourceType = .A, 226 class: enum(u16) { 227 IN = 1, 228 // CS = 2, 229 // CH = 3, 230 // HS = 4, 231 // WILDCARD = 255, 232 } = .IN, 233}; 234 235pub const ResourceType = enum(u16) { 236 A = 1, 237 // NS = 2, 238 // MD = 3, 239 // MF = 4, 240 // CNAME = 5, 241 // SOA = 6, 242 // MB = 7, 243 // MG = 8, 244 // MR = 9, 245 // NULL = 10, 246 // WKS = 11, 247 // PTR = 12, 248 // HINFO = 13, 249 // MINFO = 14, 250 // MX = 15, 251 // TXT = 16, 252 AAAA = 28, 253 SRV = 33, 254 // OPT = 41, 255}; 256 257pub const Answer = union(ResourceType) { 258 A: [4]u8, 259 AAAA: [16]u8, 260 SRV: struct { 261 priority: u16, 262 weight: u16, 263 port: u16, 264 target: []const u8, 265 }, 266}; 267 268pub const Response = struct { 269 bytes: []const u8, 270 271 pub fn header(self: Response) Header { 272 assert(self.bytes.len >= 12); 273 const readInt = std.mem.readInt; 274 275 return .{ 276 .id = readInt(u16, self.bytes[0..2], .big), 277 .flags1 = @bitCast(self.bytes[2]), 278 .flags2 = @bitCast(self.bytes[3]), 279 .question_count = readInt(u16, self.bytes[4..6], .big), 280 .answer_count = readInt(u16, self.bytes[6..8], .big), 281 .authority_count = readInt(u16, self.bytes[8..10], .big), 282 .additional_count = readInt(u16, self.bytes[10..12], .big), 283 }; 284 } 285 286 pub const AnswerIterator = struct { 287 bytes: []const u8, 288 /// offset into bytes 289 offset: usize = 0, 290 291 count: usize, 292 /// number of answers we have returned 293 idx: usize = 0, 294 295 pub fn next(self: *AnswerIterator) ?Answer { 296 if (self.idx >= self.count or self.offset >= self.bytes.len) return null; 297 defer self.idx += 1; 298 299 // Read the name 300 const b = self.bytes[self.offset]; 301 if (b & 0b1100_0000 == 0) { 302 // Encoded name. Get past this 303 self.offset = std.mem.indexOfScalar(u8, self.bytes[self.idx..], 0x00) orelse 304 return null; 305 } else { 306 // Name is pointer, we can advance 2 bytes 307 self.offset += 2; 308 } 309 310 const typ: ResourceType = @enumFromInt(std.mem.readInt( 311 u16, 312 self.bytes[self.offset..][0..2], 313 .big, 314 )); 315 self.offset += 2; 316 const class = std.mem.readInt(u16, self.bytes[self.offset..][0..2], .big); 317 assert(class == 1); 318 self.offset += 2; 319 const ttl = std.mem.readInt(u32, self.bytes[self.offset..][0..4], .big); 320 _ = ttl; 321 self.offset += 4; 322 const rd_len = std.mem.readInt(u16, self.bytes[self.offset..][0..2], .big); 323 self.offset += 2; 324 defer self.offset += rd_len; 325 326 switch (typ) { 327 .A => { 328 assert(rd_len == 4); 329 return .{ .A = .{ 330 self.bytes[self.offset], 331 self.bytes[self.offset + 1], 332 self.bytes[self.offset + 2], 333 self.bytes[self.offset + 3], 334 } }; 335 }, 336 337 .AAAA => { 338 assert(rd_len == 4); 339 return .{ .AAAA = .{ 340 self.bytes[self.offset], 341 self.bytes[self.offset + 1], 342 self.bytes[self.offset + 2], 343 self.bytes[self.offset + 3], 344 self.bytes[self.offset + 4], 345 self.bytes[self.offset + 5], 346 self.bytes[self.offset + 6], 347 self.bytes[self.offset + 7], 348 self.bytes[self.offset + 8], 349 self.bytes[self.offset + 9], 350 self.bytes[self.offset + 10], 351 self.bytes[self.offset + 11], 352 self.bytes[self.offset + 12], 353 self.bytes[self.offset + 13], 354 self.bytes[self.offset + 14], 355 self.bytes[self.offset + 15], 356 } }; 357 }, 358 359 .SRV => { 360 assert(rd_len > 6); 361 const rdata = self.bytes[self.offset..]; 362 const priority = std.mem.readInt(u16, rdata[0..2], .big); 363 const weight = std.mem.readInt(u16, rdata[2..4], .big); 364 const port = std.mem.readInt(u16, rdata[4..6], .big); 365 366 var buf: [256]u8 = undefined; 367 var idx: usize = 0; 368 var offset: usize = 6; 369 while (true) { 370 const len = rdata[offset]; 371 if (len == 0x00) break; 372 373 if (idx > 0) { 374 buf[idx] = '.'; 375 idx += 1; 376 } 377 offset += 1; 378 @memcpy(buf[idx .. idx + len], rdata[offset .. offset + len]); 379 offset += len; 380 idx += len; 381 } 382 383 return .{ .SRV = .{ 384 .priority = priority, 385 .weight = weight, 386 .port = port, 387 .target = buf[0..idx], 388 } }; 389 }, 390 } 391 } 392 }; 393 394 pub fn answerIterator(self: Response) !AnswerIterator { 395 const h = self.header(); 396 397 var offset: usize = 12; 398 399 var q: u16 = 0; 400 while (q < h.question_count) { 401 offset = std.mem.indexOfScalarPos(u8, self.bytes, offset, 0x00) orelse 402 return error.InvalidResponse; 403 offset += 4; // 2 bytes for type, 2 bytes for class 404 q += 1; 405 } 406 407 return .{ 408 .bytes = self.bytes[offset..], 409 .count = h.answer_count, 410 }; 411 } 412}; 413 414pub const Connection = struct { 415 gpa: Allocator, 416 ctx: Context, 417 config: Config, 418 419 nameserver: u8 = 0, 420 attempt: u5 = 0, 421 422 read_buffer: [2048]u8 = undefined, 423 write_buffer: std.ArrayListUnmanaged(u8) = .empty, 424 deadline: i64 = 0, 425 426 const Msg = enum { connect, recv }; 427 428 pub fn tryNext(self: *Connection, io: *Ring) !void { 429 self.deadline = std.time.timestamp() + self.config.timeout_s; 430 431 if (self.attempt < self.config.attempts) { 432 const addr = self.config.nameservers[self.nameserver]; 433 self.attempt += 1; 434 435 _ = try net.udpConnectToAddr(io, addr, .{ 436 .cb = Connection.onCompletion, 437 .msg = @intFromEnum(Connection.Msg.connect), 438 .ptr = self, 439 }); 440 441 return; 442 } 443 444 self.attempt = 0; 445 446 if (self.nameserver < self.config.nameservers.len) { 447 const addr = self.config.nameservers[self.nameserver]; 448 self.nameserver += 1; 449 450 _ = try net.udpConnectToAddr(io, addr, .{ 451 .cb = Connection.onCompletion, 452 .msg = @intFromEnum(Connection.Msg.connect), 453 .ptr = self, 454 }); 455 return; 456 } 457 458 defer self.gpa.destroy(self); 459 try self.sendResult(io, .{ .userbytes = error.Timeout }); 460 } 461 462 pub fn onCompletion(io: *Ring, task: Task) anyerror!void { 463 const self = task.userdataCast(Connection); 464 const msg = task.msgToEnum(Connection.Msg); 465 const result = task.result.?; 466 467 switch (msg) { 468 .connect => { 469 const fd = result.userfd catch return self.tryNext(io); 470 471 const recv_task = try io.recv(fd, &self.read_buffer, .{ 472 .cb = Connection.onCompletion, 473 .ptr = self, 474 .msg = @intFromEnum(Connection.Msg.recv), 475 }); 476 try recv_task.setDeadline(io, .{ .sec = self.deadline }); 477 478 const write_task = try io.write(fd, self.write_buffer.items, .{}); 479 try write_task.setDeadline(io, .{ .sec = self.deadline }); 480 }, 481 482 .recv => { 483 const n = result.recv catch { 484 _ = try io.close(task.req.recv.fd, .{}); 485 return self.tryNext(io); 486 }; 487 488 if (n == 0) { 489 _ = try io.close(task.req.recv.fd, .{}); 490 return self.tryNext(io); 491 } 492 493 try self.sendResult(io, .{ .userbytes = self.read_buffer[0..n] }); 494 _ = try io.close(task.req.recv.fd, .{}); 495 self.gpa.destroy(self); 496 }, 497 } 498 } 499 500 fn sendResult(self: *Connection, io: *Ring, result: ourio.Result) !void { 501 defer self.write_buffer.deinit(self.gpa); 502 const task: ourio.Task = .{ 503 .callback = self.ctx.cb, 504 .userdata = self.ctx.ptr, 505 .msg = self.ctx.msg, 506 .result = result, 507 }; 508 try self.ctx.cb(io, task); 509 } 510 511 fn writeQuestion(self: *Connection, query: Question) !void { 512 const header: Header = .{ .question_count = 1 }; 513 var writer = self.write_buffer.writer(self.gpa); 514 try writer.writeAll(&header.asBytes()); 515 516 var iter = std.mem.splitScalar(u8, query.host, '.'); 517 while (iter.next()) |val| { 518 const len: u8 = @intCast(val.len); 519 try writer.writeByte(len); 520 try writer.writeAll(val); 521 } 522 try writer.writeByte(0x00); 523 try writer.writeInt(u16, @intFromEnum(query.type), .big); 524 try writer.writeInt(u16, @intFromEnum(query.class), .big); 525 } 526}; 527 528test "Resolver" { 529 const Anon = struct { 530 fn onOpen(_: *Task) ourio.Result { 531 return .{ .open = 1 }; 532 } 533 534 fn onRead(task: *Task) ourio.Result { 535 const @"resolv.conf" = 536 \\nameserver 1.1.1.1 537 \\nameserver 1.0.0.1 538 \\options timeout:10 attempts:3 539 ; 540 @memcpy(task.req.read.buffer[0..@"resolv.conf".len], @"resolv.conf"); 541 return .{ .read = @"resolv.conf".len }; 542 } 543 544 fn onClose(_: *Task) ourio.Result { 545 return .{ .close = {} }; 546 } 547 548 fn onSocket(_: *Task) ourio.Result { 549 return .{ .socket = 1 }; 550 } 551 552 fn onConnect(_: *Task) ourio.Result { 553 return .{ .connect = {} }; 554 } 555 556 fn onRecv(_: *Task) ourio.Result { 557 return .{ .recv = 1 }; 558 } 559 560 fn onWrite(task: *Task) ourio.Result { 561 return .{ .write = task.req.write.buffer.len }; 562 } 563 }; 564 565 var io: ourio.Ring = try .initMock(std.testing.allocator, 16); 566 defer io.deinit(); 567 568 io.backend.mock = .{ 569 .open_cb = Anon.onOpen, 570 .read_cb = Anon.onRead, 571 .close_cb = Anon.onClose, 572 .socket_cb = Anon.onSocket, 573 .connect_cb = Anon.onConnect, 574 .recv_cb = Anon.onRecv, 575 .write_cb = Anon.onWrite, 576 }; 577 578 var resolver: Resolver = undefined; 579 try resolver.init(std.testing.allocator, &io, .{}); 580 defer resolver.deinit(); 581 582 try std.testing.expectEqual(0, resolver.config.nameservers.len); 583 try std.testing.expectEqual(5, resolver.config.attempts); 584 try std.testing.expectEqual(30, resolver.config.timeout_s); 585 586 try io.run(.until_done); 587 588 try resolver.resolveQuery(&io, .{ .host = "timculverhouse.com" }, .{}); 589 try io.run(.until_done); 590 try std.testing.expectEqual(2, resolver.config.nameservers.len); 591 try std.testing.expectEqual(3, resolver.config.attempts); 592 try std.testing.expectEqual(10, resolver.config.timeout_s); 593} 594 595test "Header roundtrip" { 596 const header: Header = .{ .question_count = 1 }; 597 const bytes = header.asBytes(); 598 const response: Response = .{ .bytes = &bytes }; 599 const resp_header = response.header(); 600 try std.testing.expectEqual(header, resp_header); 601}