open Types.Wire let encode_ping (p : ping) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no); (Msgpck.String "Node", Msgpck.String p.node); (Msgpck.String "SourceAddr", Msgpck.String p.source_addr); (Msgpck.String "SourcePort", Msgpck.of_int p.source_port); (Msgpck.String "SourceNode", Msgpck.String p.source_node); ] let decode_ping (m : Msgpck.t) : (ping, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* node = get_string "Node" in let* source_addr = get_string "SourceAddr" in let* source_port = match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0 in let* source_node = match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok "" in Ok { seq_no; node; source_addr; source_port; source_node } | _ -> Error "expected map for ping" let encode_indirect_ping (p : indirect_ping_req) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int p.seq_no); (Msgpck.String "Target", Msgpck.String p.target); (Msgpck.String "Port", Msgpck.of_int p.port); (Msgpck.String "Node", Msgpck.String p.node); (Msgpck.String "Nack", Msgpck.Bool p.nack); (Msgpck.String "SourceAddr", Msgpck.String p.source_addr); (Msgpck.String "SourcePort", Msgpck.of_int p.source_port); (Msgpck.String "SourceNode", Msgpck.String p.source_node); ] let decode_indirect_ping (m : Msgpck.t) : (indirect_ping_req, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bool key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bool b) -> Ok b | _ -> Ok false in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* target = get_string "Target" in let* port = match get_int "Port" with Ok p -> Ok p | Error _ -> Ok 0 in let* node = get_string "Node" in let* nack = get_bool "Nack" in let* source_addr = match get_string "SourceAddr" with Ok s -> Ok s | Error _ -> Ok "" in let* source_port = match get_int "SourcePort" with Ok p -> Ok p | Error _ -> Ok 0 in let* source_node = match get_string "SourceNode" with Ok s -> Ok s | Error _ -> Ok "" in Ok { seq_no; target; port; node; nack; source_addr; source_port; source_node; } | _ -> Error "expected map for indirect_ping" let encode_ack (a : ack_resp) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int a.seq_no); (Msgpck.String "Payload", Msgpck.String a.payload); ] let decode_ack (m : Msgpck.t) : (ack_resp, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bytes key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bytes s) -> Ok s | Some (Msgpck.String s) -> Ok s | Some Msgpck.Nil -> Ok "" | _ -> Ok "" in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in let* payload = get_bytes "Payload" in Ok { seq_no; payload } | _ -> Error "expected map for ack" let encode_nack (n : nack_resp) : Msgpck.t = Msgpck.Map [ (Msgpck.String "SeqNo", Msgpck.of_int n.seq_no) ] let decode_nack (m : Msgpck.t) : (nack_resp, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* seq_no = get_int "SeqNo" in Ok { seq_no } | _ -> Error "expected map for nack" let encode_suspect (s : suspect) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int s.incarnation); (Msgpck.String "Node", Msgpck.String s.node); (Msgpck.String "From", Msgpck.String s.from); ] let decode_suspect (m : Msgpck.t) : (suspect, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* from = get_string "From" in Ok ({ incarnation; node; from } : suspect) | _ -> Error "expected map for suspect" let encode_alive (a : alive) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int a.incarnation); (Msgpck.String "Node", Msgpck.String a.node); (Msgpck.String "Addr", Msgpck.String a.addr); (Msgpck.String "Port", Msgpck.of_int a.port); (Msgpck.String "Meta", Msgpck.String a.meta); (Msgpck.String "Vsn", Msgpck.List (List.map Msgpck.of_int a.vsn)); ] let decode_alive (m : Msgpck.t) : (alive, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | Some (Msgpck.Bytes s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_vsn () = match List.assoc_opt (Msgpck.String "Vsn") fields with | Some (Msgpck.List vs) -> Ok (List.filter_map (function | Msgpck.Int i -> Some i | Msgpck.Int32 i -> Some (Int32.to_int i) | _ -> None) vs) | _ -> Ok [] in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* addr = get_string "Addr" in let* port = get_int "Port" in let* meta = match get_string "Meta" with Ok m -> Ok m | Error _ -> Ok "" in let* vsn = get_vsn () in Ok { incarnation; node; addr; port; meta; vsn } | _ -> Error "expected map for alive" let encode_dead (d : dead) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Incarnation", Msgpck.of_int d.incarnation); (Msgpck.String "Node", Msgpck.String d.node); (Msgpck.String "From", Msgpck.String d.from); ] let decode_dead (m : Msgpck.t) : (dead, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | Some (Msgpck.Uint32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_string key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* incarnation = get_int "Incarnation" in let* node = get_string "Node" in let* from = get_string "From" in Ok ({ incarnation; node; from } : dead) | _ -> Error "expected map for dead" let encode_compress (c : compress) : Msgpck.t = Msgpck.Map [ (Msgpck.String "Algo", Msgpck.of_int c.algo); (Msgpck.String "Buf", Msgpck.String c.buf); ] let decode_compress (m : Msgpck.t) : (compress, string) result = match m with | Msgpck.Map fields -> let get_int key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Int i) -> Ok i | Some (Msgpck.Int32 i) -> Ok (Int32.to_int i) | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let get_bytes key = match List.assoc_opt (Msgpck.String key) fields with | Some (Msgpck.Bytes s) -> Ok s | Some (Msgpck.String s) -> Ok s | _ -> Error (Printf.sprintf "missing or invalid %s" key) in let ( let* ) = Result.bind in let* algo = get_int "Algo" in let* buf = get_bytes "Buf" in Ok { algo; buf } | _ -> Error "expected map for compress" let encode_msg (msg : protocol_msg) : string = let msg_type, payload = match msg with | Ping p -> (Ping_msg, encode_ping p) | Indirect_ping p -> (Indirect_ping_msg, encode_indirect_ping p) | Ack a -> (Ack_resp_msg, encode_ack a) | Nack n -> (Nack_resp_msg, encode_nack n) | Suspect s -> (Suspect_msg, encode_suspect s) | Alive a -> (Alive_msg, encode_alive a) | Dead d -> (Dead_msg, encode_dead d) | User_data _ -> (User_msg, Msgpck.Nil) | Compound _ -> (Compound_msg, Msgpck.Nil) | Compressed c -> (Compress_msg, encode_compress c) | Err e -> (Err_msg, Msgpck.Map [ (Msgpck.String "Error", Msgpck.String e) ]) in let buf = Buffer.create 256 in Buffer.add_char buf (Char.chr (message_type_to_int msg_type)); (match msg with | User_data data -> Buffer.add_string buf data | _ -> ignore (Msgpck.StringBuf.write buf payload)); Buffer.contents buf let decode_msg (buf : string) : (protocol_msg, Types.decode_error) result = if String.length buf < 1 then Error Types.Truncated_message else let msg_type_byte = Char.code buf.[0] in match message_type_of_int msg_type_byte with | Error n -> Error (Types.Invalid_tag n) | Ok msg_type -> ( let payload = String.sub buf 1 (String.length buf - 1) in match msg_type with | User_msg -> Ok (User_data payload) | Compound_msg -> Ok (Compound []) | _ -> ( let _, msgpack = Msgpck.String.read payload in match msg_type with | Ping_msg -> ( match decode_ping msgpack with | Ok p -> Ok (Ping p) | Error e -> Error (Types.Msgpack_error e)) | Indirect_ping_msg -> ( match decode_indirect_ping msgpack with | Ok p -> Ok (Indirect_ping p) | Error e -> Error (Types.Msgpack_error e)) | Ack_resp_msg -> ( match decode_ack msgpack with | Ok a -> Ok (Ack a) | Error e -> Error (Types.Msgpack_error e)) | Nack_resp_msg -> ( match decode_nack msgpack with | Ok n -> Ok (Nack n) | Error e -> Error (Types.Msgpack_error e)) | Suspect_msg -> ( match decode_suspect msgpack with | Ok s -> Ok (Suspect s) | Error e -> Error (Types.Msgpack_error e)) | Alive_msg -> ( match decode_alive msgpack with | Ok a -> Ok (Alive a) | Error e -> Error (Types.Msgpack_error e)) | Dead_msg -> ( match decode_dead msgpack with | Ok d -> Ok (Dead d) | Error e -> Error (Types.Msgpack_error e)) | Compress_msg -> ( match decode_compress msgpack with | Ok c -> Ok (Compressed c) | Error e -> Error (Types.Msgpack_error e)) | Err_msg -> ( match msgpack with | Msgpck.Map fields -> ( match List.assoc_opt (Msgpck.String "Error") fields with | Some (Msgpck.String e) -> Ok (Err e) | _ -> Ok (Err "unknown error")) | _ -> Ok (Err "unknown error")) | _ -> Error (Types.Invalid_tag msg_type_byte))) let make_compound_msg (msgs : string list) : string = if List.length msgs > 255 then failwith "too many messages for compound" else let buf = Buffer.create 1024 in Buffer.add_char buf (Char.chr (message_type_to_int Compound_msg)); Buffer.add_char buf (Char.chr (List.length msgs)); List.iter (fun m -> let len = String.length m in Buffer.add_char buf (Char.chr ((len lsr 8) land 0xff)); Buffer.add_char buf (Char.chr (len land 0xff))) msgs; List.iter (Buffer.add_string buf) msgs; Buffer.contents buf let decode_compound_msg (buf : string) : (string list * int, Types.decode_error) result = if String.length buf < 1 then Error Types.Truncated_message else let num_parts = Char.code buf.[0] in let header_size = 1 + (num_parts * 2) in if String.length buf < header_size then Error Types.Truncated_message else let lengths = List.init num_parts (fun i -> let hi = Char.code buf.[1 + (i * 2)] in let lo = Char.code buf.[2 + (i * 2)] in (hi lsl 8) lor lo) in let rec extract_parts offset remaining_lens acc trunc = match remaining_lens with | [] -> Ok (List.rev acc, trunc) | len :: rest -> if offset + len > String.length buf then Ok (List.rev acc, List.length remaining_lens) else let part = String.sub buf offset len in extract_parts (offset + len) rest (part :: acc) trunc in extract_parts header_size lengths [] 0 let crc32_table = Array.init 256 (fun i -> let crc = ref (Int32.of_int i) in for _ = 0 to 7 do if Int32.logand !crc 1l = 1l then crc := Int32.logxor (Int32.shift_right_logical !crc 1) 0xEDB88320l else crc := Int32.shift_right_logical !crc 1 done; !crc) let crc32 (data : string) : int32 = let crc = ref 0xFFFFFFFFl in String.iter (fun c -> let byte = Char.code c in let idx = Int32.to_int (Int32.logand (Int32.logxor !crc (Int32.of_int byte)) 0xFFl) in crc := Int32.logxor (Int32.shift_right_logical !crc 8) crc32_table.(idx)) data; Int32.logxor !crc 0xFFFFFFFFl let add_crc (buf : string) : string = let crc = crc32 buf in let header = Bytes.create 5 in Bytes.set header 0 (Char.chr (message_type_to_int Has_crc_msg)); Bytes.set header 1 (Char.chr (Int32.to_int (Int32.shift_right_logical crc 24) land 0xff)); Bytes.set header 2 (Char.chr (Int32.to_int (Int32.shift_right_logical crc 16) land 0xff)); Bytes.set header 3 (Char.chr (Int32.to_int (Int32.shift_right_logical crc 8) land 0xff)); Bytes.set header 4 (Char.chr (Int32.to_int crc land 0xff)); Bytes.to_string header ^ buf let verify_and_strip_crc (buf : string) : (string, Types.decode_error) result = if String.length buf < 5 then Error Types.Truncated_message else if Char.code buf.[0] <> message_type_to_int Has_crc_msg then Ok buf else let expected = Int32.logor (Int32.logor (Int32.shift_left (Int32.of_int (Char.code buf.[1])) 24) (Int32.shift_left (Int32.of_int (Char.code buf.[2])) 16)) (Int32.logor (Int32.shift_left (Int32.of_int (Char.code buf.[3])) 8) (Int32.of_int (Char.code buf.[4]))) in let payload = String.sub buf 5 (String.length buf - 5) in let actual = crc32 payload in if expected = actual then Ok payload else Error Types.Invalid_crc let add_label (label : string) (buf : string) : string = if label = "" then buf else let header = Bytes.create (2 + String.length label) in Bytes.set header 0 (Char.chr (message_type_to_int Has_label_msg)); Bytes.set header 1 (Char.chr (String.length label)); Bytes.blit_string label 0 header 2 (String.length label); Bytes.to_string header ^ buf let strip_label (buf : string) : (string * string, Types.decode_error) result = if String.length buf < 1 then Error Types.Truncated_message else if Char.code buf.[0] <> message_type_to_int Has_label_msg then Ok (buf, "") else if String.length buf < 2 then Error Types.Truncated_message else let label_len = Char.code buf.[1] in if String.length buf < 2 + label_len then Error Types.Truncated_message else let label = String.sub buf 2 label_len in let payload = String.sub buf (2 + label_len) (String.length buf - 2 - label_len) in Ok (payload, label) let encode_internal_msg ~self_name ~self_port (msg : Types.protocol_msg) : string = let wire_msg = Types.msg_to_wire ~self_name ~self_port msg in encode_msg wire_msg let decode_internal_msg ~default_port (buf : string) : (Types.protocol_msg, Types.decode_error) result = match decode_msg buf with | Error e -> Error e | Ok wire_msg -> ( match Types.msg_of_wire ~default_port wire_msg with | Some msg -> Ok msg | None -> Error (Types.Invalid_tag 0)) let encode_packet (packet : Types.packet) ~(buf : Cstruct.t) : (int, [ `Buffer_too_small ]) result = let self_name = packet.cluster in let self_port = 7946 in let primary_encoded = encode_internal_msg ~self_name ~self_port packet.primary in match packet.piggyback with | [] -> let total_len = String.length primary_encoded in if total_len > Cstruct.length buf then Error `Buffer_too_small else begin Cstruct.blit_from_string primary_encoded 0 buf 0 total_len; Ok total_len end | piggyback -> let piggyback_encoded = List.map (encode_internal_msg ~self_name ~self_port) piggyback in let compound = make_compound_msg (primary_encoded :: piggyback_encoded) in let total_len = String.length compound in if total_len > Cstruct.length buf then Error `Buffer_too_small else begin Cstruct.blit_from_string compound 0 buf 0 total_len; Ok total_len end let decode_packet (buf : Cstruct.t) : (Types.packet, Types.decode_error) result = let str = Cstruct.to_string buf in if String.length str < 1 then Error Types.Truncated_message else let msg_type = Char.code str.[0] in if msg_type = message_type_to_int Compound_msg then let payload = String.sub str 1 (String.length str - 1) in match decode_compound_msg payload with | Error e -> Error e | Ok (parts, _truncated) -> ( match parts with | [] -> Error Types.Truncated_message | first :: rest -> ( match decode_internal_msg ~default_port:7946 first with | Error e -> Error e | Ok primary -> let piggyback = List.filter_map (fun p -> match decode_internal_msg ~default_port:7946 p with | Ok m -> Some m | Error _ -> None) rest in Ok { Types.cluster = ""; primary; piggyback })) else match decode_internal_msg ~default_port:7946 str with | Error e -> Error e | Ok primary -> Ok { Types.cluster = ""; primary; piggyback = [] } let encoded_size (msg : Types.protocol_msg) : int = let self_name = "" in let self_port = 7946 in let encoded = encode_internal_msg ~self_name ~self_port msg in String.length encoded + 3