this repo has no description
1open Types.Wire
2
3let encode_ping (p : ping) : Msgpck.t =
4 Msgpck.Map
5 [
6 (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no);
7 (Msgpck.String "Node", Msgpck.String p.node);
8 (Msgpck.String "SourceAddr", Msgpck.String p.source_addr);
9 (Msgpck.String "SourcePort", Msgpck.of_int p.source_port);
10 (Msgpck.String "SourceNode", Msgpck.String p.source_node);
11 ]
12
13let decode_ping (m : Msgpck.t) : (ping, string) result =
14 match m with
15 | Msgpck.Map fields ->
16 let get_int key =
17 match List.assoc_opt (Msgpck.String key) fields with
18 | Some (Msgpck.Int i) -> Ok i
19 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
20 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
21 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
22 in
23 let get_string key =
24 match List.assoc_opt (Msgpck.String key) fields with
25 | Some (Msgpck.String s) -> Ok s
26 | Some (Msgpck.Bytes s) -> Ok s
27 | Some Msgpck.Nil -> Ok ""
28 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
29 in
30 let ( let* ) = Result.bind in
31 let* seq_no = get_int "SeqNo" in
32 let* node = get_string "Node" in
33 let* source_addr = get_string "SourceAddr" in
34 let* source_port =
35 match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0
36 in
37 let* source_node =
38 match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok ""
39 in
40 Ok { seq_no; node; source_addr; source_port; source_node }
41 | _ -> Error "expected map for ping"
42
43let encode_indirect_ping (p : indirect_ping_req) : Msgpck.t =
44 Msgpck.Map
45 [
46 (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no);
47 (Msgpck.String "Target", Msgpck.String p.target);
48 (Msgpck.String "Port", Msgpck.of_int p.port);
49 (Msgpck.String "Node", Msgpck.String p.node);
50 (Msgpck.String "Nack", Msgpck.Bool p.nack);
51 (Msgpck.String "SourceAddr", Msgpck.String p.source_addr);
52 (Msgpck.String "SourcePort", Msgpck.of_int p.source_port);
53 (Msgpck.String "SourceNode", Msgpck.String p.source_node);
54 ]
55
56let decode_indirect_ping (m : Msgpck.t) : (indirect_ping_req, string) result =
57 match m with
58 | Msgpck.Map fields ->
59 let get_int key =
60 match List.assoc_opt (Msgpck.String key) fields with
61 | Some (Msgpck.Int i) -> Ok i
62 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
63 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
64 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
65 in
66 let get_string key =
67 match List.assoc_opt (Msgpck.String key) fields with
68 | Some (Msgpck.String s) -> Ok s
69 | Some (Msgpck.Bytes s) -> Ok s
70 | Some Msgpck.Nil -> Ok ""
71 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
72 in
73 let get_bool key =
74 match List.assoc_opt (Msgpck.String key) fields with
75 | Some (Msgpck.Bool b) -> Ok b
76 | _ -> Ok false
77 in
78 let ( let* ) = Result.bind in
79 let* seq_no = get_int "SeqNo" in
80 let* target = get_string "Target" in
81 let* port = match get_int "Port" with Ok p -> Ok p | Error _ -> Ok 0 in
82 let* node = get_string "Node" in
83 let* nack = get_bool "Nack" in
84 let* source_addr =
85 match get_string "SourceAddr" with Ok s -> Ok s | Error _ -> Ok ""
86 in
87 let* source_port =
88 match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0
89 in
90 let* source_node =
91 match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok ""
92 in
93 Ok
94 {
95 seq_no;
96 target;
97 port;
98 node;
99 nack;
100 source_addr;
101 source_port;
102 source_node;
103 }
104 | _ -> Error "expected map for indirect_ping"
105
106let encode_ack (a : ack_resp) : Msgpck.t =
107 Msgpck.Map
108 [
109 (Msgpck.String "SeqNo", Msgpck.of_int a.seq_no);
110 (Msgpck.String "Payload", Msgpck.String a.payload);
111 ]
112
113let decode_ack (m : Msgpck.t) : (ack_resp, string) result =
114 match m with
115 | Msgpck.Map fields ->
116 let get_int key =
117 match List.assoc_opt (Msgpck.String key) fields with
118 | Some (Msgpck.Int i) -> Ok i
119 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
120 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
121 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
122 in
123 let get_bytes key =
124 match List.assoc_opt (Msgpck.String key) fields with
125 | Some (Msgpck.Bytes s) -> Ok s
126 | Some (Msgpck.String s) -> Ok s
127 | Some Msgpck.Nil -> Ok ""
128 | _ -> Ok ""
129 in
130 let ( let* ) = Result.bind in
131 let* seq_no = get_int "SeqNo" in
132 let* payload = get_bytes "Payload" in
133 Ok { seq_no; payload }
134 | _ -> Error "expected map for ack"
135
136let encode_nack (n : nack_resp) : Msgpck.t =
137 Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int n.seq_no) ]
138
139let decode_nack (m : Msgpck.t) : (nack_resp, string) result =
140 match m with
141 | Msgpck.Map fields ->
142 let get_int key =
143 match List.assoc_opt (Msgpck.String key) fields with
144 | Some (Msgpck.Int i) -> Ok i
145 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
146 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
147 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
148 in
149 let ( let* ) = Result.bind in
150 let* seq_no = get_int "SeqNo" in
151 Ok { seq_no }
152 | _ -> Error "expected map for nack"
153
154let encode_suspect (s : suspect) : Msgpck.t =
155 Msgpck.Map
156 [
157 (Msgpck.String "Incarnation", Msgpck.of_int s.incarnation);
158 (Msgpck.String "Node", Msgpck.String s.node);
159 (Msgpck.String "From", Msgpck.String s.from);
160 ]
161
162let decode_suspect (m : Msgpck.t) : (suspect, string) result =
163 match m with
164 | Msgpck.Map fields ->
165 let get_int key =
166 match List.assoc_opt (Msgpck.String key) fields with
167 | Some (Msgpck.Int i) -> Ok i
168 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
169 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
170 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
171 in
172 let get_string key =
173 match List.assoc_opt (Msgpck.String key) fields with
174 | Some (Msgpck.String s) -> Ok s
175 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
176 in
177 let ( let* ) = Result.bind in
178 let* incarnation = get_int "Incarnation" in
179 let* node = get_string "Node" in
180 let* from = get_string "From" in
181 Ok ({ incarnation; node; from } : suspect)
182 | _ -> Error "expected map for suspect"
183
184let encode_alive (a : alive) : Msgpck.t =
185 Msgpck.Map
186 [
187 (Msgpck.String "Incarnation", Msgpck.of_int a.incarnation);
188 (Msgpck.String "Node", Msgpck.String a.node);
189 (Msgpck.String "Addr", Msgpck.String a.addr);
190 (Msgpck.String "Port", Msgpck.of_int a.port);
191 (Msgpck.String "Meta", Msgpck.String a.meta);
192 (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int a.vsn));
193 ]
194
195let decode_alive (m : Msgpck.t) : (alive, string) result =
196 match m with
197 | Msgpck.Map fields ->
198 let get_int key =
199 match List.assoc_opt (Msgpck.String key) fields with
200 | Some (Msgpck.Int i) -> Ok i
201 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
202 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
203 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
204 in
205 let get_string key =
206 match List.assoc_opt (Msgpck.String key) fields with
207 | Some (Msgpck.String s) -> Ok s
208 | Some (Msgpck.Bytes s) -> Ok s
209 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
210 in
211 let get_vsn () =
212 match List.assoc_opt (Msgpck.String "Vsn") fields with
213 | Some (Msgpck.List vs) ->
214 Ok
215 (List.filter_map
216 (function
217 | Msgpck.Int i -> Some i
218 | Msgpck.Int32 i -> Some (Int32.to_int i)
219 | _ -> None)
220 vs)
221 | _ -> Ok []
222 in
223 let ( let* ) = Result.bind in
224 let* incarnation = get_int "Incarnation" in
225 let* node = get_string "Node" in
226 let* addr = get_string "Addr" in
227 let* port = get_int "Port" in
228 let* meta =
229 match get_string "Meta" with Ok m -> Ok m | Error _ -> Ok ""
230 in
231 let* vsn = get_vsn () in
232 Ok { incarnation; node; addr; port; meta; vsn }
233 | _ -> Error "expected map for alive"
234
235let encode_dead (d : dead) : Msgpck.t =
236 Msgpck.Map
237 [
238 (Msgpck.String "Incarnation", Msgpck.of_int d.incarnation);
239 (Msgpck.String "Node", Msgpck.String d.node);
240 (Msgpck.String "From", Msgpck.String d.from);
241 ]
242
243let decode_dead (m : Msgpck.t) : (dead, string) result =
244 match m with
245 | Msgpck.Map fields ->
246 let get_int key =
247 match List.assoc_opt (Msgpck.String key) fields with
248 | Some (Msgpck.Int i) -> Ok i
249 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
250 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
251 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
252 in
253 let get_string key =
254 match List.assoc_opt (Msgpck.String key) fields with
255 | Some (Msgpck.String s) -> Ok s
256 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
257 in
258 let ( let* ) = Result.bind in
259 let* incarnation = get_int "Incarnation" in
260 let* node = get_string "Node" in
261 let* from = get_string "From" in
262 Ok ({ incarnation; node; from } : dead)
263 | _ -> Error "expected map for dead"
264
265let encode_compress (c : compress) : Msgpck.t =
266 Msgpck.Map
267 [
268 (Msgpck.String "Algo", Msgpck.of_int c.algo);
269 (Msgpck.String "Buf", Msgpck.String c.buf);
270 ]
271
272let decode_compress (m : Msgpck.t) : (compress, string) result =
273 match m with
274 | Msgpck.Map fields ->
275 let get_int key =
276 match List.assoc_opt (Msgpck.String key) fields with
277 | Some (Msgpck.Int i) -> Ok i
278 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
279 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
280 in
281 let get_bytes key =
282 match List.assoc_opt (Msgpck.String key) fields with
283 | Some (Msgpck.Bytes s) -> Ok s
284 | Some (Msgpck.String s) -> Ok s
285 | _ -> Error (Printf.sprintf "missing or invalid %s" key)
286 in
287 let ( let* ) = Result.bind in
288 let* algo = get_int "Algo" in
289 let* buf = get_bytes "Buf" in
290 Ok { algo; buf }
291 | _ -> Error "expected map for compress"
292
293let wire_msg_to_msgpck (msg : protocol_msg) : message_type * Msgpck.t =
294 match msg with
295 | Ping p -> (Ping_msg, encode_ping p)
296 | Indirect_ping p -> (Indirect_ping_msg, encode_indirect_ping p)
297 | Ack a -> (Ack_resp_msg, encode_ack a)
298 | Nack n -> (Nack_resp_msg, encode_nack n)
299 | Suspect s -> (Suspect_msg, encode_suspect s)
300 | Alive a -> (Alive_msg, encode_alive a)
301 | Dead d -> (Dead_msg, encode_dead d)
302 | User_data _ -> (User_msg, Msgpck.Nil)
303 | Compound _ -> (Compound_msg, Msgpck.Nil)
304 | Compressed c -> (Compress_msg, encode_compress c)
305 | Err e -> (Err_msg, Msgpck.Map [ (Msgpck.String "Error", Msgpck.String e) ])
306
307let encode_msg_to_cstruct (msg : protocol_msg) ~(buf : Cstruct.t) :
308 (int, [ `Buffer_too_small ]) result =
309 let msg_type, payload = wire_msg_to_msgpck msg in
310 let msg_type_byte = message_type_to_int msg_type in
311 match msg with
312 | User_data data ->
313 let total_len = 1 + String.length data in
314 if total_len > Cstruct.length buf then Error `Buffer_too_small
315 else begin
316 Cstruct.set_uint8 buf 0 msg_type_byte;
317 Cstruct.blit_from_string data 0 buf 1 (String.length data);
318 Ok total_len
319 end
320 | _ ->
321 let payload_size = Msgpck.size payload in
322 let total_len = 1 + payload_size in
323 if total_len > Cstruct.length buf then Error `Buffer_too_small
324 else begin
325 Cstruct.set_uint8 buf 0 msg_type_byte;
326 let payload_bytes = Bytes.create payload_size in
327 let _ = Msgpck.Bytes.write payload_bytes payload in
328 Cstruct.blit_from_bytes payload_bytes 0 buf 1 payload_size;
329 Ok total_len
330 end
331
332let decode_msg_from_cstruct (buf : Cstruct.t) :
333 (protocol_msg, Types.decode_error) result =
334 if Cstruct.length buf < 1 then Error Types.Truncated_message
335 else
336 let msg_type_byte = Cstruct.get_uint8 buf 0 in
337 match message_type_of_int msg_type_byte with
338 | Error n -> Error (Types.Invalid_tag n)
339 | Ok msg_type -> (
340 let payload_len = Cstruct.length buf - 1 in
341 match msg_type with
342 | User_msg ->
343 let data = Cstruct.to_string ~off:1 ~len:payload_len buf in
344 Ok (User_data data)
345 | Compound_msg -> Ok (Compound [])
346 | _ -> (
347 let payload_bytes = Cstruct.to_bytes ~off:1 ~len:payload_len buf in
348 let _, msgpack = Msgpck.Bytes.read payload_bytes in
349 match msg_type with
350 | Ping_msg -> (
351 match decode_ping msgpack with
352 | Ok p -> Ok (Ping p)
353 | Error e -> Error (Types.Msgpack_error e))
354 | Indirect_ping_msg -> (
355 match decode_indirect_ping msgpack with
356 | Ok p -> Ok (Indirect_ping p)
357 | Error e -> Error (Types.Msgpack_error e))
358 | Ack_resp_msg -> (
359 match decode_ack msgpack with
360 | Ok a -> Ok (Ack a)
361 | Error e -> Error (Types.Msgpack_error e))
362 | Nack_resp_msg -> (
363 match decode_nack msgpack with
364 | Ok n -> Ok (Nack n)
365 | Error e -> Error (Types.Msgpack_error e))
366 | Suspect_msg -> (
367 match decode_suspect msgpack with
368 | Ok s -> Ok (Suspect s)
369 | Error e -> Error (Types.Msgpack_error e))
370 | Alive_msg -> (
371 match decode_alive msgpack with
372 | Ok a -> Ok (Alive a)
373 | Error e -> Error (Types.Msgpack_error e))
374 | Dead_msg -> (
375 match decode_dead msgpack with
376 | Ok d -> Ok (Dead d)
377 | Error e -> Error (Types.Msgpack_error e))
378 | Compress_msg -> (
379 match decode_compress msgpack with
380 | Ok c -> Ok (Compressed c)
381 | Error e -> Error (Types.Msgpack_error e))
382 | Err_msg -> (
383 match msgpack with
384 | Msgpck.Map fields -> (
385 match List.assoc_opt (Msgpck.String "Error") fields with
386 | Some (Msgpck.String e) -> Ok (Err e)
387 | _ -> Ok (Err "unknown error"))
388 | _ -> Ok (Err "unknown error"))
389 | _ -> Error (Types.Invalid_tag msg_type_byte)))
390
391let crc32_table =
392 Array.init 256 (fun i ->
393 let crc = ref (Int32.of_int i) in
394 for _ = 0 to 7 do
395 if Int32.logand !crc 1l = 1l then
396 crc := Int32.logxor (Int32.shift_right_logical !crc 1) 0xEDB88320l
397 else crc := Int32.shift_right_logical !crc 1
398 done;
399 !crc)
400
401let crc32_cstruct (buf : Cstruct.t) : int32 =
402 let crc = ref 0xFFFFFFFFl in
403 for i = 0 to Cstruct.length buf - 1 do
404 let byte = Cstruct.get_uint8 buf i in
405 let idx =
406 Int32.to_int (Int32.logand (Int32.logxor !crc (Int32.of_int byte)) 0xFFl)
407 in
408 crc := Int32.logxor (Int32.shift_right_logical !crc 8) crc32_table.(idx)
409 done;
410 Int32.logxor !crc 0xFFFFFFFFl
411
412let add_crc_to_cstruct ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) :
413 (int, [ `Buffer_too_small ]) result =
414 let total_len = 5 + src_len in
415 if total_len > Cstruct.length dst then Error `Buffer_too_small
416 else begin
417 let payload = Cstruct.sub src 0 src_len in
418 let crc = crc32_cstruct payload in
419 Cstruct.set_uint8 dst 0 (message_type_to_int Has_crc_msg);
420 Cstruct.BE.set_uint32 dst 1 crc;
421 Cstruct.blit payload 0 dst 5 src_len;
422 Ok total_len
423 end
424
425let verify_and_strip_crc (buf : Cstruct.t) :
426 (Cstruct.t, Types.decode_error) result =
427 if Cstruct.length buf < 5 then Error Types.Truncated_message
428 else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_crc_msg then Ok buf
429 else
430 let expected = Cstruct.BE.get_uint32 buf 1 in
431 let payload = Cstruct.shift buf 5 in
432 let actual = crc32_cstruct payload in
433 if expected = actual then Ok payload else Error Types.Invalid_crc
434
435let add_label_to_cstruct ~label ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) :
436 (int, [ `Buffer_too_small ]) result =
437 if label = "" then begin
438 if src_len > Cstruct.length dst then Error `Buffer_too_small
439 else begin
440 Cstruct.blit src 0 dst 0 src_len;
441 Ok src_len
442 end
443 end
444 else
445 let label_len = String.length label in
446 let total_len = 2 + label_len + src_len in
447 if total_len > Cstruct.length dst then Error `Buffer_too_small
448 else begin
449 Cstruct.set_uint8 dst 0 (message_type_to_int Has_label_msg);
450 Cstruct.set_uint8 dst 1 label_len;
451 Cstruct.blit_from_string label 0 dst 2 label_len;
452 Cstruct.blit src 0 dst (2 + label_len) src_len;
453 Ok total_len
454 end
455
456let strip_label (buf : Cstruct.t) :
457 (Cstruct.t * string, Types.decode_error) result =
458 if Cstruct.length buf < 1 then Error Types.Truncated_message
459 else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_label_msg then
460 Ok (buf, "")
461 else if Cstruct.length buf < 2 then Error Types.Truncated_message
462 else
463 let label_len = Cstruct.get_uint8 buf 1 in
464 if Cstruct.length buf < 2 + label_len then Error Types.Truncated_message
465 else
466 let label = Cstruct.to_string ~off:2 ~len:label_len buf in
467 let payload = Cstruct.shift buf (2 + label_len) in
468 Ok (payload, label)
469
470let encode_compound_to_cstruct ~(msgs : Cstruct.t list) ~(msg_lens : int list)
471 ~(dst : Cstruct.t) : (int, [ `Buffer_too_small ]) result =
472 let num_msgs = List.length msgs in
473 if num_msgs > 255 then failwith "too many messages for compound"
474 else
475 let header_size = 1 + 1 + (num_msgs * 2) in
476 let total_payload = List.fold_left ( + ) 0 msg_lens in
477 let total_len = header_size + total_payload in
478 if total_len > Cstruct.length dst then Error `Buffer_too_small
479 else begin
480 Cstruct.set_uint8 dst 0 (message_type_to_int Compound_msg);
481 Cstruct.set_uint8 dst 1 num_msgs;
482 List.iteri
483 (fun i len -> Cstruct.BE.set_uint16 dst (2 + (i * 2)) len)
484 msg_lens;
485 let offset = ref header_size in
486 List.iter2
487 (fun msg len ->
488 Cstruct.blit msg 0 dst !offset len;
489 offset := !offset + len)
490 msgs msg_lens;
491 Ok total_len
492 end
493
494let decode_compound_from_cstruct (buf : Cstruct.t) :
495 (Cstruct.t list * int, Types.decode_error) result =
496 if Cstruct.length buf < 1 then Error Types.Truncated_message
497 else
498 let num_parts = Cstruct.get_uint8 buf 0 in
499 let header_size = 1 + (num_parts * 2) in
500 if Cstruct.length buf < header_size then Error Types.Truncated_message
501 else
502 let lengths =
503 List.init num_parts (fun i -> Cstruct.BE.get_uint16 buf (1 + (i * 2)))
504 in
505 let rec extract_parts offset remaining_lens acc trunc =
506 match remaining_lens with
507 | [] -> Ok (List.rev acc, trunc)
508 | len :: rest ->
509 if offset + len > Cstruct.length buf then
510 Ok (List.rev acc, List.length remaining_lens)
511 else
512 let part = Cstruct.sub buf offset len in
513 extract_parts (offset + len) rest (part :: acc) trunc
514 in
515 extract_parts header_size lengths [] 0
516
517let encode_internal_msg_to_cstruct ~self_name ~self_port
518 (msg : Types.protocol_msg) ~(buf : Cstruct.t) :
519 (int, [ `Buffer_too_small ]) result =
520 let wire_msg = Types.msg_to_wire ~self_name ~self_port msg in
521 encode_msg_to_cstruct wire_msg ~buf
522
523let decode_internal_msg_from_cstruct ~default_port (buf : Cstruct.t) :
524 (Types.protocol_msg, Types.decode_error) result =
525 match decode_msg_from_cstruct buf with
526 | Error e -> Error e
527 | Ok wire_msg -> (
528 match Types.msg_of_wire ~default_port wire_msg with
529 | Some msg -> Ok msg
530 | None -> Error (Types.Invalid_tag 0))
531
532let encode_packet (packet : Types.packet) ~(buf : Cstruct.t) :
533 (int, [ `Buffer_too_small ]) result =
534 let self_name = packet.cluster in
535 let self_port = 7946 in
536 match packet.piggyback with
537 | [] ->
538 encode_internal_msg_to_cstruct ~self_name ~self_port packet.primary ~buf
539 | piggyback -> (
540 let msgs = packet.primary :: piggyback in
541 let num_msgs = List.length msgs in
542 if num_msgs > 255 then failwith "too many messages for compound"
543 else
544 let header_size = 1 + 1 + (num_msgs * 2) in
545 if header_size > Cstruct.length buf then Error `Buffer_too_small
546 else
547 let rec encode_msgs i msgs current_offset =
548 match msgs with
549 | [] -> Ok current_offset
550 | msg :: rest -> (
551 if current_offset >= Cstruct.length buf then
552 Error `Buffer_too_small
553 else
554 let slice = Cstruct.shift buf current_offset in
555 match
556 encode_internal_msg_to_cstruct ~self_name ~self_port msg
557 ~buf:slice
558 with
559 | Error _ -> Error `Buffer_too_small
560 | Ok len ->
561 Cstruct.BE.set_uint16 buf (2 + (i * 2)) len;
562 encode_msgs (i + 1) rest (current_offset + len))
563 in
564 match encode_msgs 0 msgs header_size with
565 | Ok final_offset ->
566 Cstruct.set_uint8 buf 0 (message_type_to_int Compound_msg);
567 Cstruct.set_uint8 buf 1 num_msgs;
568 Ok final_offset
569 | Error e -> Error e)
570
571let decode_packet (buf : Cstruct.t) : (Types.packet, Types.decode_error) result
572 =
573 if Cstruct.length buf < 1 then Error Types.Truncated_message
574 else
575 let msg_type = Cstruct.get_uint8 buf 0 in
576 if msg_type = message_type_to_int Compound_msg then
577 let payload = Cstruct.shift buf 1 in
578 match decode_compound_from_cstruct payload with
579 | Error e -> Error e
580 | Ok (parts, _truncated) -> (
581 match parts with
582 | [] -> Error Types.Truncated_message
583 | first :: rest -> (
584 match
585 decode_internal_msg_from_cstruct ~default_port:7946 first
586 with
587 | Error e -> Error e
588 | Ok primary ->
589 let piggyback =
590 List.filter_map
591 (fun p ->
592 match
593 decode_internal_msg_from_cstruct ~default_port:7946 p
594 with
595 | Ok m -> Some m
596 | Error _ -> None)
597 rest
598 in
599 Ok { Types.cluster = ""; primary; piggyback }))
600 else
601 match decode_internal_msg_from_cstruct ~default_port:7946 buf with
602 | Error e -> Error e
603 | Ok primary -> Ok { Types.cluster = ""; primary; piggyback = [] }
604
605let encoded_size (msg : Types.protocol_msg) : int =
606 let wire_msg = Types.msg_to_wire ~self_name:"" ~self_port:7946 msg in
607 let _, payload = wire_msg_to_msgpck wire_msg in
608 1 + Msgpck.size payload + 3
609
610let encode_internal_msg ~self_name ~self_port (msg : Types.protocol_msg) :
611 string =
612 let buf = Cstruct.create 2048 in
613 match encode_internal_msg_to_cstruct ~self_name ~self_port msg ~buf with
614 | Error _ -> ""
615 | Ok len -> Cstruct.to_string ~off:0 ~len buf
616
617(* Backward-compatible string wrappers for tests *)
618
619let add_crc (data : string) : string =
620 let src = Cstruct.of_string data in
621 let dst = Cstruct.create (5 + String.length data) in
622 match add_crc_to_cstruct ~src ~src_len:(String.length data) ~dst with
623 | Error _ -> data
624 | Ok len -> Cstruct.to_string ~off:0 ~len dst
625
626let verify_and_strip_crc_string (data : string) :
627 (string, Types.decode_error) result =
628 let buf = Cstruct.of_string data in
629 match verify_and_strip_crc buf with
630 | Error e -> Error e
631 | Ok cs -> Ok (Cstruct.to_string cs)
632
633let add_label (label : string) (data : string) : string =
634 let src = Cstruct.of_string data in
635 let dst = Cstruct.create (2 + String.length label + String.length data) in
636 match add_label_to_cstruct ~label ~src ~src_len:(String.length data) ~dst with
637 | Error _ -> data
638 | Ok len -> Cstruct.to_string ~off:0 ~len dst
639
640let strip_label_string (data : string) :
641 (string * string, Types.decode_error) result =
642 let buf = Cstruct.of_string data in
643 match strip_label buf with
644 | Error e -> Error e
645 | Ok (cs, label) -> Ok (Cstruct.to_string cs, label)
646
647let make_compound_msg (msgs : string list) : string =
648 let css = List.map Cstruct.of_string msgs in
649 let lens = List.map String.length msgs in
650 let total_len = 2 + (List.length msgs * 2) + List.fold_left ( + ) 0 lens in
651 let dst = Cstruct.create total_len in
652 match encode_compound_to_cstruct ~msgs:css ~msg_lens:lens ~dst with
653 | Error _ -> ""
654 | Ok len -> Cstruct.to_string ~off:0 ~len dst
655
656let decode_compound_msg (data : string) :
657 (string list * int, Types.decode_error) result =
658 let buf = Cstruct.of_string data in
659 match decode_compound_from_cstruct buf with
660 | Error e -> Error e
661 | Ok (css, trunc) -> Ok (List.map Cstruct.to_string css, trunc)
662
663let encode_push_pull_header (h : push_pull_header) : Msgpck.t =
664 Msgpck.Map
665 [
666 (Msgpck.String "Nodes", Msgpck.of_int h.pp_nodes);
667 (Msgpck.String "UserStateLen", Msgpck.of_int h.pp_user_state_len);
668 (Msgpck.String "Join", Msgpck.Bool h.pp_join);
669 ]
670
671let decode_push_pull_header (m : Msgpck.t) : (push_pull_header, string) result =
672 match m with
673 | Msgpck.Map fields ->
674 let get_int key =
675 match List.assoc_opt (Msgpck.String key) fields with
676 | Some (Msgpck.Int i) -> Ok i
677 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
678 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
679 | _ -> Ok 0
680 in
681 let get_bool key =
682 match List.assoc_opt (Msgpck.String key) fields with
683 | Some (Msgpck.Bool b) -> Ok b
684 | _ -> Ok false
685 in
686 let ( let* ) = Result.bind in
687 let* pp_nodes = get_int "Nodes" in
688 let* pp_user_state_len = get_int "UserStateLen" in
689 let* pp_join = get_bool "Join" in
690 Ok { pp_nodes; pp_user_state_len; pp_join }
691 | _ -> Error "expected map for push_pull_header"
692
693let encode_push_node_state (s : push_node_state) : Msgpck.t =
694 Msgpck.Map
695 [
696 (Msgpck.String "Name", Msgpck.String s.pns_name);
697 (Msgpck.String "Addr", Msgpck.Bytes s.pns_addr);
698 (Msgpck.String "Port", Msgpck.of_int s.pns_port);
699 (Msgpck.String "Meta", Msgpck.Bytes s.pns_meta);
700 (Msgpck.String "Incarnation", Msgpck.of_int s.pns_incarnation);
701 (Msgpck.String "State", Msgpck.of_int s.pns_state);
702 (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int s.pns_vsn));
703 ]
704
705let decode_push_node_state (m : Msgpck.t) : (push_node_state, string) result =
706 match m with
707 | Msgpck.Map fields ->
708 let get_string key =
709 match List.assoc_opt (Msgpck.String key) fields with
710 | Some (Msgpck.String s) -> Ok s
711 | Some (Msgpck.Bytes s) -> Ok s
712 | Some Msgpck.Nil -> Ok ""
713 | _ -> Ok ""
714 in
715 let get_int key =
716 match List.assoc_opt (Msgpck.String key) fields with
717 | Some (Msgpck.Int i) -> Ok i
718 | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i)
719 | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i)
720 | _ -> Ok 0
721 in
722 let get_int_list key =
723 match List.assoc_opt (Msgpck.String key) fields with
724 | Some (Msgpck.List items) ->
725 Ok
726 (List.filter_map
727 (function
728 | Msgpck.Int i -> Some i
729 | Msgpck.Int32 i -> Some (Int32.to_int i)
730 | Msgpck.Uint32 i -> Some (Int32.to_int i)
731 | _ -> None)
732 items)
733 | _ -> Ok []
734 in
735 let ( let* ) = Result.bind in
736 let* pns_name = get_string "Name" in
737 let* pns_addr = get_string "Addr" in
738 let* pns_port = get_int "Port" in
739 let* pns_meta = get_string "Meta" in
740 let* pns_incarnation = get_int "Incarnation" in
741 let* pns_state = get_int "State" in
742 let* pns_vsn = get_int_list "Vsn" in
743 Ok
744 {
745 pns_name;
746 pns_addr;
747 pns_port;
748 pns_meta;
749 pns_incarnation;
750 pns_state;
751 pns_vsn;
752 }
753 | _ -> Error "expected map for push_node_state"
754
755let encode_push_pull_msg ~(header : push_pull_header)
756 ~(nodes : push_node_state list) ~(user_state : string) : string =
757 let buf = Buffer.create 1024 in
758 Buffer.add_char buf (Char.chr (message_type_to_int Push_pull_msg));
759 ignore (Msgpck.StringBuf.write buf (encode_push_pull_header header));
760 List.iter
761 (fun n -> ignore (Msgpck.StringBuf.write buf (encode_push_node_state n)))
762 nodes;
763 Buffer.add_string buf user_state;
764 Buffer.contents buf
765
766let decode_push_pull_msg (data : string) :
767 ( push_pull_header * push_node_state list * string,
768 Types.decode_error )
769 result =
770 if String.length data < 1 then Error Types.Truncated_message
771 else
772 let header_size, header_msgpack = Msgpck.String.read data in
773 match decode_push_pull_header header_msgpack with
774 | Error e -> Error (Types.Msgpack_error e)
775 | Ok header -> (
776 let rec read_nodes offset remaining acc =
777 if remaining <= 0 then Ok (List.rev acc, offset)
778 else if offset >= String.length data then
779 Error Types.Truncated_message
780 else
781 let rest = String.sub data offset (String.length data - offset) in
782 let node_size, node_msgpack = Msgpck.String.read rest in
783 match decode_push_node_state node_msgpack with
784 | Error e -> Error (Types.Msgpack_error e)
785 | Ok node ->
786 read_nodes (offset + node_size) (remaining - 1) (node :: acc)
787 in
788 match read_nodes header_size header.pp_nodes [] with
789 | Error e -> Error e
790 | Ok (nodes, offset) ->
791 let user_state =
792 if header.pp_user_state_len > 0 && offset < String.length data
793 then
794 String.sub data offset
795 (min header.pp_user_state_len (String.length data - offset))
796 else ""
797 in
798 Ok (header, nodes, user_state))
799
800let decode_compress_from_cstruct (buf : Cstruct.t) :
801 (int * Cstruct.t, Types.decode_error) result =
802 let data = Cstruct.to_string buf in
803 let _, msgpack = Msgpck.String.read data in
804 match msgpack with
805 | Msgpck.Map fields -> (
806 let algo =
807 match List.assoc_opt (Msgpck.String "Algo") fields with
808 | Some (Msgpck.Int i) -> i
809 | Some (Msgpck.Int32 i) -> Int32.to_int i
810 | _ -> -1
811 in
812 let compressed_buf =
813 match List.assoc_opt (Msgpck.String "Buf") fields with
814 | Some (Msgpck.Bytes s) -> Some (Cstruct.of_string s)
815 | Some (Msgpck.String s) -> Some (Cstruct.of_string s)
816 | _ -> None
817 in
818 match compressed_buf with
819 | Some cs -> Ok (algo, cs)
820 | None -> Error (Types.Msgpack_error "missing Buf field"))
821 | _ -> Error (Types.Msgpack_error "expected map for compress")
822
823let decode_push_pull_msg_cstruct (buf : Cstruct.t) :
824 ( push_pull_header * push_node_state list * Cstruct.t,
825 Types.decode_error )
826 result =
827 if Cstruct.length buf < 1 then Error Types.Truncated_message
828 else
829 let data = Cstruct.to_string buf in
830 let header_size, header_msgpack = Msgpck.String.read data in
831 match decode_push_pull_header header_msgpack with
832 | Error e -> Error (Types.Msgpack_error e)
833 | Ok header -> (
834 let rec read_nodes offset remaining acc =
835 if remaining <= 0 then Ok (List.rev acc, offset)
836 else if offset >= String.length data then
837 Error Types.Truncated_message
838 else
839 let rest = String.sub data offset (String.length data - offset) in
840 let node_size, node_msgpack = Msgpck.String.read rest in
841 match decode_push_node_state node_msgpack with
842 | Error e -> Error (Types.Msgpack_error e)
843 | Ok node ->
844 read_nodes (offset + node_size) (remaining - 1) (node :: acc)
845 in
846 match read_nodes header_size header.pp_nodes [] with
847 | Error e -> Error e
848 | Ok (nodes, offset) ->
849 let user_state =
850 if header.pp_user_state_len > 0 && offset < Cstruct.length buf
851 then
852 Cstruct.sub buf offset
853 (min header.pp_user_state_len (Cstruct.length buf - offset))
854 else Cstruct.empty
855 in
856 Ok (header, nodes, user_state))