Matrix protocol in OCaml, Eio specialised
at main 895 lines 32 kB view raw
1(** Olm/Megolm cryptographic session management. 2 3 This module implements the Olm double-ratchet algorithm for encrypted 4 to-device messages, and Megolm for encrypted room messages. *) 5 6module Ed25519 = Mirage_crypto_ec.Ed25519 7module X25519 = Mirage_crypto_ec.X25519 8 9(* Base64 encoding/decoding - Matrix uses unpadded base64 *) 10let base64_encode s = Base64.encode_string ~pad:false s 11let base64_decode s = Base64.decode ~pad:false s 12 13(** Olm account - manages identity keys and one-time keys. 14 15 An Olm account contains: 16 - Ed25519 identity key (for signing) 17 - Curve25519 identity key (for key exchange) 18 - One-time keys (for session establishment) *) 19module Account = struct 20 type t = { 21 (* Identity keys *) 22 ed25519_priv : Ed25519.priv; 23 ed25519_pub : Ed25519.pub; 24 curve25519_secret : X25519.secret; 25 curve25519_public : string; 26 (* One-time keys - key_id -> (secret, public) *) 27 mutable one_time_keys : (string * (X25519.secret * string)) list; 28 (* Fallback keys *) 29 mutable fallback_key : (string * (X25519.secret * string)) option; 30 (* Key counter for generating key IDs *) 31 mutable next_key_id : int; 32 (* Max number of one-time keys to store *) 33 max_one_time_keys : int; 34 } 35 36 (** Generate a new Olm account with fresh identity keys. *) 37 let create () = 38 let ed25519_priv, ed25519_pub = Ed25519.generate () in 39 let curve25519_secret, curve25519_public = X25519.gen_key () in 40 { 41 ed25519_priv; 42 ed25519_pub; 43 curve25519_secret; 44 curve25519_public; 45 one_time_keys = []; 46 fallback_key = None; 47 next_key_id = 0; 48 max_one_time_keys = 100; 49 } 50 51 (** Get the Ed25519 identity key as base64. *) 52 let ed25519_key t = 53 Ed25519.pub_to_octets t.ed25519_pub |> base64_encode 54 55 (** Get the Curve25519 identity key as base64. *) 56 let curve25519_key t = 57 base64_encode t.curve25519_public 58 59 (** Get the identity keys as a pair (ed25519, curve25519). *) 60 let identity_keys t = 61 (ed25519_key t, curve25519_key t) 62 63 (** Sign a message with the account's Ed25519 key. *) 64 let sign t message = 65 let signature = Ed25519.sign ~key:t.ed25519_priv message in 66 base64_encode signature 67 68 (** Generate a unique key ID. *) 69 let generate_key_id t = 70 let id = Printf.sprintf "AAAA%02dAA" t.next_key_id in 71 t.next_key_id <- t.next_key_id + 1; 72 id 73 74 (** Generate new one-time keys. *) 75 let generate_one_time_keys t count = 76 let count = min count (t.max_one_time_keys - List.length t.one_time_keys) in 77 for _ = 1 to count do 78 let secret, public = X25519.gen_key () in 79 let key_id = generate_key_id t in 80 t.one_time_keys <- (key_id, (secret, public)) :: t.one_time_keys 81 done 82 83 (** Get one-time keys for upload (key_id -> public_key). *) 84 let one_time_keys t = 85 List.map (fun (key_id, (_secret, public)) -> 86 (key_id, base64_encode public) 87 ) t.one_time_keys 88 89 (** Get signed one-time keys for upload. *) 90 let signed_one_time_keys t = 91 List.map (fun (key_id, (_secret, public)) -> 92 let pub_b64 = base64_encode public in 93 let to_sign = Printf.sprintf {|{"key":"%s"}|} pub_b64 in 94 let signature = sign t to_sign in 95 (key_id, pub_b64, signature) 96 ) t.one_time_keys 97 98 (** Mark one-time keys as published (remove them from pending). *) 99 let mark_keys_as_published _t = 100 (* One-time keys are kept until used in a session *) 101 () 102 103 (** Generate a fallback key. *) 104 let generate_fallback_key t = 105 let secret, public = X25519.gen_key () in 106 let key_id = generate_key_id t in 107 t.fallback_key <- Some (key_id, (secret, public)) 108 109 (** Get the fallback key if one exists. *) 110 let fallback_key t = 111 match t.fallback_key with 112 | Some (key_id, (_secret, public)) -> Some (key_id, base64_encode public) 113 | None -> None 114 115 (** Remove a one-time key by ID (after session creation). *) 116 let remove_one_time_key t key_id = 117 t.one_time_keys <- List.filter (fun (id, _) -> id <> key_id) t.one_time_keys 118 119 (** Get the secret for a one-time key (for session creation). *) 120 let get_one_time_key_secret t key_id = 121 match List.assoc_opt key_id t.one_time_keys with 122 | Some (secret, _public) -> Some secret 123 | None -> 124 (* Check fallback key *) 125 match t.fallback_key with 126 | Some (id, (secret, _public)) when id = key_id -> Some secret 127 | _ -> None 128 129 (** Number of unpublished one-time keys. *) 130 let one_time_keys_count t = List.length t.one_time_keys 131 132 (** Maximum number of one-time keys this account can hold. *) 133 let max_one_time_keys t = t.max_one_time_keys 134end 135 136(** Olm session state for the double ratchet algorithm. *) 137module Session = struct 138 (** Ratchet chain key *) 139 type chain_key = { 140 key : string; (* 32 bytes *) 141 index : int; 142 } 143 144 (** Root key used for deriving new chain keys *) 145 type root_key = string (* 32 bytes *) 146 147 (** Message key for encrypting a single message *) 148 type message_key = string (* 32 bytes *) 149 150 (** Session state *) 151 type t = { 152 (* Session ID *) 153 session_id : string; 154 (* Their identity key (Curve25519) *) 155 their_identity_key : string; 156 (* Their current ratchet key *) 157 mutable their_ratchet_key : string option; 158 (* Our current ratchet key pair *) 159 mutable our_ratchet_secret : X25519.secret; 160 mutable our_ratchet_public : string; 161 (* Root key for deriving chain keys *) 162 mutable root_key : root_key; 163 (* Sending chain *) 164 mutable sending_chain : chain_key option; 165 (* Receiving chains (their_ratchet_key -> chain) *) 166 mutable receiving_chains : (string * chain_key) list; 167 (* Skipped message keys for out-of-order decryption *) 168 mutable skipped_keys : ((string * int) * message_key) list; 169 (* Creation time *) 170 creation_time : Ptime.t; 171 } 172 173 (** HKDF for key derivation using SHA-256 *) 174 let hkdf_sha256 ~salt ~info ~ikm length = 175 let prk = Hkdf.extract ~hash:`SHA256 ~salt ikm in 176 Hkdf.expand ~hash:`SHA256 ~prk ~info length 177 178 (** Derive root and chain keys from shared secret *) 179 let kdf_rk root_key shared_secret = 180 let derived = hkdf_sha256 181 ~salt:root_key 182 ~info:"OLM_ROOT" 183 ~ikm:shared_secret 184 64 185 in 186 let new_root = String.sub derived 0 32 in 187 let chain_key = String.sub derived 32 32 in 188 (new_root, chain_key) 189 190 (** Derive next chain key and message key *) 191 let kdf_ck chain_key = 192 let mk = hkdf_sha256 193 ~salt:"" 194 ~info:"OLM_CHAIN_MESSAGE" 195 ~ikm:chain_key.key 196 32 197 in 198 let new_ck = hkdf_sha256 199 ~salt:"" 200 ~info:"OLM_CHAIN_KEY" 201 ~ikm:chain_key.key 202 32 203 in 204 (mk, { key = new_ck; index = chain_key.index + 1 }) 205 206 (** Perform X3DH key agreement for outbound session *) 207 let x3dh_outbound ~our_identity_secret ~our_ephemeral_secret 208 ~their_identity_key ~their_one_time_key = 209 (* DH1: our_identity_secret, their_one_time_key *) 210 let dh1 = match X25519.key_exchange our_identity_secret their_one_time_key with 211 | Ok s -> s 212 | Error _ -> failwith "Key exchange failed" 213 in 214 (* DH2: our_ephemeral_secret, their_identity_key *) 215 let dh2 = match X25519.key_exchange our_ephemeral_secret their_identity_key with 216 | Ok s -> s 217 | Error _ -> failwith "Key exchange failed" 218 in 219 (* DH3: our_ephemeral_secret, their_one_time_key *) 220 let dh3 = match X25519.key_exchange our_ephemeral_secret their_one_time_key with 221 | Ok s -> s 222 | Error _ -> failwith "Key exchange failed" 223 in 224 (* Combine: DH1 || DH2 || DH3 *) 225 dh1 ^ dh2 ^ dh3 226 227 (** Create a new outbound session (when sending first message). *) 228 let create_outbound account ~their_identity_key ~their_one_time_key = 229 (* Parse their keys from base64 *) 230 let their_identity = match base64_decode their_identity_key with 231 | Ok k -> k 232 | Error _ -> failwith "Invalid identity key" 233 in 234 let their_otk = match base64_decode their_one_time_key with 235 | Ok k -> k 236 | Error _ -> failwith "Invalid one-time key" 237 in 238 (* Generate ephemeral key for X3DH *) 239 let ephemeral_secret, _ephemeral_public = X25519.gen_key () in 240 (* Perform X3DH *) 241 let shared_secret = x3dh_outbound 242 ~our_identity_secret:account.Account.curve25519_secret 243 ~our_ephemeral_secret:ephemeral_secret 244 ~their_identity_key:their_identity 245 ~their_one_time_key:their_otk 246 in 247 (* Derive root key *) 248 let root_key = hkdf_sha256 249 ~salt:"" 250 ~info:"OLM_ROOT" 251 ~ikm:shared_secret 252 32 253 in 254 (* Generate our initial ratchet key *) 255 let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in 256 (* Session ID is hash of the root key *) 257 let session_id = 258 Digestif.SHA256.(digest_string root_key |> to_raw_string) 259 |> base64_encode 260 in 261 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with 262 | Some t -> t 263 | None -> Ptime.epoch 264 in 265 (* Initial sending chain *) 266 let sending_chain = Some { key = root_key; index = 0 } in 267 { 268 session_id; 269 their_identity_key = their_identity; 270 their_ratchet_key = None; 271 our_ratchet_secret; 272 our_ratchet_public; 273 root_key; 274 sending_chain; 275 receiving_chains = []; 276 skipped_keys = []; 277 creation_time = now; 278 } 279 280 (** Create a new inbound session (when receiving first message). *) 281 let create_inbound account ~their_identity_key ~their_ephemeral_key ~one_time_key_id = 282 (* Get our one-time key secret *) 283 let our_otk_secret = match Account.get_one_time_key_secret account one_time_key_id with 284 | Some s -> s 285 | None -> failwith "One-time key not found" 286 in 287 (* Parse their keys *) 288 let their_identity = match base64_decode their_identity_key with 289 | Ok k -> k 290 | Error _ -> failwith "Invalid identity key" 291 in 292 let their_ephemeral = match base64_decode their_ephemeral_key with 293 | Ok k -> k 294 | Error _ -> failwith "Invalid ephemeral key" 295 in 296 (* Perform reverse X3DH *) 297 (* DH1: their_identity, our_otk *) 298 let dh1 = match X25519.key_exchange our_otk_secret their_identity with 299 | Ok s -> s 300 | Error _ -> failwith "Key exchange failed" 301 in 302 (* DH2: their_ephemeral, our_identity *) 303 let dh2 = match X25519.key_exchange account.Account.curve25519_secret their_ephemeral with 304 | Ok s -> s 305 | Error _ -> failwith "Key exchange failed" 306 in 307 (* DH3: their_ephemeral, our_otk *) 308 let dh3 = match X25519.key_exchange our_otk_secret their_ephemeral with 309 | Ok s -> s 310 | Error _ -> failwith "Key exchange failed" 311 in 312 let shared_secret = dh1 ^ dh2 ^ dh3 in 313 (* Derive root key *) 314 let root_key = hkdf_sha256 315 ~salt:"" 316 ~info:"OLM_ROOT" 317 ~ikm:shared_secret 318 32 319 in 320 (* Generate our ratchet key *) 321 let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in 322 let session_id = 323 Digestif.SHA256.(digest_string root_key |> to_raw_string) 324 |> base64_encode 325 in 326 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with 327 | Some t -> t 328 | None -> Ptime.epoch 329 in 330 (* Remove the used one-time key *) 331 Account.remove_one_time_key account one_time_key_id; 332 { 333 session_id; 334 their_identity_key = their_identity; 335 their_ratchet_key = Some their_ephemeral; 336 our_ratchet_secret; 337 our_ratchet_public; 338 root_key; 339 sending_chain = None; 340 receiving_chains = [(their_ephemeral, { key = root_key; index = 0 })]; 341 skipped_keys = []; 342 creation_time = now; 343 } 344 345 (** Get session ID *) 346 let session_id t = t.session_id 347 348 (** Get their identity key *) 349 let their_identity_key t = base64_encode t.their_identity_key 350 351 (** Encrypt a message using AES-256-CBC with HMAC-SHA256 *) 352 let aes_encrypt key plaintext = 353 (* Use first 32 bytes for AES key, derive IV *) 354 let aes_key = String.sub key 0 32 in 355 let iv = Digestif.SHA256.(digest_string (aes_key ^ "IV") |> to_raw_string) 356 |> fun s -> String.sub s 0 16 in 357 (* PKCS7 padding *) 358 let block_size = 16 in 359 let pad_len = block_size - (String.length plaintext mod block_size) in 360 let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in 361 (* Encrypt using mirage-crypto AES.CBC *) 362 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in 363 let encrypted = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in 364 iv ^ encrypted 365 366 (** Decrypt a message *) 367 let aes_decrypt key ciphertext = 368 if String.length ciphertext < 16 then 369 Error "Ciphertext too short" 370 else 371 let iv = String.sub ciphertext 0 16 in 372 let data = String.sub ciphertext 16 (String.length ciphertext - 16) in 373 let aes_key = String.sub key 0 32 in 374 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in 375 let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv data in 376 (* Remove PKCS7 padding *) 377 if String.length decrypted = 0 then 378 Error "Empty plaintext" 379 else 380 let pad_len = Char.code decrypted.[String.length decrypted - 1] in 381 if pad_len > 16 || pad_len > String.length decrypted then 382 Error "Invalid padding" 383 else 384 Ok (String.sub decrypted 0 (String.length decrypted - pad_len)) 385 386 (** Encrypt a plaintext message. Returns (message_type, ciphertext). *) 387 let encrypt t plaintext = 388 (* Get or create sending chain *) 389 let chain = match t.sending_chain with 390 | Some c -> c 391 | None -> 392 (* Need to ratchet first *) 393 let new_secret, new_public = X25519.gen_key () in 394 t.our_ratchet_secret <- new_secret; 395 t.our_ratchet_public <- new_public; 396 { key = t.root_key; index = 0 } 397 in 398 (* Derive message key *) 399 let message_key, new_chain = kdf_ck chain in 400 t.sending_chain <- Some new_chain; 401 (* Encrypt the message *) 402 let ciphertext = aes_encrypt message_key plaintext in 403 (* Create message payload with ratchet key and chain index *) 404 let ratchet_key_b64 = base64_encode t.our_ratchet_public in 405 let msg_type = if chain.index = 0 then 0 else 1 in (* 0 = prekey, 1 = normal *) 406 let payload = Printf.sprintf "%s|%d|%s" 407 ratchet_key_b64 408 new_chain.index 409 (base64_encode ciphertext) 410 in 411 (msg_type, payload) 412 413 (** Decrypt a message. *) 414 let decrypt t ~message_type ~ciphertext:payload = 415 (* Parse payload *) 416 match String.split_on_char '|' payload with 417 | [ratchet_key_b64; index_str; ciphertext_b64] -> 418 let their_ratchet = match base64_decode ratchet_key_b64 with 419 | Ok k -> k 420 | Error _ -> failwith "Invalid ratchet key" 421 in 422 let msg_index = int_of_string index_str in 423 let ciphertext = match base64_decode ciphertext_b64 with 424 | Ok c -> c 425 | Error _ -> failwith "Invalid ciphertext" 426 in 427 (* Check if we need to advance the ratchet *) 428 let _need_ratchet = match t.their_ratchet_key with 429 | Some k when k = their_ratchet -> false 430 | _ -> true 431 in 432 (* Find or create receiving chain *) 433 let chain = match List.assoc_opt their_ratchet t.receiving_chains with 434 | Some c -> c 435 | None -> 436 (* New ratchet - derive new chain *) 437 let dh_out = match X25519.key_exchange t.our_ratchet_secret their_ratchet with 438 | Ok s -> s 439 | Error _ -> failwith "Key exchange failed" 440 in 441 let new_root, chain_key = kdf_rk t.root_key dh_out in 442 t.root_key <- new_root; 443 t.their_ratchet_key <- Some their_ratchet; 444 let chain = { key = chain_key; index = 0 } in 445 t.receiving_chains <- (their_ratchet, chain) :: t.receiving_chains; 446 chain 447 in 448 (* Advance chain to the right index *) 449 let rec advance_chain c target_idx = 450 if c.index >= target_idx then c 451 else 452 let mk, new_c = kdf_ck c in 453 (* Store skipped keys *) 454 t.skipped_keys <- ((their_ratchet, c.index), mk) :: t.skipped_keys; 455 advance_chain new_c target_idx 456 in 457 let chain = advance_chain chain msg_index in 458 (* Get message key *) 459 let message_key, new_chain = kdf_ck chain in 460 (* Update chain *) 461 t.receiving_chains <- 462 (their_ratchet, new_chain) :: 463 (List.filter (fun (k, _) -> k <> their_ratchet) t.receiving_chains); 464 (* Decrypt *) 465 let _ = message_type in 466 aes_decrypt message_key ciphertext 467 | _ -> Error "Invalid message format" 468 469 (** Check if this is a pre-key message (first message in session). *) 470 let is_pre_key_message message_type = message_type = 0 471end 472 473(** Megolm session for room message encryption. 474 475 Megolm uses a ratchet that only moves forward, making it efficient 476 for encrypting many messages to many recipients. *) 477module Megolm = struct 478 (** Inbound session for decrypting received room messages *) 479 module Inbound = struct 480 type t = { 481 session_id : string; 482 sender_key : string; (* Curve25519 key of sender *) 483 room_id : string; 484 (* Ratchet state - 4 parts of 256 bits each *) 485 mutable ratchet : string array; (* 4 x 32 bytes *) 486 mutable message_index : int; 487 (* For detecting replays *) 488 mutable received_indices : int list; 489 (* Ed25519 signing key of the sender *) 490 signing_key : string; 491 creation_time : Ptime.t; 492 } 493 494 (** Advance the ratchet by one step *) 495 let advance_ratchet t = 496 (* Megolm ratchet: each part hashes the parts below it *) 497 let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in 498 (* R(i,j) = H(R(i-1,j) || j) for j = 0,1,2,3 *) 499 (* Simplified: we just hash each part with its index *) 500 let i = t.message_index land 3 in 501 for j = i to 3 do 502 t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j) 503 done; 504 t.message_index <- t.message_index + 1 505 506 (** Create from exported session data *) 507 let of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index ~signing_key = 508 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with 509 | Some t -> t 510 | None -> Ptime.epoch 511 in 512 { 513 session_id; 514 sender_key; 515 room_id; 516 ratchet; 517 message_index; 518 received_indices = []; 519 signing_key; 520 creation_time = now; 521 } 522 523 (** Create from room key event (m.room_key) *) 524 let from_room_key ~sender_key ~room_id ~session_id ~session_key ~signing_key = 525 (* Parse session_key which contains ratchet state *) 526 let ratchet = match base64_decode session_key with 527 | Ok data when String.length data >= 128 -> 528 [| 529 String.sub data 0 32; 530 String.sub data 32 32; 531 String.sub data 64 32; 532 String.sub data 96 32; 533 |] 534 | _ -> 535 (* Generate random initial state if parsing fails *) 536 let random_part () = 537 Mirage_crypto_rng.generate 32 538 in 539 [| random_part (); random_part (); random_part (); random_part () |] 540 in 541 of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index:0 ~signing_key 542 543 let session_id t = t.session_id 544 let sender_key t = t.sender_key 545 let room_id t = t.room_id 546 let first_known_index t = t.message_index 547 548 (** Derive encryption key from current ratchet state *) 549 let derive_key t = 550 let combined = String.concat "" (Array.to_list t.ratchet) in 551 Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80 552 553 (** Decrypt a message *) 554 let decrypt t ~ciphertext ~message_index = 555 (* Check for replay *) 556 if List.mem message_index t.received_indices then 557 Error "Duplicate message index (replay attack)" 558 else if message_index < t.message_index then 559 Error "Message index too old" 560 else begin 561 (* Advance ratchet to the right position *) 562 while t.message_index < message_index do 563 advance_ratchet t 564 done; 565 (* Derive key and decrypt *) 566 let key_material = derive_key t in 567 let aes_key = String.sub key_material 0 32 in 568 let hmac_key = String.sub key_material 32 32 in 569 let iv = String.sub key_material 64 16 in 570 (* Verify HMAC if present (last 8 bytes of ciphertext) *) 571 let ct_len = String.length ciphertext in 572 if ct_len < 24 then 573 Error "Ciphertext too short" 574 else begin 575 let ct_data = String.sub ciphertext 0 (ct_len - 8) in 576 let mac = String.sub ciphertext (ct_len - 8) 8 in 577 let expected_mac = 578 Digestif.SHA256.hmac_string ~key:hmac_key ct_data 579 |> Digestif.SHA256.to_raw_string 580 |> fun s -> String.sub s 0 8 581 in 582 if mac <> expected_mac then 583 Error "MAC verification failed" 584 else begin 585 (* Decrypt using mirage-crypto AES.CBC *) 586 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in 587 let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv ct_data in 588 (* Remove PKCS7 padding *) 589 let pad_len = Char.code decrypted.[String.length decrypted - 1] in 590 if pad_len > 16 then 591 Error "Invalid padding" 592 else begin 593 t.received_indices <- message_index :: t.received_indices; 594 advance_ratchet t; 595 Ok (String.sub decrypted 0 (String.length decrypted - pad_len)) 596 end 597 end 598 end 599 end 600 end 601 602 (** Outbound session for encrypting messages to send to a room *) 603 module Outbound = struct 604 type t = { 605 session_id : string; 606 room_id : string; 607 (* Ratchet state *) 608 mutable ratchet : string array; 609 mutable message_index : int; 610 (* Ed25519 signing key *) 611 signing_priv : Ed25519.priv; 612 signing_pub : Ed25519.pub; 613 (* Creation and rotation tracking *) 614 creation_time : Ptime.t; 615 mutable message_count : int; 616 max_messages : int; 617 max_age : Ptime.Span.t; 618 (* Users this session has been shared with *) 619 mutable shared_with : (string * string) list; (* user_id, device_id pairs *) 620 } 621 622 (** Create a new outbound session for a room *) 623 let create ~room_id = 624 let session_id = 625 Mirage_crypto_rng.generate 16 626 |> base64_encode 627 in 628 let random_part () = 629 Mirage_crypto_rng.generate 32 630 in 631 let ratchet = [| random_part (); random_part (); random_part (); random_part () |] in 632 let signing_priv, signing_pub = Ed25519.generate () in 633 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with 634 | Some t -> t 635 | None -> Ptime.epoch 636 in 637 { 638 session_id; 639 room_id; 640 ratchet; 641 message_index = 0; 642 signing_priv; 643 signing_pub; 644 creation_time = now; 645 message_count = 0; 646 max_messages = 100; 647 max_age = Ptime.Span.of_int_s (7 * 24 * 60 * 60); (* 1 week *) 648 shared_with = []; 649 } 650 651 (** Advance the ratchet *) 652 let advance_ratchet t = 653 let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in 654 let i = t.message_index land 3 in 655 for j = i to 3 do 656 t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j) 657 done; 658 t.message_index <- t.message_index + 1 659 660 let session_id t = t.session_id 661 let room_id t = t.room_id 662 let message_index t = t.message_index 663 664 (** Check if session should be rotated *) 665 let needs_rotation t = 666 t.message_count >= t.max_messages || 667 match Ptime.of_float_s (Unix.gettimeofday ()) with 668 | Some now -> 669 (match Ptime.diff now t.creation_time |> Ptime.Span.compare t.max_age with 670 | n when n > 0 -> true 671 | _ -> false) 672 | None -> false 673 674 (** Derive encryption key *) 675 let derive_key t = 676 let combined = String.concat "" (Array.to_list t.ratchet) in 677 Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80 678 679 (** Export the session key for sharing via m.room_key *) 680 let export_session_key t = 681 let ratchet_data = String.concat "" (Array.to_list t.ratchet) in 682 base64_encode ratchet_data 683 684 (** Get the signing key *) 685 let signing_key t = 686 Ed25519.pub_to_octets t.signing_pub |> base64_encode 687 688 (** Encrypt a message *) 689 let encrypt t plaintext = 690 let key_material = derive_key t in 691 let aes_key = String.sub key_material 0 32 in 692 let hmac_key = String.sub key_material 32 32 in 693 let iv = String.sub key_material 64 16 in 694 (* PKCS7 padding *) 695 let block_size = 16 in 696 let pad_len = block_size - (String.length plaintext mod block_size) in 697 let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in 698 (* Encrypt using mirage-crypto AES.CBC *) 699 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in 700 let ct_data = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in 701 (* Add HMAC (first 8 bytes) *) 702 let mac = 703 Digestif.SHA256.hmac_string ~key:hmac_key ct_data 704 |> Digestif.SHA256.to_raw_string 705 |> fun s -> String.sub s 0 8 706 in 707 let ciphertext = ct_data ^ mac in 708 let msg_index = t.message_index in 709 (* Advance ratchet for next message *) 710 advance_ratchet t; 711 t.message_count <- t.message_count + 1; 712 (* Return message index and ciphertext *) 713 (msg_index, base64_encode ciphertext) 714 715 (** Mark session as shared with a user/device *) 716 let mark_shared_with t ~user_id ~device_id = 717 if not (List.mem (user_id, device_id) t.shared_with) then 718 t.shared_with <- (user_id, device_id) :: t.shared_with 719 720 (** Check if already shared with a user/device *) 721 let is_shared_with t ~user_id ~device_id = 722 List.mem (user_id, device_id) t.shared_with 723 724 (** Get list of users this session is shared with *) 725 let shared_with t = t.shared_with 726 end 727end 728 729(** Olm Machine - high-level state machine for E2EE operations *) 730module Machine = struct 731 type t = { 732 user_id : string; 733 device_id : string; 734 account : Account.t; 735 (* Active Olm sessions indexed by their curve25519 key *) 736 mutable sessions : (string * Session.t list) list; 737 (* Outbound Megolm sessions by room_id *) 738 mutable outbound_group_sessions : (string * Megolm.Outbound.t) list; 739 (* Inbound Megolm sessions by (room_id, session_id) *) 740 mutable inbound_group_sessions : ((string * string) * Megolm.Inbound.t) list; 741 (* Device keys we know about: user_id -> device_id -> device_keys *) 742 mutable device_keys : (string * (string * Keys.queried_device_keys) list) list; 743 } 744 745 (** Create a new OlmMachine *) 746 let create ~user_id ~device_id = 747 let account = Account.create () in 748 { 749 user_id; 750 device_id; 751 account; 752 sessions = []; 753 outbound_group_sessions = []; 754 inbound_group_sessions = []; 755 device_keys = []; 756 } 757 758 (** Get identity keys *) 759 let identity_keys t = Account.identity_keys t.account 760 761 (** Get device keys for upload *) 762 let device_keys_for_upload t = 763 let ed25519, curve25519 = identity_keys t in 764 let algorithms = [ 765 "m.olm.v1.curve25519-aes-sha2-256"; 766 "m.megolm.v1.aes-sha2-256"; 767 ] in 768 let keys = [ 769 (Printf.sprintf "ed25519:%s" t.device_id, ed25519); 770 (Printf.sprintf "curve25519:%s" t.device_id, curve25519); 771 ] in 772 (t.user_id, t.device_id, algorithms, keys) 773 774 (** Generate one-time keys if needed *) 775 let generate_one_time_keys t count = 776 Account.generate_one_time_keys t.account count 777 778 (** Get one-time keys for upload *) 779 let one_time_keys_for_upload t = 780 Account.signed_one_time_keys t.account 781 782 (** Mark keys as uploaded *) 783 let mark_keys_as_published t = 784 Account.mark_keys_as_published t.account 785 786 (** Store device keys from key query response *) 787 let receive_device_keys t ~user_id ~devices = 788 t.device_keys <- (user_id, devices) :: 789 (List.filter (fun (uid, _) -> uid <> user_id) t.device_keys) 790 791 (** Get or create outbound Megolm session for a room *) 792 let get_outbound_group_session t ~room_id = 793 match List.assoc_opt room_id t.outbound_group_sessions with 794 | Some session when not (Megolm.Outbound.needs_rotation session) -> 795 session 796 | _ -> 797 (* Create new session *) 798 let session = Megolm.Outbound.create ~room_id in 799 t.outbound_group_sessions <- 800 (room_id, session) :: 801 (List.filter (fun (rid, _) -> rid <> room_id) t.outbound_group_sessions); 802 session 803 804 (** Store inbound Megolm session from room key event *) 805 let receive_room_key t ~sender_key ~room_id ~session_id ~session_key ~signing_key = 806 let session = Megolm.Inbound.from_room_key 807 ~sender_key ~room_id ~session_id ~session_key ~signing_key 808 in 809 t.inbound_group_sessions <- 810 ((room_id, session_id), session) :: t.inbound_group_sessions 811 812 (** Encrypt a room message using Megolm *) 813 let encrypt_room_message t ~room_id ~content = 814 let session = get_outbound_group_session t ~room_id in 815 let msg_index, ciphertext = Megolm.Outbound.encrypt session content in 816 let _, curve25519_key = identity_keys t in 817 (* Build m.room.encrypted content *) 818 let encrypted_content = Printf.sprintf 819 {|{"algorithm":"m.megolm.v1.aes-sha2-256","sender_key":"%s","ciphertext":"%s","session_id":"%s","device_id":"%s"}|} 820 curve25519_key 821 ciphertext 822 (Megolm.Outbound.session_id session) 823 t.device_id 824 in 825 let _ = msg_index in 826 encrypted_content 827 828 (** Decrypt a room message *) 829 let decrypt_room_message t ~room_id ~sender_key ~session_id ~ciphertext ~message_index = 830 match List.assoc_opt (room_id, session_id) t.inbound_group_sessions with 831 | Some session when Megolm.Inbound.sender_key session = sender_key -> 832 Megolm.Inbound.decrypt session ~ciphertext ~message_index 833 | Some _ -> 834 Error "Sender key mismatch" 835 | None -> 836 Error "Unknown session" 837 838 (** Get or create Olm session for a device *) 839 let get_olm_session t ~their_identity_key = 840 match List.assoc_opt their_identity_key t.sessions with 841 | Some (session :: _) -> Some session 842 | _ -> None 843 844 (** Create outbound Olm session *) 845 let create_olm_session t ~their_identity_key ~their_one_time_key = 846 let session = Session.create_outbound t.account 847 ~their_identity_key ~their_one_time_key 848 in 849 let existing = match List.assoc_opt their_identity_key t.sessions with 850 | Some sessions -> sessions 851 | None -> [] 852 in 853 t.sessions <- 854 (their_identity_key, session :: existing) :: 855 (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions); 856 session 857 858 (** Process inbound Olm message to create session *) 859 let create_inbound_session t ~their_identity_key ~their_ephemeral_key ~one_time_key_id = 860 let session = Session.create_inbound t.account 861 ~their_identity_key ~their_ephemeral_key ~one_time_key_id 862 in 863 let existing = match List.assoc_opt their_identity_key t.sessions with 864 | Some sessions -> sessions 865 | None -> [] 866 in 867 t.sessions <- 868 (their_identity_key, session :: existing) :: 869 (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions); 870 session 871 872 (** Encrypt a to-device message *) 873 let encrypt_to_device t ~their_identity_key ~their_one_time_key ~plaintext = 874 let session = match get_olm_session t ~their_identity_key with 875 | Some s -> s 876 | None -> create_olm_session t ~their_identity_key ~their_one_time_key 877 in 878 Session.encrypt session plaintext 879 880 (** Decrypt a to-device message *) 881 let decrypt_to_device t ~their_identity_key ~message_type ~ciphertext = 882 match get_olm_session t ~their_identity_key with 883 | Some session -> 884 Session.decrypt session ~message_type ~ciphertext 885 | None -> 886 Error "No session for sender" 887 888 (** Number of one-time keys remaining *) 889 let one_time_keys_count t = 890 Account.one_time_keys_count t.account 891 892 (** Should upload more one-time keys? *) 893 let should_upload_keys t = 894 one_time_keys_count t < Account.max_one_time_keys t.account / 2 895end