Matrix protocol in OCaml, Eio specialised
at main 447 lines 15 kB view raw
1(** E2EE key management for Matrix. 2 3 This module handles device keys, one-time keys, and key exchange 4 for end-to-end encryption using Olm/Megolm protocols. *) 5 6module Ed25519 = Mirage_crypto_ec.Ed25519 7module X25519 = Mirage_crypto_ec.X25519 8 9(* Base64 encoding/decoding - Matrix uses unpadded base64 *) 10let base64_encode s = 11 Base64.encode_string ~pad:false s 12 13let base64_decode s = 14 Base64.decode ~pad:false s 15 16(* Key types *) 17type ed25519_keypair = { 18 priv : Ed25519.priv; 19 pub : Ed25519.pub; 20} 21 22type curve25519_keypair = { 23 secret : X25519.secret; 24 public : string; 25} 26 27(* Device keys structure *) 28type device_keys = { 29 user_id : Matrix_proto.Id.User_id.t; 30 device_id : string; 31 algorithms : string list; 32 ed25519_key : string; (* base64 *) 33 curve25519_key : string; (* base64 *) 34 signatures : (string * (string * string) list) list; (* user_id -> key_id -> sig *) 35} 36 37(* One-time key *) 38type one_time_key = { 39 key_id : string; 40 key : string; (* base64 curve25519 public key *) 41 signature : string option; (* optional signature *) 42} 43 44(* Fallback key *) 45type fallback_key = { 46 key_id : string; 47 key : string; 48 signature : string option; 49} 50 51(* Key generation *) 52let generate_ed25519 () = 53 let priv, pub = Ed25519.generate () in 54 { priv; pub } 55 56let generate_curve25519 () = 57 let secret, public = X25519.gen_key () in 58 { secret; public } 59 60(* Serialize keys to base64 *) 61let ed25519_pub_to_base64 pub = 62 Ed25519.pub_to_octets pub |> base64_encode 63 64let ed25519_priv_to_base64 priv = 65 Ed25519.priv_to_octets priv |> base64_encode 66 67let curve25519_pub_to_base64 public = 68 base64_encode public 69 70let curve25519_secret_to_base64 secret = 71 X25519.secret_to_octets secret |> base64_encode 72 73(* Deserialize keys from base64 *) 74let ed25519_pub_of_base64 s = 75 match base64_decode s with 76 | Error _ -> Error "Invalid base64" 77 | Ok octets -> 78 match Ed25519.pub_of_octets octets with 79 | Error _ -> Error "Invalid Ed25519 public key" 80 | Ok pub -> Ok pub 81 82let ed25519_priv_of_base64 s = 83 match base64_decode s with 84 | Error _ -> Error "Invalid base64" 85 | Ok octets -> 86 match Ed25519.priv_of_octets octets with 87 | Error _ -> Error "Invalid Ed25519 private key" 88 | Ok priv -> Ok priv 89 90let curve25519_pub_of_base64 s = 91 base64_decode s |> Result.map_error (fun _ -> "Invalid base64") 92 93let curve25519_secret_of_base64 s = 94 match base64_decode s with 95 | Error _ -> Error "Invalid base64" 96 | Ok octets -> 97 match X25519.secret_of_octets octets with 98 | Error _ -> Error "Invalid Curve25519 secret key" 99 | Ok (secret, _public) -> Ok secret 100 101(* Signing *) 102let sign_json ed25519_priv json_str = 103 let signature = Ed25519.sign ~key:ed25519_priv json_str in 104 base64_encode signature 105 106let verify_signature ed25519_pub signature_b64 json_str = 107 match base64_decode signature_b64 with 108 | Error _ -> false 109 | Ok signature -> Ed25519.verify ~key:ed25519_pub signature ~msg:json_str 110 111(* Curve25519 key exchange *) 112let key_exchange ~secret ~their_public = 113 match X25519.key_exchange secret their_public with 114 | Error _ -> Error "Key exchange failed" 115 | Ok shared -> Ok shared 116 117(* Generate a batch of one-time keys *) 118let generate_one_time_keys ~count ?sign_with () : (one_time_key * curve25519_keypair) list = 119 List.init count (fun i -> 120 let kp = generate_curve25519 () in 121 let key_id = Printf.sprintf "AAAAAA%d" i in (* Would use proper ID generation *) 122 let key = curve25519_pub_to_base64 kp.public in 123 let signature = match sign_with with 124 | Some ed_priv -> 125 let to_sign = Printf.sprintf "{\"key\":\"%s\"}" key in 126 Some (sign_json ed_priv to_sign) 127 | None -> None 128 in 129 (({ key_id; key; signature } : one_time_key), kp) 130 ) 131 132(* JSON codecs for key upload/query *) 133 134(* Upload keys request *) 135type upload_keys_request = { 136 device_keys : device_keys_json option; 137 one_time_keys : (string * one_time_key_json) list; 138 fallback_keys : (string * one_time_key_json) list; 139} [@@warning "-69"] 140 141and device_keys_json = { 142 user_id : string; 143 device_id : string; 144 algorithms : string list; 145 keys : (string * string) list; (* key_id -> key *) 146 signatures : (string * (string * string) list) list; 147} [@@warning "-69"] 148 149and one_time_key_json = { 150 key : string; 151 signatures : (string * (string * string) list) list option; 152} [@@warning "-69"] 153 154module StringMap = Map.Make(String) 155 156let string_string_map_jsont : (string * string) list Jsont.t = 157 let map_jsont = Jsont.Object.as_string_map Jsont.string in 158 Jsont.map 159 ~dec:(fun m -> StringMap.bindings m) 160 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq) 161 map_jsont 162 163let signatures_jsont : (string * (string * string) list) list Jsont.t = 164 let inner = Jsont.Object.as_string_map Jsont.string in 165 let outer = Jsont.Object.as_string_map inner in 166 Jsont.map 167 ~dec:(fun m -> 168 StringMap.bindings m 169 |> List.map (fun (k, v) -> (k, StringMap.bindings v))) 170 ~enc:(fun l -> 171 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l 172 |> List.to_seq |> StringMap.of_seq) 173 outer 174 175let device_keys_json_jsont : device_keys_json Jsont.t = 176 Jsont.Object.( 177 map (fun user_id device_id algorithms keys signatures -> 178 { user_id; device_id; algorithms; keys; signatures }) 179 |> mem "user_id" Jsont.string 180 |> mem "device_id" Jsont.string 181 |> mem "algorithms" (Jsont.list Jsont.string) 182 |> mem "keys" string_string_map_jsont 183 |> mem "signatures" signatures_jsont 184 |> finish) 185 186let one_time_key_json_jsont : one_time_key_json Jsont.t = 187 Jsont.Object.( 188 map (fun key signatures -> { key; signatures }) 189 |> mem "key" Jsont.string 190 |> opt_mem "signatures" signatures_jsont ~enc:(fun (t : one_time_key_json) -> t.signatures) 191 |> finish) 192 193let one_time_keys_map_jsont : (string * one_time_key_json) list Jsont.t = 194 let map_jsont = Jsont.Object.as_string_map one_time_key_json_jsont in 195 Jsont.map 196 ~dec:(fun m -> StringMap.bindings m) 197 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq) 198 map_jsont 199 200let upload_keys_request_jsont : upload_keys_request Jsont.t = 201 Jsont.Object.( 202 map (fun device_keys one_time_keys fallback_keys -> 203 { device_keys; one_time_keys; fallback_keys }) 204 |> opt_mem "device_keys" device_keys_json_jsont ~enc:(fun t -> t.device_keys) 205 |> mem "one_time_keys" one_time_keys_map_jsont ~dec_absent:[] 206 |> mem "fallback_keys" one_time_keys_map_jsont ~dec_absent:[] 207 |> finish) 208 209(* Upload keys response *) 210type upload_keys_response = { 211 one_time_key_counts : (string * int) list; 212} 213 214let one_time_key_counts_jsont : (string * int) list Jsont.t = 215 let map_jsont = Jsont.Object.as_string_map Jsont.int in 216 Jsont.map 217 ~dec:(fun m -> StringMap.bindings m) 218 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq) 219 map_jsont 220 221let upload_keys_response_jsont = 222 Jsont.Object.( 223 map (fun one_time_key_counts -> { one_time_key_counts }) 224 |> mem "one_time_key_counts" one_time_key_counts_jsont ~dec_absent:[] 225 |> finish) 226 227(* Upload device keys *) 228let upload_keys client ?device_keys ?(one_time_keys=[]) ?(fallback_keys=[]) () = 229 let request = { device_keys; one_time_keys; fallback_keys } in 230 match Client.encode_body upload_keys_request_jsont request with 231 | Error e -> Error e 232 | Ok body -> 233 match Client.post client ~path:"/keys/upload" ~body () with 234 | Error e -> Error e 235 | Ok body -> Client.decode_response upload_keys_response_jsont body 236 237(* Query keys request/response *) 238type query_keys_request = { 239 timeout : int option; 240 device_keys : (string * string list) list; (* user_id -> device_ids *) 241} [@@warning "-69"] 242 243let device_keys_query_jsont : (string * string list) list Jsont.t = 244 let map_jsont = Jsont.Object.as_string_map (Jsont.list Jsont.string) in 245 Jsont.map 246 ~dec:(fun m -> StringMap.bindings m) 247 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq) 248 map_jsont 249 250let query_keys_request_jsont = 251 Jsont.Object.( 252 map (fun timeout device_keys -> { timeout; device_keys }) 253 |> opt_mem "timeout" Jsont.int ~enc:(fun t -> t.timeout) 254 |> mem "device_keys" device_keys_query_jsont 255 |> finish) 256 257type queried_device_keys = { 258 user_id : string; 259 device_id : string; 260 algorithms : string list; 261 keys : (string * string) list; 262 signatures : (string * (string * string) list) list; 263 unsigned : Jsont.json option; 264} 265 266let queried_device_keys_jsont = 267 Jsont.Object.( 268 map (fun user_id device_id algorithms keys signatures unsigned -> 269 { user_id; device_id; algorithms; keys; signatures; unsigned }) 270 |> mem "user_id" Jsont.string 271 |> mem "device_id" Jsont.string 272 |> mem "algorithms" (Jsont.list Jsont.string) ~dec_absent:[] 273 |> mem "keys" string_string_map_jsont ~dec_absent:[] 274 |> mem "signatures" signatures_jsont ~dec_absent:[] 275 |> opt_mem "unsigned" Jsont.json ~enc:(fun t -> t.unsigned) 276 |> finish) 277 278type query_keys_response = { 279 failures : (string * Jsont.json) list; 280 device_keys : (string * (string * queried_device_keys) list) list; 281} 282 283let device_keys_map_jsont = 284 let inner = Jsont.Object.as_string_map queried_device_keys_jsont in 285 let outer = Jsont.Object.as_string_map inner in 286 Jsont.map 287 ~dec:(fun m -> 288 StringMap.bindings m 289 |> List.map (fun (k, v) -> (k, StringMap.bindings v))) 290 ~enc:(fun l -> 291 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l 292 |> List.to_seq |> StringMap.of_seq) 293 outer 294 295let failures_jsont = 296 let map_jsont = Jsont.Object.as_string_map Jsont.json in 297 Jsont.map 298 ~dec:(fun m -> StringMap.bindings m) 299 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq) 300 map_jsont 301 302let query_keys_response_jsont = 303 Jsont.Object.( 304 map (fun failures device_keys -> { failures; device_keys }) 305 |> mem "failures" failures_jsont ~dec_absent:[] 306 |> mem "device_keys" device_keys_map_jsont ~dec_absent:[] 307 |> finish) 308 309(* Query device keys *) 310let query_keys client ?timeout ~users () = 311 let device_keys = List.map (fun (user_id, device_ids) -> 312 (Matrix_proto.Id.User_id.to_string user_id, device_ids) 313 ) users in 314 let request = { timeout; device_keys } in 315 match Client.encode_body query_keys_request_jsont request with 316 | Error e -> Error e 317 | Ok body -> 318 match Client.post client ~path:"/keys/query" ~body () with 319 | Error e -> Error e 320 | Ok body -> Client.decode_response query_keys_response_jsont body 321 322(* Claim one-time keys request/response *) 323type claim_keys_request = { 324 timeout : int option; 325 one_time_keys : (string * (string * string) list) list; (* user_id -> device_id -> algorithm *) 326} [@@warning "-69"] 327 328let one_time_keys_claim_jsont = 329 let inner = Jsont.Object.as_string_map Jsont.string in 330 let outer = Jsont.Object.as_string_map inner in 331 Jsont.map 332 ~dec:(fun m -> 333 StringMap.bindings m 334 |> List.map (fun (k, v) -> (k, StringMap.bindings v))) 335 ~enc:(fun l -> 336 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l 337 |> List.to_seq |> StringMap.of_seq) 338 outer 339 340let claim_keys_request_jsont = 341 Jsont.Object.( 342 map (fun timeout one_time_keys -> { timeout; one_time_keys }) 343 |> opt_mem "timeout" Jsont.int ~enc:(fun t -> t.timeout) 344 |> mem "one_time_keys" one_time_keys_claim_jsont 345 |> finish) 346 347type claim_keys_response = { 348 failures : (string * Jsont.json) list; 349 one_time_keys : (string * (string * (string * one_time_key_json) list) list) list; 350} 351 352let claimed_keys_map_jsont = 353 let key_map = Jsont.Object.as_string_map one_time_key_json_jsont in 354 let device_map = Jsont.Object.as_string_map key_map in 355 let user_map = Jsont.Object.as_string_map device_map in 356 Jsont.map 357 ~dec:(fun m -> 358 StringMap.bindings m 359 |> List.map (fun (user, devices) -> 360 (user, StringMap.bindings devices 361 |> List.map (fun (dev, keys) -> (dev, StringMap.bindings keys))))) 362 ~enc:(fun l -> 363 List.map (fun (user, devices) -> 364 (user, List.map (fun (dev, keys) -> 365 (dev, List.to_seq keys |> StringMap.of_seq)) devices 366 |> List.to_seq |> StringMap.of_seq)) l 367 |> List.to_seq |> StringMap.of_seq) 368 user_map 369 370let claim_keys_response_jsont = 371 Jsont.Object.( 372 map (fun failures one_time_keys -> { failures; one_time_keys }) 373 |> mem "failures" failures_jsont ~dec_absent:[] 374 |> mem "one_time_keys" claimed_keys_map_jsont ~dec_absent:[] 375 |> finish) 376 377(* Claim one-time keys for Olm session establishment *) 378let claim_keys client ?timeout ~keys () = 379 let one_time_keys = List.map (fun (user_id, device_keys) -> 380 let user_str = Matrix_proto.Id.User_id.to_string user_id in 381 let device_map = List.map (fun (device_id, algorithm) -> 382 (device_id, algorithm) 383 ) device_keys in 384 (user_str, device_map) 385 ) keys in 386 let request = { timeout; one_time_keys } in 387 match Client.encode_body claim_keys_request_jsont request with 388 | Error e -> Error e 389 | Ok body -> 390 match Client.post client ~path:"/keys/claim" ~body () with 391 | Error e -> Error e 392 | Ok body -> Client.decode_response claim_keys_response_jsont body 393 394(* Key changes tracking *) 395type key_changes_response = { 396 changed : string list; 397 left : string list; 398} 399 400let key_changes_response_jsont = 401 Jsont.Object.( 402 map (fun changed left -> { changed; left }) 403 |> mem "changed" (Jsont.list Jsont.string) ~dec_absent:[] 404 |> mem "left" (Jsont.list Jsont.string) ~dec_absent:[] 405 |> finish) 406 407let get_key_changes client ~from ~to_ = 408 let query = [("from", from); ("to", to_)] in 409 match Client.get client ~path:"/keys/changes" ~query () with 410 | Error e -> Error e 411 | Ok body -> Client.decode_response key_changes_response_jsont body 412 413(* Helper to create signed device keys for upload *) 414let create_device_keys ~user_id ~device_id ~ed25519_keypair ~curve25519_keypair = 415 let user_str = Matrix_proto.Id.User_id.to_string user_id in 416 let ed25519_pub = ed25519_pub_to_base64 ed25519_keypair.pub in 417 let curve25519_pub = curve25519_pub_to_base64 curve25519_keypair.public in 418 let keys = [ 419 (Printf.sprintf "ed25519:%s" device_id, ed25519_pub); 420 (Printf.sprintf "curve25519:%s" device_id, curve25519_pub); 421 ] in 422 let algorithms = [ 423 "m.olm.v1.curve25519-aes-sha2-256"; 424 "m.megolm.v1.aes-sha2-256"; 425 ] in 426 (* Create unsigned JSON for signing *) 427 let unsigned_json = { 428 user_id = user_str; 429 device_id; 430 algorithms; 431 keys; 432 signatures = []; 433 } in 434 match Client.encode_body device_keys_json_jsont unsigned_json with 435 | Error _ -> Error "Failed to encode device keys" 436 | Ok json_str -> 437 let signature = sign_json ed25519_keypair.priv json_str in 438 let signatures = [ 439 (user_str, [(Printf.sprintf "ed25519:%s" device_id, signature)]) 440 ] in 441 Ok { 442 user_id = user_str; 443 device_id; 444 algorithms; 445 keys; 446 signatures; 447 }