this repo has no description
1open Types.Wire
2
3let encode_ping (p : ping) : Msgpck.t =
4 Msgpck.Map
5 [
6 (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no);
7 (Msgpck.String "Node", Msgpck.String p.node);
8 (Msgpck.String "SourceAddr", Msgpck.String p.source_addr);
9 (Msgpck.String "SourcePort", Msgpck.of_int p.source_port);
10 (Msgpck.String "SourceNode", Msgpck.String p.source_node);
11 ]
12
13let decode_ping (m : Msgpck.t) : (ping, string) result =
14 match m with
15 | Msgpck.Map fields ->
16 let get_int key =
17 match List.assoc_opt (Msgpck.String key) fields with
18 | Some (Msgpck.Int i) -> Ok i
19 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
20 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
21 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
22 in
23 let get_string key =
24 match List.assoc_opt (Msgpck.String key) fields with
25 | Some (Msgpck.String s) -> Ok s
26 | Some (Msgpck.Bytes s) -> Ok s
27 | Some Msgpck.Nil -> Ok ""
28 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
29 in
30 let ( let* ) = Result.bind in
31 let* seq_no = get_int "SeqNo" in
32 let* node = get_string "Node" in
33 let* source_addr = get_string "SourceAddr" in
34 let* source_port =
35 match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0
36 in
37 let* source_node =
38 match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok ""
39 in
40 Ok { seq_no; node; source_addr; source_port; source_node }
41 | _ -> Error "expected map for ping"
42
43let encode_indirect_ping (p : indirect_ping_req) : Msgpck.t =
44 Msgpck.Map
45 [
46 (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no);
47 (Msgpck.String "Target", Msgpck.String p.target);
48 (Msgpck.String "Port", Msgpck.of_int p.port);
49 (Msgpck.String "Node", Msgpck.String p.node);
50 (Msgpck.String "Nack", Msgpck.Bool p.nack);
51 (Msgpck.String "SourceAddr", Msgpck.String p.source_addr);
52 (Msgpck.String "SourcePort", Msgpck.of_int p.source_port);
53 (Msgpck.String "SourceNode", Msgpck.String p.source_node);
54 ]
55
56let decode_indirect_ping (m : Msgpck.t) : (indirect_ping_req, string) result =
57 match m with
58 | Msgpck.Map fields ->
59 let get_int key =
60 match List.assoc_opt (Msgpck.String key) fields with
61 | Some (Msgpck.Int i) -> Ok i
62 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
63 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
64 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
65 in
66 let get_string key =
67 match List.assoc_opt (Msgpck.String key) fields with
68 | Some (Msgpck.String s) -> Ok s
69 | Some (Msgpck.Bytes s) -> Ok s
70 | Some Msgpck.Nil -> Ok ""
71 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
72 in
73 let get_bool key =
74 match List.assoc_opt (Msgpck.String key) fields with
75 | Some (Msgpck.Bool b) -> Ok b
76 | _ -> Ok false
77 in
78 let ( let* ) = Result.bind in
79 let* seq_no = get_int "SeqNo" in
80 let* target = get_string "Target" in
81 let* port = match get_int "Port" with Ok p -> Ok p | Error _ -> Ok 0 in
82 let* node = get_string "Node" in
83 let* nack = get_bool "Nack" in
84 let* source_addr =
85 match get_string "SourceAddr" with Ok s -> Ok s | Error _ -> Ok ""
86 in
87 let* source_port =
88 match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0
89 in
90 let* source_node =
91 match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok ""
92 in
93 Ok
94 {
95 seq_no;
96 target;
97 port;
98 node;
99 nack;
100 source_addr;
101 source_port;
102 source_node;
103 }
104 | _ -> Error "expected map for indirect_ping"
105
106let encode_ack (a : ack_resp) : Msgpck.t =
107 Msgpck.Map
108 [
109 (Msgpck.String "SeqNo", Msgpck.of_int a.seq_no);
110 (Msgpck.String "Payload", Msgpck.String a.payload);
111 ]
112
113let decode_ack (m : Msgpck.t) : (ack_resp, string) result =
114 match m with
115 | Msgpck.Map fields ->
116 let get_int key =
117 match List.assoc_opt (Msgpck.String key) fields with
118 | Some (Msgpck.Int i) -> Ok i
119 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
120 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
121 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
122 in
123 let get_bytes key =
124 match List.assoc_opt (Msgpck.String key) fields with
125 | Some (Msgpck.Bytes s) -> Ok s
126 | Some (Msgpck.String s) -> Ok s
127 | Some Msgpck.Nil -> Ok ""
128 | _ -> Ok ""
129 in
130 let ( let* ) = Result.bind in
131 let* seq_no = get_int "SeqNo" in
132 let* payload = get_bytes "Payload" in
133 Ok { seq_no; payload }
134 | _ -> Error "expected map for ack"
135
136let encode_nack (n : nack_resp) : Msgpck.t =
137 Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int n.seq_no) ]
138
139let decode_nack (m : Msgpck.t) : (nack_resp, string) result =
140 match m with
141 | Msgpck.Map fields ->
142 let get_int key =
143 match List.assoc_opt (Msgpck.String key) fields with
144 | Some (Msgpck.Int i) -> Ok i
145 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
146 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
147 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
148 in
149 let ( let* ) = Result.bind in
150 let* seq_no = get_int "SeqNo" in
151 Ok { seq_no }
152 | _ -> Error "expected map for nack"
153
154let encode_suspect (s : suspect) : Msgpck.t =
155 Msgpck.Map
156 [
157 (Msgpck.String "Incarnation", Msgpck.of_int s.incarnation);
158 (Msgpck.String "Node", Msgpck.String s.node);
159 (Msgpck.String "From", Msgpck.String s.from);
160 ]
161
162let decode_suspect (m : Msgpck.t) : (suspect, string) result =
163 match m with
164 | Msgpck.Map fields ->
165 let get_int key =
166 match List.assoc_opt (Msgpck.String key) fields with
167 | Some (Msgpck.Int i) -> Ok i
168 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
169 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
170 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
171 in
172 let get_string key =
173 match List.assoc_opt (Msgpck.String key) fields with
174 | Some (Msgpck.String s) -> Ok s
175 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
176 in
177 let ( let* ) = Result.bind in
178 let* incarnation = get_int "Incarnation" in
179 let* node = get_string "Node" in
180 let* from = get_string "From" in
181 Ok ({ incarnation; node; from } : suspect)
182 | _ -> Error "expected map for suspect"
183
184let encode_alive (a : alive) : Msgpck.t =
185 Msgpck.Map
186 [
187 (Msgpck.String "Incarnation", Msgpck.of_int a.incarnation);
188 (Msgpck.String "Node", Msgpck.String a.node);
189 (Msgpck.String "Addr", Msgpck.String a.addr);
190 (Msgpck.String "Port", Msgpck.of_int a.port);
191 (Msgpck.String "Meta", Msgpck.String a.meta);
192 (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int a.vsn));
193 ]
194
195let decode_alive (m : Msgpck.t) : (alive, string) result =
196 match m with
197 | Msgpck.Map fields ->
198 let get_int key =
199 match List.assoc_opt (Msgpck.String key) fields with
200 | Some (Msgpck.Int i) -> Ok i
201 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
202 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
203 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
204 in
205 let get_string key =
206 match List.assoc_opt (Msgpck.String key) fields with
207 | Some (Msgpck.String s) -> Ok s
208 | Some (Msgpck.Bytes s) -> Ok s
209 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
210 in
211 let get_vsn () =
212 match List.assoc_opt (Msgpck.String "Vsn") fields with
213 | Some (Msgpck.List vs) ->
214 Ok
215 (List.filter_map
216 (function
217 | Msgpck.Int i -> Some i
218 | Msgpck.Int32 i -> Some (Int32.to_int i)
219 | _ -> None)
220 vs)
221 | _ -> Ok []
222 in
223 let ( let* ) = Result.bind in
224 let* incarnation = get_int "Incarnation" in
225 let* node = get_string "Node" in
226 let* addr = get_string "Addr" in
227 let* port = get_int "Port" in
228 let* meta =
229 match get_string "Meta" with Ok m -> Ok m | Error _ -> Ok ""
230 in
231 let* vsn = get_vsn () in
232 Ok { incarnation; node; addr; port; meta; vsn }
233 | _ -> Error "expected map for alive"
234
235let encode_dead (d : dead) : Msgpck.t =
236 Msgpck.Map
237 [
238 (Msgpck.String "Incarnation", Msgpck.of_int d.incarnation);
239 (Msgpck.String "Node", Msgpck.String d.node);
240 (Msgpck.String "From", Msgpck.String d.from);
241 ]
242
243let decode_dead (m : Msgpck.t) : (dead, string) result =
244 match m with
245 | Msgpck.Map fields ->
246 let get_int key =
247 match List.assoc_opt (Msgpck.String key) fields with
248 | Some (Msgpck.Int i) -> Ok i
249 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
250 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
251 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
252 in
253 let get_string key =
254 match List.assoc_opt (Msgpck.String key) fields with
255 | Some (Msgpck.String s) -> Ok s
256 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
257 in
258 let ( let* ) = Result.bind in
259 let* incarnation = get_int "Incarnation" in
260 let* node = get_string "Node" in
261 let* from = get_string "From" in
262 Ok ({ incarnation; node; from } : dead)
263 | _ -> Error "expected map for dead"
264
265let encode_compress (c : compress) : Msgpck.t =
266 Msgpck.Map
267 [
268 (Msgpck.String "Algo", Msgpck.of_int c.algo);
269 (Msgpck.String "Buf", Msgpck.String c.buf);
270 ]
271
272let decode_compress (m : Msgpck.t) : (compress, string) result =
273 match m with
274 | Msgpck.Map fields ->
275 let get_int key =
276 match List.assoc_opt (Msgpck.String key) fields with
277 | Some (Msgpck.Int i) -> Ok i
278 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
279 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
280 in
281 let get_bytes key =
282 match List.assoc_opt (Msgpck.String key) fields with
283 | Some (Msgpck.Bytes s) -> Ok s
284 | Some (Msgpck.String s) -> Ok s
285 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
286 in
287 let ( let* ) = Result.bind in
288 let* algo = get_int "Algo" in
289 let* buf = get_bytes "Buf" in
290 Ok { algo; buf }
291 | _ -> Error "expected map for compress"
292
293let wire_msg_to_msgpck (msg : protocol_msg) : message_type * Msgpck.t =
294 match msg with
295 | Ping p -> (Ping_msg, encode_ping p)
296 | Indirect_ping p -> (Indirect_ping_msg, encode_indirect_ping p)
297 | Ack a -> (Ack_resp_msg, encode_ack a)
298 | Nack n -> (Nack_resp_msg, encode_nack n)
299 | Suspect s -> (Suspect_msg, encode_suspect s)
300 | Alive a -> (Alive_msg, encode_alive a)
301 | Dead d -> (Dead_msg, encode_dead d)
302 | User_data _ -> (User_msg, Msgpck.Nil)
303 | Compound _ -> (Compound_msg, Msgpck.Nil)
304 | Compressed c -> (Compress_msg, encode_compress c)
305 | Err e -> (Err_msg, Msgpck.Map [ (Msgpck.String "Error", Msgpck.String e) ])
306
307let encode_msg_to_cstruct (msg : protocol_msg) ~(buf : Cstruct.t) :
308 (int, [ `Buffer_too_small ]) result =
309 let msg_type, payload = wire_msg_to_msgpck msg in
310 let msg_type_byte = message_type_to_int msg_type in
311 match msg with
312 | User_data data ->
313 let total_len = 1 + String.length data in
314 if total_len > Cstruct.length buf then Error `Buffer_too_small
315 else begin
316 Cstruct.set_uint8 buf 0 msg_type_byte;
317 Cstruct.blit_from_string data 0 buf 1 (String.length data);
318 Ok total_len
319 end
320 | _ ->
321 let payload_size = Msgpck.size payload in
322 let total_len = 1 + payload_size in
323 if total_len > Cstruct.length buf then Error `Buffer_too_small
324 else begin
325 Cstruct.set_uint8 buf 0 msg_type_byte;
326 let payload_bytes = Bytes.create payload_size in
327 let _ = Msgpck.Bytes.write payload_bytes payload in
328 Cstruct.blit_from_bytes payload_bytes 0 buf 1 payload_size;
329 Ok total_len
330 end
331
332let decode_msg_from_cstruct (buf : Cstruct.t) :
333 (protocol_msg, Types.decode_error) result =
334 if Cstruct.length buf < 1 then Error Types.Truncated_message
335 else
336 let msg_type_byte = Cstruct.get_uint8 buf 0 in
337 match message_type_of_int msg_type_byte with
338 | Error n -> Error (Types.Invalid_tag n)
339 | Ok msg_type -> (
340 let payload_len = Cstruct.length buf - 1 in
341 match msg_type with
342 | User_msg ->
343 let data = Cstruct.to_string ~off:1 ~len:payload_len buf in
344 Ok (User_data data)
345 | Compound_msg -> Ok (Compound [])
346 | _ -> (
347 let payload_bytes = Cstruct.to_bytes ~off:1 ~len:payload_len buf in
348 let _, msgpack = Msgpck.Bytes.read payload_bytes in
349 match msg_type with
350 | Ping_msg -> (
351 match decode_ping msgpack with
352 | Ok p -> Ok (Ping p)
353 | Error e -> Error (Types.Msgpack_error e))
354 | Indirect_ping_msg -> (
355 match decode_indirect_ping msgpack with
356 | Ok p -> Ok (Indirect_ping p)
357 | Error e -> Error (Types.Msgpack_error e))
358 | Ack_resp_msg -> (
359 match decode_ack msgpack with
360 | Ok a -> Ok (Ack a)
361 | Error e -> Error (Types.Msgpack_error e))
362 | Nack_resp_msg -> (
363 match decode_nack msgpack with
364 | Ok n -> Ok (Nack n)
365 | Error e -> Error (Types.Msgpack_error e))
366 | Suspect_msg -> (
367 match decode_suspect msgpack with
368 | Ok s -> Ok (Suspect s)
369 | Error e -> Error (Types.Msgpack_error e))
370 | Alive_msg -> (
371 match decode_alive msgpack with
372 | Ok a -> Ok (Alive a)
373 | Error e -> Error (Types.Msgpack_error e))
374 | Dead_msg -> (
375 match decode_dead msgpack with
376 | Ok d -> Ok (Dead d)
377 | Error e -> Error (Types.Msgpack_error e))
378 | Compress_msg -> (
379 match decode_compress msgpack with
380 | Ok c -> Ok (Compressed c)
381 | Error e -> Error (Types.Msgpack_error e))
382 | Err_msg -> (
383 match msgpack with
384 | Msgpck.Map fields -> (
385 match List.assoc_opt (Msgpck.String "Error") fields with
386 | Some (Msgpck.String e) -> Ok (Err e)
387 | _ -> Ok (Err "unknown error"))
388 | _ -> Ok (Err "unknown error"))
389 | _ -> Error (Types.Invalid_tag msg_type_byte)))
390
391let crc32_table =
392 Array.init 256 (fun i ->
393 let crc = ref (Int32.of_int i) in
394 for _ = 0 to 7 do
395 if Int32.logand !crc 1l = 1l then
396 crc := Int32.logxor (Int32.shift_right_logical !crc 1) 0xEDB88320l
397 else crc := Int32.shift_right_logical !crc 1
398 done;
399 !crc)
400
401let crc32_cstruct (buf : Cstruct.t) : int32 =
402 let crc = ref 0xFFFFFFFFl in
403 for i = 0 to Cstruct.length buf - 1 do
404 let byte = Cstruct.get_uint8 buf i in
405 let idx =
406 Int32.to_int (Int32.logand (Int32.logxor !crc (Int32.of_int byte)) 0xFFl)
407 in
408 crc := Int32.logxor (Int32.shift_right_logical !crc 8) crc32_table.(idx)
409 done;
410 Int32.logxor !crc 0xFFFFFFFFl
411
412let add_crc_to_cstruct ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) :
413 (int, [ `Buffer_too_small ]) result =
414 let total_len = 5 + src_len in
415 if total_len > Cstruct.length dst then Error `Buffer_too_small
416 else begin
417 let payload = Cstruct.sub src 0 src_len in
418 let crc = crc32_cstruct payload in
419 Cstruct.set_uint8 dst 0 (message_type_to_int Has_crc_msg);
420 Cstruct.BE.set_uint32 dst 1 crc;
421 Cstruct.blit payload 0 dst 5 src_len;
422 Ok total_len
423 end
424
425let verify_and_strip_crc (buf : Cstruct.t) :
426 (Cstruct.t, Types.decode_error) result =
427 if Cstruct.length buf < 5 then Error Types.Truncated_message
428 else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_crc_msg then Ok buf
429 else
430 let expected = Cstruct.BE.get_uint32 buf 1 in
431 let payload = Cstruct.shift buf 5 in
432 let actual = crc32_cstruct payload in
433 if expected = actual then Ok payload else Error Types.Invalid_crc
434
435let add_label_to_cstruct ~label ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) :
436 (int, [ `Buffer_too_small ]) result =
437 if label = "" then begin
438 if src_len > Cstruct.length dst then Error `Buffer_too_small
439 else begin
440 Cstruct.blit src 0 dst 0 src_len;
441 Ok src_len
442 end
443 end
444 else
445 let label_len = String.length label in
446 let total_len = 2 + label_len + src_len in
447 if total_len > Cstruct.length dst then Error `Buffer_too_small
448 else begin
449 Cstruct.set_uint8 dst 0 (message_type_to_int Has_label_msg);
450 Cstruct.set_uint8 dst 1 label_len;
451 Cstruct.blit_from_string label 0 dst 2 label_len;
452 Cstruct.blit src 0 dst (2 + label_len) src_len;
453 Ok total_len
454 end
455
456let strip_label (buf : Cstruct.t) :
457 (Cstruct.t * string, Types.decode_error) result =
458 if Cstruct.length buf < 1 then Error Types.Truncated_message
459 else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_label_msg then
460 Ok (buf, "")
461 else if Cstruct.length buf < 2 then Error Types.Truncated_message
462 else
463 let label_len = Cstruct.get_uint8 buf 1 in
464 if Cstruct.length buf < 2 + label_len then Error Types.Truncated_message
465 else
466 let label = Cstruct.to_string ~off:2 ~len:label_len buf in
467 let payload = Cstruct.shift buf (2 + label_len) in
468 Ok (payload, label)
469
470let encode_compound_to_cstruct ~(msgs : Cstruct.t list) ~(msg_lens : int list)
471 ~(dst : Cstruct.t) : (int, [ `Buffer_too_small ]) result =
472 let num_msgs = List.length msgs in
473 if num_msgs > 255 then failwith "too many messages for compound"
474 else
475 let header_size = 1 + 1 + (num_msgs * 2) in
476 let total_payload = List.fold_left ( + ) 0 msg_lens in
477 let total_len = header_size + total_payload in
478 if total_len > Cstruct.length dst then Error `Buffer_too_small
479 else begin
480 Cstruct.set_uint8 dst 0 (message_type_to_int Compound_msg);
481 Cstruct.set_uint8 dst 1 num_msgs;
482 List.iteri
483 (fun i len -> Cstruct.BE.set_uint16 dst (2 + (i * 2)) len)
484 msg_lens;
485 let offset = ref header_size in
486 List.iter2
487 (fun msg len ->
488 Cstruct.blit msg 0 dst !offset len;
489 offset := !offset + len)
490 msgs msg_lens;
491 Ok total_len
492 end
493
494let decode_compound_from_cstruct (buf : Cstruct.t) :
495 (Cstruct.t list * int, Types.decode_error) result =
496 if Cstruct.length buf < 1 then Error Types.Truncated_message
497 else
498 let num_parts = Cstruct.get_uint8 buf 0 in
499 let header_size = 1 + (num_parts * 2) in
500 if Cstruct.length buf < header_size then Error Types.Truncated_message
501 else
502 let lengths =
503 List.init num_parts (fun i -> Cstruct.BE.get_uint16 buf (1 + (i * 2)))
504 in
505 let rec extract_parts offset remaining_lens acc trunc =
506 match remaining_lens with
507 | [] -> Ok (List.rev acc, trunc)
508 | len :: rest ->
509 if offset + len > Cstruct.length buf then
510 Ok (List.rev acc, List.length remaining_lens)
511 else
512 let part = Cstruct.sub buf offset len in
513 extract_parts (offset + len) rest (part :: acc) trunc
514 in
515 extract_parts header_size lengths [] 0
516
517let encode_internal_msg_to_cstruct ~self_name ~self_port
518 (msg : Types.protocol_msg) ~(buf : Cstruct.t) :
519 (int, [ `Buffer_too_small ]) result =
520 let wire_msg = Types.msg_to_wire ~self_name ~self_port msg in
521 encode_msg_to_cstruct wire_msg ~buf
522
523let decode_internal_msg_from_cstruct ~default_port (buf : Cstruct.t) :
524 (Types.protocol_msg, Types.decode_error) result =
525 match decode_msg_from_cstruct buf with
526 | Error e -> Error e
527 | Ok wire_msg -> (
528 match Types.msg_of_wire ~default_port wire_msg with
529 | Some msg -> Ok msg
530 | None -> Error (Types.Invalid_tag 0))
531
532let encode_packet (packet : Types.packet) ~(buf : Cstruct.t) :
533 (int, [ `Buffer_too_small ]) result =
534 let self_name = packet.cluster in
535 let self_port = 7946 in
536 match packet.piggyback with
537 | [] ->
538 encode_internal_msg_to_cstruct ~self_name ~self_port packet.primary ~buf
539 | piggyback -> (
540 let encode_one msg =
541 let temp_buf = Cstruct.create 2048 in
542 match
543 encode_internal_msg_to_cstruct ~self_name ~self_port msg ~buf:temp_buf
544 with
545 | Error _ -> None
546 | Ok len -> Some (Cstruct.sub temp_buf 0 len, len)
547 in
548 let primary_result = encode_one packet.primary in
549 let piggyback_results = List.filter_map encode_one piggyback in
550 match primary_result with
551 | None -> Error `Buffer_too_small
552 | Some (primary_cs, primary_len) ->
553 let all_msgs = primary_cs :: List.map fst piggyback_results in
554 let all_lens = primary_len :: List.map snd piggyback_results in
555 encode_compound_to_cstruct ~msgs:all_msgs ~msg_lens:all_lens ~dst:buf)
556
557let decode_packet (buf : Cstruct.t) : (Types.packet, Types.decode_error) result
558 =
559 if Cstruct.length buf < 1 then Error Types.Truncated_message
560 else
561 let msg_type = Cstruct.get_uint8 buf 0 in
562 if msg_type = message_type_to_int Compound_msg then
563 let payload = Cstruct.shift buf 1 in
564 match decode_compound_from_cstruct payload with
565 | Error e -> Error e
566 | Ok (parts, _truncated) -> (
567 match parts with
568 | [] -> Error Types.Truncated_message
569 | first :: rest -> (
570 match
571 decode_internal_msg_from_cstruct ~default_port:7946 first
572 with
573 | Error e -> Error e
574 | Ok primary ->
575 let piggyback =
576 List.filter_map
577 (fun p ->
578 match
579 decode_internal_msg_from_cstruct ~default_port:7946 p
580 with
581 | Ok m -> Some m
582 | Error _ -> None)
583 rest
584 in
585 Ok { Types.cluster = ""; primary; piggyback }))
586 else
587 match decode_internal_msg_from_cstruct ~default_port:7946 buf with
588 | Error e -> Error e
589 | Ok primary -> Ok { Types.cluster = ""; primary; piggyback = [] }
590
591let encoded_size (msg : Types.protocol_msg) : int =
592 let wire_msg = Types.msg_to_wire ~self_name:"" ~self_port:7946 msg in
593 let _, payload = wire_msg_to_msgpck wire_msg in
594 1 + Msgpck.size payload + 3
595
596let encode_internal_msg ~self_name ~self_port (msg : Types.protocol_msg) :
597 string =
598 let buf = Cstruct.create 2048 in
599 match encode_internal_msg_to_cstruct ~self_name ~self_port msg ~buf with
600 | Error _ -> ""
601 | Ok len -> Cstruct.to_string ~off:0 ~len buf
602
603(* Backward-compatible string wrappers for tests *)
604
605let add_crc (data : string) : string =
606 let src = Cstruct.of_string data in
607 let dst = Cstruct.create (5 + String.length data) in
608 match add_crc_to_cstruct ~src ~src_len:(String.length data) ~dst with
609 | Error _ -> data
610 | Ok len -> Cstruct.to_string ~off:0 ~len dst
611
612let verify_and_strip_crc_string (data : string) :
613 (string, Types.decode_error) result =
614 let buf = Cstruct.of_string data in
615 match verify_and_strip_crc buf with
616 | Error e -> Error e
617 | Ok cs -> Ok (Cstruct.to_string cs)
618
619let add_label (label : string) (data : string) : string =
620 let src = Cstruct.of_string data in
621 let dst = Cstruct.create (2 + String.length label + String.length data) in
622 match add_label_to_cstruct ~label ~src ~src_len:(String.length data) ~dst with
623 | Error _ -> data
624 | Ok len -> Cstruct.to_string ~off:0 ~len dst
625
626let strip_label_string (data : string) :
627 (string * string, Types.decode_error) result =
628 let buf = Cstruct.of_string data in
629 match strip_label buf with
630 | Error e -> Error e
631 | Ok (cs, label) -> Ok (Cstruct.to_string cs, label)
632
633let make_compound_msg (msgs : string list) : string =
634 let css = List.map Cstruct.of_string msgs in
635 let lens = List.map String.length msgs in
636 let total_len = 2 + (List.length msgs * 2) + List.fold_left ( + ) 0 lens in
637 let dst = Cstruct.create total_len in
638 match encode_compound_to_cstruct ~msgs:css ~msg_lens:lens ~dst with
639 | Error _ -> ""
640 | Ok len -> Cstruct.to_string ~off:0 ~len dst
641
642let decode_compound_msg (data : string) :
643 (string list * int, Types.decode_error) result =
644 let buf = Cstruct.of_string data in
645 match decode_compound_from_cstruct buf with
646 | Error e -> Error e
647 | Ok (css, trunc) -> Ok (List.map Cstruct.to_string css, trunc)
648
649let encode_push_pull_header (h : push_pull_header) : Msgpck.t =
650 Msgpck.Map
651 [
652 (Msgpck.String "Nodes", Msgpck.of_int h.pp_nodes);
653 (Msgpck.String "UserStateLen", Msgpck.of_int h.pp_user_state_len);
654 (Msgpck.String "Join", Msgpck.Bool h.pp_join);
655 ]
656
657let decode_push_pull_header (m : Msgpck.t) : (push_pull_header, string) result =
658 match m with
659 | Msgpck.Map fields ->
660 let get_int key =
661 match List.assoc_opt (Msgpck.String key) fields with
662 | Some (Msgpck.Int i) -> Ok i
663 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
664 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
665 | _ -> Ok 0
666 in
667 let get_bool key =
668 match List.assoc_opt (Msgpck.String key) fields with
669 | Some (Msgpck.Bool b) -> Ok b
670 | _ -> Ok false
671 in
672 let ( let* ) = Result.bind in
673 let* pp_nodes = get_int "Nodes" in
674 let* pp_user_state_len = get_int "UserStateLen" in
675 let* pp_join = get_bool "Join" in
676 Ok { pp_nodes; pp_user_state_len; pp_join }
677 | _ -> Error "expected map for push_pull_header"
678
679let encode_push_node_state (s : push_node_state) : Msgpck.t =
680 Msgpck.Map
681 [
682 (Msgpck.String "Name", Msgpck.String s.pns_name);
683 (Msgpck.String "Addr", Msgpck.Bytes s.pns_addr);
684 (Msgpck.String "Port", Msgpck.of_int s.pns_port);
685 (Msgpck.String "Meta", Msgpck.Bytes s.pns_meta);
686 (Msgpck.String "Incarnation", Msgpck.of_int s.pns_incarnation);
687 (Msgpck.String "State", Msgpck.of_int s.pns_state);
688 (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int s.pns_vsn));
689 ]
690
691let decode_push_node_state (m : Msgpck.t) : (push_node_state, string) result =
692 match m with
693 | Msgpck.Map fields ->
694 let get_string key =
695 match List.assoc_opt (Msgpck.String key) fields with
696 | Some (Msgpck.String s) -> Ok s
697 | Some (Msgpck.Bytes s) -> Ok s
698 | Some Msgpck.Nil -> Ok ""
699 | _ -> Ok ""
700 in
701 let get_int key =
702 match List.assoc_opt (Msgpck.String key) fields with
703 | Some (Msgpck.Int i) -> Ok i
704 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
705 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
706 | _ -> Ok 0
707 in
708 let get_int_list key =
709 match List.assoc_opt (Msgpck.String key) fields with
710 | Some (Msgpck.List items) ->
711 Ok
712 (List.filter_map
713 (function
714 | Msgpck.Int i -> Some i
715 | Msgpck.Int32 i -> Some (Int32.to_int i)
716 | Msgpck.Uint32 i -> Some (Int32.to_int i)
717 | _ -> None)
718 items)
719 | _ -> Ok []
720 in
721 let ( let* ) = Result.bind in
722 let* pns_name = get_string "Name" in
723 let* pns_addr = get_string "Addr" in
724 let* pns_port = get_int "Port" in
725 let* pns_meta = get_string "Meta" in
726 let* pns_incarnation = get_int "Incarnation" in
727 let* pns_state = get_int "State" in
728 let* pns_vsn = get_int_list "Vsn" in
729 Ok
730 {
731 pns_name;
732 pns_addr;
733 pns_port;
734 pns_meta;
735 pns_incarnation;
736 pns_state;
737 pns_vsn;
738 }
739 | _ -> Error "expected map for push_node_state"
740
741let encode_push_pull_msg ~(header : push_pull_header)
742 ~(nodes : push_node_state list) ~(user_state : string) : string =
743 let buf = Buffer.create 1024 in
744 Buffer.add_char buf (Char.chr (message_type_to_int Push_pull_msg));
745 ignore (Msgpck.StringBuf.write buf (encode_push_pull_header header));
746 List.iter
747 (fun n -> ignore (Msgpck.StringBuf.write buf (encode_push_node_state n)))
748 nodes;
749 Buffer.add_string buf user_state;
750 Buffer.contents buf
751
752let decode_push_pull_msg (data : string) :
753 ( push_pull_header * push_node_state list * string,
754 Types.decode_error )
755 result =
756 if String.length data < 1 then Error Types.Truncated_message
757 else
758 let header_size, header_msgpack = Msgpck.String.read data in
759 match decode_push_pull_header header_msgpack with
760 | Error e -> Error (Types.Msgpack_error e)
761 | Ok header -> (
762 let rec read_nodes offset remaining acc =
763 if remaining <= 0 then Ok (List.rev acc, offset)
764 else if offset >= String.length data then
765 Error Types.Truncated_message
766 else
767 let rest = String.sub data offset (String.length data - offset) in
768 let node_size, node_msgpack = Msgpck.String.read rest in
769 match decode_push_node_state node_msgpack with
770 | Error e -> Error (Types.Msgpack_error e)
771 | Ok node ->
772 read_nodes (offset + node_size) (remaining - 1) (node :: acc)
773 in
774 match read_nodes header_size header.pp_nodes [] with
775 | Error e -> Error e
776 | Ok (nodes, offset) ->
777 let user_state =
778 if header.pp_user_state_len > 0 && offset < String.length data
779 then
780 String.sub data offset
781 (min header.pp_user_state_len (String.length data - offset))
782 else ""
783 in
784 Ok (header, nodes, user_state))
785
786let decode_compress_from_cstruct (buf : Cstruct.t) :
787 (int * Cstruct.t, Types.decode_error) result =
788 let data = Cstruct.to_string buf in
789 let _, msgpack = Msgpck.String.read data in
790 match msgpack with
791 | Msgpck.Map fields -> (
792 let algo =
793 match List.assoc_opt (Msgpck.String "Algo") fields with
794 | Some (Msgpck.Int i) -> i
795 | Some (Msgpck.Int32 i) -> Int32.to_int i
796 | _ -> -1
797 in
798 let compressed_buf =
799 match List.assoc_opt (Msgpck.String "Buf") fields with
800 | Some (Msgpck.Bytes s) -> Some (Cstruct.of_string s)
801 | Some (Msgpck.String s) -> Some (Cstruct.of_string s)
802 | _ -> None
803 in
804 match compressed_buf with
805 | Some cs -> Ok (algo, cs)
806 | None -> Error (Types.Msgpack_error "missing Buf field"))
807 | _ -> Error (Types.Msgpack_error "expected map for compress")
808
809let decode_push_pull_msg_cstruct (buf : Cstruct.t) :
810 ( push_pull_header * push_node_state list * Cstruct.t,
811 Types.decode_error )
812 result =
813 if Cstruct.length buf < 1 then Error Types.Truncated_message
814 else
815 let data = Cstruct.to_string buf in
816 let header_size, header_msgpack = Msgpck.String.read data in
817 match decode_push_pull_header header_msgpack with
818 | Error e -> Error (Types.Msgpack_error e)
819 | Ok header -> (
820 let rec read_nodes offset remaining acc =
821 if remaining <= 0 then Ok (List.rev acc, offset)
822 else if offset >= String.length data then
823 Error Types.Truncated_message
824 else
825 let rest = String.sub data offset (String.length data - offset) in
826 let node_size, node_msgpack = Msgpck.String.read rest in
827 match decode_push_node_state node_msgpack with
828 | Error e -> Error (Types.Msgpack_error e)
829 | Ok node ->
830 read_nodes (offset + node_size) (remaining - 1) (node :: acc)
831 in
832 match read_nodes header_size header.pp_nodes [] with
833 | Error e -> Error e
834 | Ok (nodes, offset) ->
835 let user_state =
836 if header.pp_user_state_len > 0 && offset < Cstruct.length buf
837 then
838 Cstruct.sub buf offset
839 (min header.pp_user_state_len (Cstruct.length buf - offset))
840 else Cstruct.empty
841 in
842 Ok (header, nodes, user_state))