An asynchronous IO runtime
1const std = @import("std");
2const stda = @import("../../stda.zig");
3const ourio = @import("ourio");
4
5const Allocator = std.mem.Allocator;
6const Context = ourio.Context;
7const Ring = ourio.Ring;
8const Task = ourio.Task;
9const assert = std.debug.assert;
10const net = stda.net;
11const posix = std.posix;
12
13const default_dns = &stda.options.nameservers;
14
15pub const Resolver = struct {
16 gpa: Allocator,
17 ctx: Context,
18 config: Config = .{},
19
20 const Msg = enum { open_resolv, read_resolv };
21
22 /// initialize a Resolver instance. When the resolver is complete with initialization, a userptr
23 /// result type will be delivered to ctx. The resolver will then be ready to resolve DNS queries
24 pub fn init(
25 self: *Resolver,
26 gpa: Allocator,
27 io: *Ring,
28 ctx: Context,
29 ) Allocator.Error!void {
30 self.* = .{
31 .gpa = gpa,
32 .ctx = ctx,
33 };
34
35 _ = try io.open("/etc/resolv.conf", .{ .CLOEXEC = true }, 0, .{
36 .cb = Resolver.onCompletion,
37 .ptr = self,
38 .msg = @intFromEnum(Msg.open_resolv),
39 });
40 }
41
42 pub fn deinit(self: *Resolver) void {
43 self.gpa.free(self.config.nameservers);
44 }
45
46 pub fn resolveQuery(self: *Resolver, io: *Ring, query: Question, ctx: ourio.Context) !void {
47 assert(self.config.nameservers.len > 0);
48
49 const conn = try self.gpa.create(Connection);
50 conn.* = .{ .gpa = self.gpa, .ctx = ctx, .config = self.config };
51 try conn.writeQuestion(query);
52
53 try conn.tryNext(io);
54 }
55
56 pub fn onCompletion(io: *Ring, task: Task) anyerror!void {
57 const self = task.userdataCast(Resolver);
58 const msg = task.msgToEnum(Resolver.Msg);
59 const result = task.result.?;
60
61 switch (msg) {
62 .open_resolv => {
63 const fd = result.open catch {
64 self.config.nameservers = try self.gpa.dupe(std.net.Address, default_dns);
65 const t: Task = .{
66 .callback = self.ctx.cb,
67 .msg = self.ctx.msg,
68 .userdata = self.ctx.ptr,
69 .result = .{ .userptr = self },
70 };
71 try self.ctx.cb(io, t);
72 return;
73 };
74
75 const buffer = try self.gpa.alloc(u8, 4096);
76 errdefer self.gpa.free(buffer);
77
78 _ = try io.read(fd, buffer, .{
79 .cb = Resolver.onCompletion,
80 .ptr = self,
81 .msg = @intFromEnum(Resolver.Msg.read_resolv),
82 });
83 },
84
85 .read_resolv => {
86 const buffer = task.req.read.buffer;
87 defer self.gpa.free(buffer);
88
89 _ = try io.close(task.req.read.fd, .{});
90
91 const n = result.read catch |err| {
92 const t: Task = .{
93 .callback = self.ctx.cb,
94 .msg = self.ctx.msg,
95 .userdata = self.ctx.ptr,
96 .result = .{ .userptr = err },
97 };
98 try self.ctx.cb(io, t);
99 return;
100 };
101
102 if (n >= buffer.len) {
103 @panic("TODO: more to read");
104 }
105
106 var line_iter = std.mem.splitScalar(u8, buffer[0..n], '\n');
107 var addresses: std.ArrayListUnmanaged(std.net.Address) = .empty;
108 defer addresses.deinit(self.gpa);
109
110 while (line_iter.next()) |line| {
111 if (line.len == 0 or line[0] == ';' or line[0] == '#') continue;
112
113 var iter = std.mem.splitAny(u8, line, &std.ascii.whitespace);
114 const key = iter.first();
115
116 if (std.mem.eql(u8, key, "nameserver")) {
117 const addr = try std.net.Address.parseIp(iter.rest(), 53);
118 try addresses.append(self.gpa, addr);
119 continue;
120 }
121
122 if (std.mem.eql(u8, key, "options")) {
123 while (iter.next()) |opt| {
124 if (std.mem.startsWith(u8, opt, "timeout:")) {
125 const timeout = std.fmt.parseInt(u5, opt[8..], 10) catch 30;
126 self.config.timeout_s = @min(30, timeout);
127 continue;
128 }
129
130 if (std.mem.startsWith(u8, opt, "attempts:")) {
131 const attempts = std.fmt.parseInt(u3, opt[9..], 10) catch 5;
132 self.config.attempts = @max(@min(5, attempts), 1);
133 continue;
134 }
135
136 if (std.mem.eql(u8, opt, "edns0")) {
137 self.config.edns0 = true;
138 continue;
139 }
140 }
141 }
142 }
143
144 self.config.nameservers = try addresses.toOwnedSlice(self.gpa);
145
146 const t: Task = .{
147 .callback = self.ctx.cb,
148 .msg = self.ctx.msg,
149 .userdata = self.ctx.ptr,
150 .result = .{ .userptr = self },
151 };
152 try self.ctx.cb(io, t);
153 },
154 }
155 }
156};
157
158pub const Config = struct {
159 nameservers: []const std.net.Address = &.{},
160
161 /// timeout_s is silently capped to 30 according to man resolv.conf
162 timeout_s: u5 = 30,
163
164 /// attempts is capped at 5
165 attempts: u3 = 5,
166
167 edns0: bool = false,
168};
169
170pub const Header = packed struct(u96) {
171 id: u16 = 0,
172
173 flags1: packed struct(u8) {
174 recursion_desired: bool = true,
175 truncated: bool = false,
176 authoritative_answer: bool = false,
177 opcode: enum(u4) {
178 query = 0,
179 inverse_query = 1,
180 server_status_request = 2,
181 } = .query,
182 is_response: bool = false,
183 } = .{},
184
185 flags2: packed struct(u8) {
186 response_code: enum(u4) {
187 success = 0,
188 format_error = 1,
189 server_failure = 2,
190 name_error = 3,
191 not_implemented = 4,
192 refuse = 5,
193 } = .success,
194 z: u3 = 0,
195 recursion_available: bool = false,
196 } = .{},
197
198 question_count: u16 = 0,
199
200 answer_count: u16 = 0,
201
202 authority_count: u16 = 0,
203
204 additional_count: u16 = 0,
205
206 pub fn asBytes(self: Header) [12]u8 {
207 var bytes: [12]u8 = undefined;
208 var fbs = std.io.fixedBufferStream(&bytes);
209 fbs.writer().writeInt(u16, self.id, .big) catch unreachable;
210
211 fbs.writer().writeByte(@bitCast(self.flags1)) catch unreachable;
212 fbs.writer().writeByte(@bitCast(self.flags2)) catch unreachable;
213
214 fbs.writer().writeInt(u16, self.question_count, .big) catch unreachable;
215 fbs.writer().writeInt(u16, self.answer_count, .big) catch unreachable;
216 fbs.writer().writeInt(u16, self.authority_count, .big) catch unreachable;
217 fbs.writer().writeInt(u16, self.additional_count, .big) catch unreachable;
218 assert(fbs.pos == 12);
219 return bytes;
220 }
221};
222
223pub const Question = struct {
224 host: []const u8,
225 type: ResourceType = .A,
226 class: enum(u16) {
227 IN = 1,
228 // CS = 2,
229 // CH = 3,
230 // HS = 4,
231 // WILDCARD = 255,
232 } = .IN,
233};
234
235pub const ResourceType = enum(u16) {
236 A = 1,
237 // NS = 2,
238 // MD = 3,
239 // MF = 4,
240 // CNAME = 5,
241 // SOA = 6,
242 // MB = 7,
243 // MG = 8,
244 // MR = 9,
245 // NULL = 10,
246 // WKS = 11,
247 // PTR = 12,
248 // HINFO = 13,
249 // MINFO = 14,
250 // MX = 15,
251 // TXT = 16,
252 AAAA = 28,
253 SRV = 33,
254 // OPT = 41,
255};
256
257pub const Answer = union(ResourceType) {
258 A: [4]u8,
259 AAAA: [16]u8,
260 SRV: struct {
261 priority: u16,
262 weight: u16,
263 port: u16,
264 target: []const u8,
265 },
266};
267
268pub const Response = struct {
269 bytes: []const u8,
270
271 pub fn header(self: Response) Header {
272 assert(self.bytes.len >= 12);
273 const readInt = std.mem.readInt;
274
275 return .{
276 .id = readInt(u16, self.bytes[0..2], .big),
277 .flags1 = @bitCast(self.bytes[2]),
278 .flags2 = @bitCast(self.bytes[3]),
279 .question_count = readInt(u16, self.bytes[4..6], .big),
280 .answer_count = readInt(u16, self.bytes[6..8], .big),
281 .authority_count = readInt(u16, self.bytes[8..10], .big),
282 .additional_count = readInt(u16, self.bytes[10..12], .big),
283 };
284 }
285
286 pub const AnswerIterator = struct {
287 bytes: []const u8,
288 /// offset into bytes
289 offset: usize = 0,
290
291 count: usize,
292 /// number of answers we have returned
293 idx: usize = 0,
294
295 pub fn next(self: *AnswerIterator) ?Answer {
296 if (self.idx >= self.count or self.offset >= self.bytes.len) return null;
297 defer self.idx += 1;
298
299 // Read the name
300 const b = self.bytes[self.offset];
301 if (b & 0b1100_0000 == 0) {
302 // Encoded name. Get past this
303 self.offset = std.mem.indexOfScalar(u8, self.bytes[self.idx..], 0x00) orelse
304 return null;
305 } else {
306 // Name is pointer, we can advance 2 bytes
307 self.offset += 2;
308 }
309
310 const typ: ResourceType = @enumFromInt(std.mem.readInt(
311 u16,
312 self.bytes[self.offset..][0..2],
313 .big,
314 ));
315 self.offset += 2;
316 const class = std.mem.readInt(u16, self.bytes[self.offset..][0..2], .big);
317 assert(class == 1);
318 self.offset += 2;
319 const ttl = std.mem.readInt(u32, self.bytes[self.offset..][0..4], .big);
320 _ = ttl;
321 self.offset += 4;
322 const rd_len = std.mem.readInt(u16, self.bytes[self.offset..][0..2], .big);
323 self.offset += 2;
324 defer self.offset += rd_len;
325
326 switch (typ) {
327 .A => {
328 assert(rd_len == 4);
329 return .{ .A = .{
330 self.bytes[self.offset],
331 self.bytes[self.offset + 1],
332 self.bytes[self.offset + 2],
333 self.bytes[self.offset + 3],
334 } };
335 },
336
337 .AAAA => {
338 assert(rd_len == 4);
339 return .{ .AAAA = .{
340 self.bytes[self.offset],
341 self.bytes[self.offset + 1],
342 self.bytes[self.offset + 2],
343 self.bytes[self.offset + 3],
344 self.bytes[self.offset + 4],
345 self.bytes[self.offset + 5],
346 self.bytes[self.offset + 6],
347 self.bytes[self.offset + 7],
348 self.bytes[self.offset + 8],
349 self.bytes[self.offset + 9],
350 self.bytes[self.offset + 10],
351 self.bytes[self.offset + 11],
352 self.bytes[self.offset + 12],
353 self.bytes[self.offset + 13],
354 self.bytes[self.offset + 14],
355 self.bytes[self.offset + 15],
356 } };
357 },
358
359 .SRV => {
360 assert(rd_len > 6);
361 const rdata = self.bytes[self.offset..];
362 const priority = std.mem.readInt(u16, rdata[0..2], .big);
363 const weight = std.mem.readInt(u16, rdata[2..4], .big);
364 const port = std.mem.readInt(u16, rdata[4..6], .big);
365
366 var buf: [256]u8 = undefined;
367 var idx: usize = 0;
368 var offset: usize = 6;
369 while (true) {
370 const len = rdata[offset];
371 if (len == 0x00) break;
372
373 if (idx > 0) {
374 buf[idx] = '.';
375 idx += 1;
376 }
377 offset += 1;
378 @memcpy(buf[idx .. idx + len], rdata[offset .. offset + len]);
379 offset += len;
380 idx += len;
381 }
382
383 return .{ .SRV = .{
384 .priority = priority,
385 .weight = weight,
386 .port = port,
387 .target = buf[0..idx],
388 } };
389 },
390 }
391 }
392 };
393
394 pub fn answerIterator(self: Response) !AnswerIterator {
395 const h = self.header();
396
397 var offset: usize = 12;
398
399 var q: u16 = 0;
400 while (q < h.question_count) {
401 offset = std.mem.indexOfScalarPos(u8, self.bytes, offset, 0x00) orelse
402 return error.InvalidResponse;
403 offset += 4; // 2 bytes for type, 2 bytes for class
404 q += 1;
405 }
406
407 return .{
408 .bytes = self.bytes[offset..],
409 .count = h.answer_count,
410 };
411 }
412};
413
414pub const Connection = struct {
415 gpa: Allocator,
416 ctx: Context,
417 config: Config,
418
419 nameserver: u8 = 0,
420 attempt: u5 = 0,
421
422 read_buffer: [2048]u8 = undefined,
423 write_buffer: std.ArrayListUnmanaged(u8) = .empty,
424 deadline: i64 = 0,
425
426 const Msg = enum { connect, recv };
427
428 pub fn tryNext(self: *Connection, io: *Ring) !void {
429 self.deadline = std.time.timestamp() + self.config.timeout_s;
430
431 if (self.attempt < self.config.attempts) {
432 const addr = self.config.nameservers[self.nameserver];
433 self.attempt += 1;
434
435 _ = try net.udpConnectToAddr(io, addr, .{
436 .cb = Connection.onCompletion,
437 .msg = @intFromEnum(Connection.Msg.connect),
438 .ptr = self,
439 });
440
441 return;
442 }
443
444 self.attempt = 0;
445
446 if (self.nameserver < self.config.nameservers.len) {
447 const addr = self.config.nameservers[self.nameserver];
448 self.nameserver += 1;
449
450 _ = try net.udpConnectToAddr(io, addr, .{
451 .cb = Connection.onCompletion,
452 .msg = @intFromEnum(Connection.Msg.connect),
453 .ptr = self,
454 });
455 return;
456 }
457
458 defer self.gpa.destroy(self);
459 try self.sendResult(io, .{ .userbytes = error.Timeout });
460 }
461
462 pub fn onCompletion(io: *Ring, task: Task) anyerror!void {
463 const self = task.userdataCast(Connection);
464 const msg = task.msgToEnum(Connection.Msg);
465 const result = task.result.?;
466
467 switch (msg) {
468 .connect => {
469 const fd = result.userfd catch return self.tryNext(io);
470
471 const recv_task = try io.recv(fd, &self.read_buffer, .{
472 .cb = Connection.onCompletion,
473 .ptr = self,
474 .msg = @intFromEnum(Connection.Msg.recv),
475 });
476 try recv_task.setDeadline(io, .{ .sec = self.deadline });
477
478 const write_task = try io.write(fd, self.write_buffer.items, .{});
479 try write_task.setDeadline(io, .{ .sec = self.deadline });
480 },
481
482 .recv => {
483 const n = result.recv catch {
484 _ = try io.close(task.req.recv.fd, .{});
485 return self.tryNext(io);
486 };
487
488 if (n == 0) {
489 _ = try io.close(task.req.recv.fd, .{});
490 return self.tryNext(io);
491 }
492
493 try self.sendResult(io, .{ .userbytes = self.read_buffer[0..n] });
494 _ = try io.close(task.req.recv.fd, .{});
495 self.gpa.destroy(self);
496 },
497 }
498 }
499
500 fn sendResult(self: *Connection, io: *Ring, result: ourio.Result) !void {
501 defer self.write_buffer.deinit(self.gpa);
502 const task: ourio.Task = .{
503 .callback = self.ctx.cb,
504 .userdata = self.ctx.ptr,
505 .msg = self.ctx.msg,
506 .result = result,
507 };
508 try self.ctx.cb(io, task);
509 }
510
511 fn writeQuestion(self: *Connection, query: Question) !void {
512 const header: Header = .{ .question_count = 1 };
513 var writer = self.write_buffer.writer(self.gpa);
514 try writer.writeAll(&header.asBytes());
515
516 var iter = std.mem.splitScalar(u8, query.host, '.');
517 while (iter.next()) |val| {
518 const len: u8 = @intCast(val.len);
519 try writer.writeByte(len);
520 try writer.writeAll(val);
521 }
522 try writer.writeByte(0x00);
523 try writer.writeInt(u16, @intFromEnum(query.type), .big);
524 try writer.writeInt(u16, @intFromEnum(query.class), .big);
525 }
526};
527
528test "Resolver" {
529 const Anon = struct {
530 fn onOpen(_: *Task) ourio.Result {
531 return .{ .open = 1 };
532 }
533
534 fn onRead(task: *Task) ourio.Result {
535 const @"resolv.conf" =
536 \\nameserver 1.1.1.1
537 \\nameserver 1.0.0.1
538 \\options timeout:10 attempts:3
539 ;
540 @memcpy(task.req.read.buffer[0..@"resolv.conf".len], @"resolv.conf");
541 return .{ .read = @"resolv.conf".len };
542 }
543
544 fn onClose(_: *Task) ourio.Result {
545 return .{ .close = {} };
546 }
547
548 fn onSocket(_: *Task) ourio.Result {
549 return .{ .socket = 1 };
550 }
551
552 fn onConnect(_: *Task) ourio.Result {
553 return .{ .connect = {} };
554 }
555
556 fn onRecv(_: *Task) ourio.Result {
557 return .{ .recv = 1 };
558 }
559
560 fn onWrite(task: *Task) ourio.Result {
561 return .{ .write = task.req.write.buffer.len };
562 }
563 };
564
565 var io: ourio.Ring = try .initMock(std.testing.allocator, 16);
566 defer io.deinit();
567
568 io.backend.mock = .{
569 .open_cb = Anon.onOpen,
570 .read_cb = Anon.onRead,
571 .close_cb = Anon.onClose,
572 .socket_cb = Anon.onSocket,
573 .connect_cb = Anon.onConnect,
574 .recv_cb = Anon.onRecv,
575 .write_cb = Anon.onWrite,
576 };
577
578 var resolver: Resolver = undefined;
579 try resolver.init(std.testing.allocator, &io, .{});
580 defer resolver.deinit();
581
582 try std.testing.expectEqual(0, resolver.config.nameservers.len);
583 try std.testing.expectEqual(5, resolver.config.attempts);
584 try std.testing.expectEqual(30, resolver.config.timeout_s);
585
586 try io.run(.until_done);
587
588 try resolver.resolveQuery(&io, .{ .host = "timculverhouse.com" }, .{});
589 try io.run(.until_done);
590 try std.testing.expectEqual(2, resolver.config.nameservers.len);
591 try std.testing.expectEqual(3, resolver.config.attempts);
592 try std.testing.expectEqual(10, resolver.config.timeout_s);
593}
594
595test "Header roundtrip" {
596 const header: Header = .{ .question_count = 1 };
597 const bytes = header.asBytes();
598 const response: Response = .{ .bytes = &bytes };
599 const resp_header = response.header();
600 try std.testing.expectEqual(header, resp_header);
601}