open Types.Wire let encode_ping (p : ping) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no); (Msgpck.String "Node", Msgpck.String p.node); (Msgpck.String "SourceAddr", Msgpck.String p.source_addr); (Msgpck.String "SourcePort", Msgpck.of_int p.source_port); (Msgpck.String "SourceNode", Msgpck.String p.source_node); ] let decode_ping (m : Msgpck.t) : (ping, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* node = get_string "Node" in let* source_addr = get_string "SourceAddr" in let* source_port = match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0 in let* source_node = match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok "" in Ok { seq_no; node; source_addr; source_port; source_node } | _ -> Error "expected map for ping" let encode_indirect_ping (p : indirect_ping_req) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no); (Msgpck.String "Target", Msgpck.String p.target); (Msgpck.String "Port", Msgpck.of_int p.port); (Msgpck.String "Node", Msgpck.String p.node); (Msgpck.String "Nack", Msgpck.Bool p.nack); (Msgpck.String "SourceAddr", Msgpck.String p.source_addr); (Msgpck.String "SourcePort", Msgpck.of_int p.source_port); (Msgpck.String "SourceNode", Msgpck.String p.source_node); ] let decode_indirect_ping (m : Msgpck.t) : (indirect_ping_req, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bool key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bool b) -> Ok b | _ -> Ok false in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* target = get_string "Target" in let* port = match get_int "Port" with Ok p -> Ok p | Error _ -> Ok 0 in let* node = get_string "Node" in let* nack = get_bool "Nack" in let* source_addr = match get_string "SourceAddr" with Ok s -> Ok s | Error _ -> Ok "" in let* source_port = match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0 in let* source_node = match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok "" in Ok { seq_no; target; port; node; nack; source_addr; source_port; source_node; } | _ -> Error "expected map for indirect_ping" let encode_ack (a : ack_resp) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int a.seq_no); (Msgpck.String "Payload", Msgpck.String a.payload); ] let decode_ack (m : Msgpck.t) : (ack_resp, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bytes key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bytes s) -> Ok s | Some (Msgpck.String s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Ok "" in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* payload = get_bytes "Payload" in Ok { seq_no; payload } | _ -> Error "expected map for ack" let encode_nack (n : nack_resp) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int n.seq_no) ] let decode_nack (m : Msgpck.t) : (nack_resp, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in Ok { seq_no } | _ -> Error "expected map for nack" let encode_suspect (s : suspect) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int s.incarnation); (Msgpck.String "Node", Msgpck.String s.node); (Msgpck.String "From", Msgpck.String s.from); ] let decode_suspect (m : Msgpck.t) : (suspect, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* from = get_string "From" in Ok ({ incarnation; node; from } : suspect) | _ -> Error "expected map for suspect" let encode_alive (a : alive) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int a.incarnation); (Msgpck.String "Node", Msgpck.String a.node); (Msgpck.String "Addr", Msgpck.String a.addr); (Msgpck.String "Port", Msgpck.of_int a.port); (Msgpck.String "Meta", Msgpck.String a.meta); (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int a.vsn)); ] let decode_alive (m : Msgpck.t) : (alive, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_vsn () = match List.assoc_opt (Msgpck.String "Vsn") fields with | Some (Msgpck.List vs) -> Ok (List.filter_map (function | Msgpck.Int i -> Some i | Msgpck.Int32 i -> Some (Int32.to_int i) | _ -> None) vs) | _ -> Ok [] in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* addr = get_string "Addr" in let* port = get_int "Port" in let* meta = match get_string "Meta" with Ok m -> Ok m | Error _ -> Ok "" in let* vsn = get_vsn () in Ok { incarnation; node; addr; port; meta; vsn } | _ -> Error "expected map for alive" let encode_dead (d : dead) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int d.incarnation); (Msgpck.String "Node", Msgpck.String d.node); (Msgpck.String "From", Msgpck.String d.from); ] let decode_dead (m : Msgpck.t) : (dead, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* from = get_string "From" in Ok ({ incarnation; node; from } : dead) | _ -> Error "expected map for dead" let encode_compress (c : compress) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Algo", Msgpck.of_int c.algo); (Msgpck.String "Buf", Msgpck.String c.buf); ] let decode_compress (m : Msgpck.t) : (compress, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bytes key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bytes s) -> Ok s | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* algo = get_int "Algo" in let* buf = get_bytes "Buf" in Ok { algo; buf } | _ -> Error "expected map for compress" let wire_msg_to_msgpck (msg : protocol_msg) : message_type * Msgpck.t = match msg with | Ping p -> (Ping_msg, encode_ping p) | Indirect_ping p -> (Indirect_ping_msg, encode_indirect_ping p) | Ack a -> (Ack_resp_msg, encode_ack a) | Nack n -> (Nack_resp_msg, encode_nack n) | Suspect s -> (Suspect_msg, encode_suspect s) | Alive a -> (Alive_msg, encode_alive a) | Dead d -> (Dead_msg, encode_dead d) | User_data _ -> (User_msg, Msgpck.Nil) | Compound _ -> (Compound_msg, Msgpck.Nil) | Compressed c -> (Compress_msg, encode_compress c) | Err e -> (Err_msg, Msgpck.Map [ (Msgpck.String "Error", Msgpck.String e) ]) let encode_msg_to_cstruct (msg : protocol_msg) ~(buf : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let msg_type, payload = wire_msg_to_msgpck msg in let msg_type_byte = message_type_to_int msg_type in match msg with | User_data data -> let total_len = 1 + String.length data in if total_len > Cstruct.length buf then Error `Buffer_too_small else begin Cstruct.set_uint8 buf 0 msg_type_byte; Cstruct.blit_from_string data 0 buf 1 (String.length data); Ok total_len end | _ -> let payload_size = Msgpck.size payload in let total_len = 1 + payload_size in if total_len > Cstruct.length buf then Error `Buffer_too_small else begin Cstruct.set_uint8 buf 0 msg_type_byte; let payload_bytes = Bytes.create payload_size in let _ = Msgpck.Bytes.write payload_bytes payload in Cstruct.blit_from_bytes payload_bytes 0 buf 1 payload_size; Ok total_len end let decode_msg_from_cstruct (buf : Cstruct.t) : (protocol_msg, Types.decode_error) result = if Cstruct.length buf < 1 then Error Types.Truncated_message else let msg_type_byte = Cstruct.get_uint8 buf 0 in match message_type_of_int msg_type_byte with | Error n -> Error (Types.Invalid_tag n) | Ok msg_type -> ( let payload_len = Cstruct.length buf - 1 in match msg_type with | User_msg -> let data = Cstruct.to_string ~off:1 ~len:payload_len buf in Ok (User_data data) | Compound_msg -> Ok (Compound []) | _ -> ( let payload_bytes = Cstruct.to_bytes ~off:1 ~len:payload_len buf in let _, msgpack = Msgpck.Bytes.read payload_bytes in match msg_type with | Ping_msg -> ( match decode_ping msgpack with | Ok p -> Ok (Ping p) | Error e -> Error (Types.Msgpack_error e)) | Indirect_ping_msg -> ( match decode_indirect_ping msgpack with | Ok p -> Ok (Indirect_ping p) | Error e -> Error (Types.Msgpack_error e)) | Ack_resp_msg -> ( match decode_ack msgpack with | Ok a -> Ok (Ack a) | Error e -> Error (Types.Msgpack_error e)) | Nack_resp_msg -> ( match decode_nack msgpack with | Ok n -> Ok (Nack n) | Error e -> Error (Types.Msgpack_error e)) | Suspect_msg -> ( match decode_suspect msgpack with | Ok s -> Ok (Suspect s) | Error e -> Error (Types.Msgpack_error e)) | Alive_msg -> ( match decode_alive msgpack with | Ok a -> Ok (Alive a) | Error e -> Error (Types.Msgpack_error e)) | Dead_msg -> ( match decode_dead msgpack with | Ok d -> Ok (Dead d) | Error e -> Error (Types.Msgpack_error e)) | Compress_msg -> ( match decode_compress msgpack with | Ok c -> Ok (Compressed c) | Error e -> Error (Types.Msgpack_error e)) | Err_msg -> ( match msgpack with | Msgpck.Map fields -> ( match List.assoc_opt (Msgpck.String "Error") fields with | Some (Msgpck.String e) -> Ok (Err e) | _ -> Ok (Err "unknown error")) | _ -> Ok (Err "unknown error")) | _ -> Error (Types.Invalid_tag msg_type_byte))) let crc32_table = Array.init 256 (fun i -> let crc = ref (Int32.of_int i) in for _ = 0 to 7 do if Int32.logand !crc 1l = 1l then crc := Int32.logxor (Int32.shift_right_logical !crc 1) 0xEDB88320l else crc := Int32.shift_right_logical !crc 1 done; !crc) let crc32_cstruct (buf : Cstruct.t) : int32 = let crc = ref 0xFFFFFFFFl in for i = 0 to Cstruct.length buf - 1 do let byte = Cstruct.get_uint8 buf i in let idx = Int32.to_int (Int32.logand (Int32.logxor !crc (Int32.of_int byte)) 0xFFl) in crc := Int32.logxor (Int32.shift_right_logical !crc 8) crc32_table.(idx) done; Int32.logxor !crc 0xFFFFFFFFl let add_crc_to_cstruct ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let total_len = 5 + src_len in if total_len > Cstruct.length dst then Error `Buffer_too_small else begin let payload = Cstruct.sub src 0 src_len in let crc = crc32_cstruct payload in Cstruct.set_uint8 dst 0 (message_type_to_int Has_crc_msg); Cstruct.BE.set_uint32 dst 1 crc; Cstruct.blit payload 0 dst 5 src_len; Ok total_len end let verify_and_strip_crc (buf : Cstruct.t) : (Cstruct.t, Types.decode_error) result = if Cstruct.length buf < 5 then Error Types.Truncated_message else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_crc_msg then Ok buf else let expected = Cstruct.BE.get_uint32 buf 1 in let payload = Cstruct.shift buf 5 in let actual = crc32_cstruct payload in if expected = actual then Ok payload else Error Types.Invalid_crc let add_label_to_cstruct ~label ~(src : Cstruct.t) ~src_len ~(dst : Cstruct.t) : (int, [ `Buffer_too_small ]) result = if label = "" then begin if src_len > Cstruct.length dst then Error `Buffer_too_small else begin Cstruct.blit src 0 dst 0 src_len; Ok src_len end end else let label_len = String.length label in let total_len = 2 + label_len + src_len in if total_len > Cstruct.length dst then Error `Buffer_too_small else begin Cstruct.set_uint8 dst 0 (message_type_to_int Has_label_msg); Cstruct.set_uint8 dst 1 label_len; Cstruct.blit_from_string label 0 dst 2 label_len; Cstruct.blit src 0 dst (2 + label_len) src_len; Ok total_len end let strip_label (buf : Cstruct.t) : (Cstruct.t * string, Types.decode_error) result = if Cstruct.length buf < 1 then Error Types.Truncated_message else if Cstruct.get_uint8 buf 0 <> message_type_to_int Has_label_msg then Ok (buf, "") else if Cstruct.length buf < 2 then Error Types.Truncated_message else let label_len = Cstruct.get_uint8 buf 1 in if Cstruct.length buf < 2 + label_len then Error Types.Truncated_message else let label = Cstruct.to_string ~off:2 ~len:label_len buf in let payload = Cstruct.shift buf (2 + label_len) in Ok (payload, label) let encode_compound_to_cstruct ~(msgs : Cstruct.t list) ~(msg_lens : int list) ~(dst : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let num_msgs = List.length msgs in if num_msgs > 255 then failwith "too many messages for compound" else let header_size = 1 + 1 + (num_msgs * 2) in let total_payload = List.fold_left ( + ) 0 msg_lens in let total_len = header_size + total_payload in if total_len > Cstruct.length dst then Error `Buffer_too_small else begin Cstruct.set_uint8 dst 0 (message_type_to_int Compound_msg); Cstruct.set_uint8 dst 1 num_msgs; List.iteri (fun i len -> Cstruct.BE.set_uint16 dst (2 + (i * 2)) len) msg_lens; let offset = ref header_size in List.iter2 (fun msg len -> Cstruct.blit msg 0 dst !offset len; offset := !offset + len) msgs msg_lens; Ok total_len end let decode_compound_from_cstruct (buf : Cstruct.t) : (Cstruct.t list * int, Types.decode_error) result = if Cstruct.length buf < 1 then Error Types.Truncated_message else let num_parts = Cstruct.get_uint8 buf 0 in let header_size = 1 + (num_parts * 2) in if Cstruct.length buf < header_size then Error Types.Truncated_message else let lengths = List.init num_parts (fun i -> Cstruct.BE.get_uint16 buf (1 + (i * 2))) in let rec extract_parts offset remaining_lens acc trunc = match remaining_lens with | [] -> Ok (List.rev acc, trunc) | len :: rest -> if offset + len > Cstruct.length buf then Ok (List.rev acc, List.length remaining_lens) else let part = Cstruct.sub buf offset len in extract_parts (offset + len) rest (part :: acc) trunc in extract_parts header_size lengths [] 0 let encode_internal_msg_to_cstruct ~self_name ~self_port (msg : Types.protocol_msg) ~(buf : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let wire_msg = Types.msg_to_wire ~self_name ~self_port msg in encode_msg_to_cstruct wire_msg ~buf let decode_internal_msg_from_cstruct ~default_port (buf : Cstruct.t) : (Types.protocol_msg, Types.decode_error) result = match decode_msg_from_cstruct buf with | Error e -> Error e | Ok wire_msg -> ( match Types.msg_of_wire ~default_port wire_msg with | Some msg -> Ok msg | None -> Error (Types.Invalid_tag 0)) let encode_packet (packet : Types.packet) ~(buf : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let self_name = packet.cluster in let self_port = 7946 in match packet.piggyback with | [] -> encode_internal_msg_to_cstruct ~self_name ~self_port packet.primary ~buf | piggyback -> ( let encode_one msg = let temp_buf = Cstruct.create 2048 in match encode_internal_msg_to_cstruct ~self_name ~self_port msg ~buf:temp_buf with | Error _ -> None | Ok len -> Some (Cstruct.sub temp_buf 0 len, len) in let primary_result = encode_one packet.primary in let piggyback_results = List.filter_map encode_one piggyback in match primary_result with | None -> Error `Buffer_too_small | Some (primary_cs, primary_len) -> let all_msgs = primary_cs :: List.map fst piggyback_results in let all_lens = primary_len :: List.map snd piggyback_results in encode_compound_to_cstruct ~msgs:all_msgs ~msg_lens:all_lens ~dst:buf) let decode_packet (buf : Cstruct.t) : (Types.packet, Types.decode_error) result = if Cstruct.length buf < 1 then Error Types.Truncated_message else let msg_type = Cstruct.get_uint8 buf 0 in if msg_type = message_type_to_int Compound_msg then let payload = Cstruct.shift buf 1 in match decode_compound_from_cstruct payload with | Error e -> Error e | Ok (parts, _truncated) -> ( match parts with | [] -> Error Types.Truncated_message | first :: rest -> ( match decode_internal_msg_from_cstruct ~default_port:7946 first with | Error e -> Error e | Ok primary -> let piggyback = List.filter_map (fun p -> match decode_internal_msg_from_cstruct ~default_port:7946 p with | Ok m -> Some m | Error _ -> None) rest in Ok { Types.cluster = ""; primary; piggyback })) else match decode_internal_msg_from_cstruct ~default_port:7946 buf with | Error e -> Error e | Ok primary -> Ok { Types.cluster = ""; primary; piggyback = [] } let encoded_size (msg : Types.protocol_msg) : int = let wire_msg = Types.msg_to_wire ~self_name:"" ~self_port:7946 msg in let _, payload = wire_msg_to_msgpck wire_msg in 1 + Msgpck.size payload + 3 let encode_internal_msg ~self_name ~self_port (msg : Types.protocol_msg) : string = let buf = Cstruct.create 2048 in match encode_internal_msg_to_cstruct ~self_name ~self_port msg ~buf with | Error _ -> "" | Ok len -> Cstruct.to_string ~off:0 ~len buf (* Backward-compatible string wrappers for tests *) let add_crc (data : string) : string = let src = Cstruct.of_string data in let dst = Cstruct.create (5 + String.length data) in match add_crc_to_cstruct ~src ~src_len:(String.length data) ~dst with | Error _ -> data | Ok len -> Cstruct.to_string ~off:0 ~len dst let verify_and_strip_crc_string (data : string) : (string, Types.decode_error) result = let buf = Cstruct.of_string data in match verify_and_strip_crc buf with | Error e -> Error e | Ok cs -> Ok (Cstruct.to_string cs) let add_label (label : string) (data : string) : string = let src = Cstruct.of_string data in let dst = Cstruct.create (2 + String.length label + String.length data) in match add_label_to_cstruct ~label ~src ~src_len:(String.length data) ~dst with | Error _ -> data | Ok len -> Cstruct.to_string ~off:0 ~len dst let strip_label_string (data : string) : (string * string, Types.decode_error) result = let buf = Cstruct.of_string data in match strip_label buf with | Error e -> Error e | Ok (cs, label) -> Ok (Cstruct.to_string cs, label) let make_compound_msg (msgs : string list) : string = let css = List.map Cstruct.of_string msgs in let lens = List.map String.length msgs in let total_len = 2 + (List.length msgs * 2) + List.fold_left ( + ) 0 lens in let dst = Cstruct.create total_len in match encode_compound_to_cstruct ~msgs:css ~msg_lens:lens ~dst with | Error _ -> "" | Ok len -> Cstruct.to_string ~off:0 ~len dst let decode_compound_msg (data : string) : (string list * int, Types.decode_error) result = let buf = Cstruct.of_string data in match decode_compound_from_cstruct buf with | Error e -> Error e | Ok (css, trunc) -> Ok (List.map Cstruct.to_string css, trunc)