(** Olm/Megolm cryptographic session management. This module implements the Olm double-ratchet algorithm for encrypted to-device messages, and Megolm for encrypted room messages. *) module Ed25519 = Mirage_crypto_ec.Ed25519 module X25519 = Mirage_crypto_ec.X25519 (* Base64 encoding/decoding - Matrix uses unpadded base64 *) let base64_encode s = Base64.encode_string ~pad:false s let base64_decode s = Base64.decode ~pad:false s (** Olm account - manages identity keys and one-time keys. An Olm account contains: - Ed25519 identity key (for signing) - Curve25519 identity key (for key exchange) - One-time keys (for session establishment) *) module Account = struct type t = { (* Identity keys *) ed25519_priv : Ed25519.priv; ed25519_pub : Ed25519.pub; curve25519_secret : X25519.secret; curve25519_public : string; (* One-time keys - key_id -> (secret, public) *) mutable one_time_keys : (string * (X25519.secret * string)) list; (* Fallback keys *) mutable fallback_key : (string * (X25519.secret * string)) option; (* Key counter for generating key IDs *) mutable next_key_id : int; (* Max number of one-time keys to store *) max_one_time_keys : int; } (** Generate a new Olm account with fresh identity keys. *) let create () = let ed25519_priv, ed25519_pub = Ed25519.generate () in let curve25519_secret, curve25519_public = X25519.gen_key () in { ed25519_priv; ed25519_pub; curve25519_secret; curve25519_public; one_time_keys = []; fallback_key = None; next_key_id = 0; max_one_time_keys = 100; } (** Get the Ed25519 identity key as base64. *) let ed25519_key t = Ed25519.pub_to_octets t.ed25519_pub |> base64_encode (** Get the Curve25519 identity key as base64. *) let curve25519_key t = base64_encode t.curve25519_public (** Get the identity keys as a pair (ed25519, curve25519). *) let identity_keys t = (ed25519_key t, curve25519_key t) (** Sign a message with the account's Ed25519 key. *) let sign t message = let signature = Ed25519.sign ~key:t.ed25519_priv message in base64_encode signature (** Generate a unique key ID. *) let generate_key_id t = let id = Printf.sprintf "AAAA%02dAA" t.next_key_id in t.next_key_id <- t.next_key_id + 1; id (** Generate new one-time keys. *) let generate_one_time_keys t count = let count = min count (t.max_one_time_keys - List.length t.one_time_keys) in for _ = 1 to count do let secret, public = X25519.gen_key () in let key_id = generate_key_id t in t.one_time_keys <- (key_id, (secret, public)) :: t.one_time_keys done (** Get one-time keys for upload (key_id -> public_key). *) let one_time_keys t = List.map (fun (key_id, (_secret, public)) -> (key_id, base64_encode public) ) t.one_time_keys (** Get signed one-time keys for upload. *) let signed_one_time_keys t = List.map (fun (key_id, (_secret, public)) -> let pub_b64 = base64_encode public in let to_sign = Printf.sprintf {|{"key":"%s"}|} pub_b64 in let signature = sign t to_sign in (key_id, pub_b64, signature) ) t.one_time_keys (** Mark one-time keys as published (remove them from pending). *) let mark_keys_as_published _t = (* One-time keys are kept until used in a session *) () (** Generate a fallback key. *) let generate_fallback_key t = let secret, public = X25519.gen_key () in let key_id = generate_key_id t in t.fallback_key <- Some (key_id, (secret, public)) (** Get the fallback key if one exists. *) let fallback_key t = match t.fallback_key with | Some (key_id, (_secret, public)) -> Some (key_id, base64_encode public) | None -> None (** Remove a one-time key by ID (after session creation). *) let remove_one_time_key t key_id = t.one_time_keys <- List.filter (fun (id, _) -> id <> key_id) t.one_time_keys (** Get the secret for a one-time key (for session creation). *) let get_one_time_key_secret t key_id = match List.assoc_opt key_id t.one_time_keys with | Some (secret, _public) -> Some secret | None -> (* Check fallback key *) match t.fallback_key with | Some (id, (secret, _public)) when id = key_id -> Some secret | _ -> None (** Number of unpublished one-time keys. *) let one_time_keys_count t = List.length t.one_time_keys (** Maximum number of one-time keys this account can hold. *) let max_one_time_keys t = t.max_one_time_keys end (** Olm session state for the double ratchet algorithm. *) module Session = struct (** Ratchet chain key *) type chain_key = { key : string; (* 32 bytes *) index : int; } (** Root key used for deriving new chain keys *) type root_key = string (* 32 bytes *) (** Message key for encrypting a single message *) type message_key = string (* 32 bytes *) (** Session state *) type t = { (* Session ID *) session_id : string; (* Their identity key (Curve25519) *) their_identity_key : string; (* Their current ratchet key *) mutable their_ratchet_key : string option; (* Our current ratchet key pair *) mutable our_ratchet_secret : X25519.secret; mutable our_ratchet_public : string; (* Root key for deriving chain keys *) mutable root_key : root_key; (* Sending chain *) mutable sending_chain : chain_key option; (* Receiving chains (their_ratchet_key -> chain) *) mutable receiving_chains : (string * chain_key) list; (* Skipped message keys for out-of-order decryption *) mutable skipped_keys : ((string * int) * message_key) list; (* Creation time *) creation_time : Ptime.t; } (** HKDF for key derivation using SHA-256 *) let hkdf_sha256 ~salt ~info ~ikm length = let prk = Hkdf.extract ~hash:`SHA256 ~salt ikm in Hkdf.expand ~hash:`SHA256 ~prk ~info length (** Derive root and chain keys from shared secret *) let kdf_rk root_key shared_secret = let derived = hkdf_sha256 ~salt:root_key ~info:"OLM_ROOT" ~ikm:shared_secret 64 in let new_root = String.sub derived 0 32 in let chain_key = String.sub derived 32 32 in (new_root, chain_key) (** Derive next chain key and message key *) let kdf_ck chain_key = let mk = hkdf_sha256 ~salt:"" ~info:"OLM_CHAIN_MESSAGE" ~ikm:chain_key.key 32 in let new_ck = hkdf_sha256 ~salt:"" ~info:"OLM_CHAIN_KEY" ~ikm:chain_key.key 32 in (mk, { key = new_ck; index = chain_key.index + 1 }) (** Perform X3DH key agreement for outbound session *) let x3dh_outbound ~our_identity_secret ~our_ephemeral_secret ~their_identity_key ~their_one_time_key = (* DH1: our_identity_secret, their_one_time_key *) let dh1 = match X25519.key_exchange our_identity_secret their_one_time_key with | Ok s -> s | Error _ -> failwith "Key exchange failed" in (* DH2: our_ephemeral_secret, their_identity_key *) let dh2 = match X25519.key_exchange our_ephemeral_secret their_identity_key with | Ok s -> s | Error _ -> failwith "Key exchange failed" in (* DH3: our_ephemeral_secret, their_one_time_key *) let dh3 = match X25519.key_exchange our_ephemeral_secret their_one_time_key with | Ok s -> s | Error _ -> failwith "Key exchange failed" in (* Combine: DH1 || DH2 || DH3 *) dh1 ^ dh2 ^ dh3 (** Create a new outbound session (when sending first message). *) let create_outbound account ~their_identity_key ~their_one_time_key = (* Parse their keys from base64 *) let their_identity = match base64_decode their_identity_key with | Ok k -> k | Error _ -> failwith "Invalid identity key" in let their_otk = match base64_decode their_one_time_key with | Ok k -> k | Error _ -> failwith "Invalid one-time key" in (* Generate ephemeral key for X3DH *) let ephemeral_secret, _ephemeral_public = X25519.gen_key () in (* Perform X3DH *) let shared_secret = x3dh_outbound ~our_identity_secret:account.Account.curve25519_secret ~our_ephemeral_secret:ephemeral_secret ~their_identity_key:their_identity ~their_one_time_key:their_otk in (* Derive root key *) let root_key = hkdf_sha256 ~salt:"" ~info:"OLM_ROOT" ~ikm:shared_secret 32 in (* Generate our initial ratchet key *) let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in (* Session ID is hash of the root key *) let session_id = Digestif.SHA256.(digest_string root_key |> to_raw_string) |> base64_encode in let now = match Ptime.of_float_s (Unix.gettimeofday ()) with | Some t -> t | None -> Ptime.epoch in (* Initial sending chain *) let sending_chain = Some { key = root_key; index = 0 } in { session_id; their_identity_key = their_identity; their_ratchet_key = None; our_ratchet_secret; our_ratchet_public; root_key; sending_chain; receiving_chains = []; skipped_keys = []; creation_time = now; } (** Create a new inbound session (when receiving first message). *) let create_inbound account ~their_identity_key ~their_ephemeral_key ~one_time_key_id = (* Get our one-time key secret *) let our_otk_secret = match Account.get_one_time_key_secret account one_time_key_id with | Some s -> s | None -> failwith "One-time key not found" in (* Parse their keys *) let their_identity = match base64_decode their_identity_key with | Ok k -> k | Error _ -> failwith "Invalid identity key" in let their_ephemeral = match base64_decode their_ephemeral_key with | Ok k -> k | Error _ -> failwith "Invalid ephemeral key" in (* Perform reverse X3DH *) (* DH1: their_identity, our_otk *) let dh1 = match X25519.key_exchange our_otk_secret their_identity with | Ok s -> s | Error _ -> failwith "Key exchange failed" in (* DH2: their_ephemeral, our_identity *) let dh2 = match X25519.key_exchange account.Account.curve25519_secret their_ephemeral with | Ok s -> s | Error _ -> failwith "Key exchange failed" in (* DH3: their_ephemeral, our_otk *) let dh3 = match X25519.key_exchange our_otk_secret their_ephemeral with | Ok s -> s | Error _ -> failwith "Key exchange failed" in let shared_secret = dh1 ^ dh2 ^ dh3 in (* Derive root key *) let root_key = hkdf_sha256 ~salt:"" ~info:"OLM_ROOT" ~ikm:shared_secret 32 in (* Generate our ratchet key *) let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in let session_id = Digestif.SHA256.(digest_string root_key |> to_raw_string) |> base64_encode in let now = match Ptime.of_float_s (Unix.gettimeofday ()) with | Some t -> t | None -> Ptime.epoch in (* Remove the used one-time key *) Account.remove_one_time_key account one_time_key_id; { session_id; their_identity_key = their_identity; their_ratchet_key = Some their_ephemeral; our_ratchet_secret; our_ratchet_public; root_key; sending_chain = None; receiving_chains = [(their_ephemeral, { key = root_key; index = 0 })]; skipped_keys = []; creation_time = now; } (** Get session ID *) let session_id t = t.session_id (** Get their identity key *) let their_identity_key t = base64_encode t.their_identity_key (** Encrypt a message using AES-256-CBC with HMAC-SHA256 *) let aes_encrypt key plaintext = (* Use first 32 bytes for AES key, derive IV *) let aes_key = String.sub key 0 32 in let iv = Digestif.SHA256.(digest_string (aes_key ^ "IV") |> to_raw_string) |> fun s -> String.sub s 0 16 in (* PKCS7 padding *) let block_size = 16 in let pad_len = block_size - (String.length plaintext mod block_size) in let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in (* Encrypt using mirage-crypto AES.CBC *) let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in let encrypted = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in iv ^ encrypted (** Decrypt a message *) let aes_decrypt key ciphertext = if String.length ciphertext < 16 then Error "Ciphertext too short" else let iv = String.sub ciphertext 0 16 in let data = String.sub ciphertext 16 (String.length ciphertext - 16) in let aes_key = String.sub key 0 32 in let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv data in (* Remove PKCS7 padding *) if String.length decrypted = 0 then Error "Empty plaintext" else let pad_len = Char.code decrypted.[String.length decrypted - 1] in if pad_len > 16 || pad_len > String.length decrypted then Error "Invalid padding" else Ok (String.sub decrypted 0 (String.length decrypted - pad_len)) (** Encrypt a plaintext message. Returns (message_type, ciphertext). *) let encrypt t plaintext = (* Get or create sending chain *) let chain = match t.sending_chain with | Some c -> c | None -> (* Need to ratchet first *) let new_secret, new_public = X25519.gen_key () in t.our_ratchet_secret <- new_secret; t.our_ratchet_public <- new_public; { key = t.root_key; index = 0 } in (* Derive message key *) let message_key, new_chain = kdf_ck chain in t.sending_chain <- Some new_chain; (* Encrypt the message *) let ciphertext = aes_encrypt message_key plaintext in (* Create message payload with ratchet key and chain index *) let ratchet_key_b64 = base64_encode t.our_ratchet_public in let msg_type = if chain.index = 0 then 0 else 1 in (* 0 = prekey, 1 = normal *) let payload = Printf.sprintf "%s|%d|%s" ratchet_key_b64 new_chain.index (base64_encode ciphertext) in (msg_type, payload) (** Decrypt a message. *) let decrypt t ~message_type ~ciphertext:payload = (* Parse payload *) match String.split_on_char '|' payload with | [ratchet_key_b64; index_str; ciphertext_b64] -> let their_ratchet = match base64_decode ratchet_key_b64 with | Ok k -> k | Error _ -> failwith "Invalid ratchet key" in let msg_index = int_of_string index_str in let ciphertext = match base64_decode ciphertext_b64 with | Ok c -> c | Error _ -> failwith "Invalid ciphertext" in (* Check if we need to advance the ratchet *) let _need_ratchet = match t.their_ratchet_key with | Some k when k = their_ratchet -> false | _ -> true in (* Find or create receiving chain *) let chain = match List.assoc_opt their_ratchet t.receiving_chains with | Some c -> c | None -> (* New ratchet - derive new chain *) let dh_out = match X25519.key_exchange t.our_ratchet_secret their_ratchet with | Ok s -> s | Error _ -> failwith "Key exchange failed" in let new_root, chain_key = kdf_rk t.root_key dh_out in t.root_key <- new_root; t.their_ratchet_key <- Some their_ratchet; let chain = { key = chain_key; index = 0 } in t.receiving_chains <- (their_ratchet, chain) :: t.receiving_chains; chain in (* Advance chain to the right index *) let rec advance_chain c target_idx = if c.index >= target_idx then c else let mk, new_c = kdf_ck c in (* Store skipped keys *) t.skipped_keys <- ((their_ratchet, c.index), mk) :: t.skipped_keys; advance_chain new_c target_idx in let chain = advance_chain chain msg_index in (* Get message key *) let message_key, new_chain = kdf_ck chain in (* Update chain *) t.receiving_chains <- (their_ratchet, new_chain) :: (List.filter (fun (k, _) -> k <> their_ratchet) t.receiving_chains); (* Decrypt *) let _ = message_type in aes_decrypt message_key ciphertext | _ -> Error "Invalid message format" (** Check if this is a pre-key message (first message in session). *) let is_pre_key_message message_type = message_type = 0 end (** Megolm session for room message encryption. Megolm uses a ratchet that only moves forward, making it efficient for encrypting many messages to many recipients. *) module Megolm = struct (** Inbound session for decrypting received room messages *) module Inbound = struct type t = { session_id : string; sender_key : string; (* Curve25519 key of sender *) room_id : string; (* Ratchet state - 4 parts of 256 bits each *) mutable ratchet : string array; (* 4 x 32 bytes *) mutable message_index : int; (* For detecting replays *) mutable received_indices : int list; (* Ed25519 signing key of the sender *) signing_key : string; creation_time : Ptime.t; } (** Advance the ratchet by one step *) let advance_ratchet t = (* Megolm ratchet: each part hashes the parts below it *) let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in (* R(i,j) = H(R(i-1,j) || j) for j = 0,1,2,3 *) (* Simplified: we just hash each part with its index *) let i = t.message_index land 3 in for j = i to 3 do t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j) done; t.message_index <- t.message_index + 1 (** Create from exported session data *) let of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index ~signing_key = let now = match Ptime.of_float_s (Unix.gettimeofday ()) with | Some t -> t | None -> Ptime.epoch in { session_id; sender_key; room_id; ratchet; message_index; received_indices = []; signing_key; creation_time = now; } (** Create from room key event (m.room_key) *) let from_room_key ~sender_key ~room_id ~session_id ~session_key ~signing_key = (* Parse session_key which contains ratchet state *) let ratchet = match base64_decode session_key with | Ok data when String.length data >= 128 -> [| String.sub data 0 32; String.sub data 32 32; String.sub data 64 32; String.sub data 96 32; |] | _ -> (* Generate random initial state if parsing fails *) let random_part () = Mirage_crypto_rng.generate 32 in [| random_part (); random_part (); random_part (); random_part () |] in of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index:0 ~signing_key let session_id t = t.session_id let sender_key t = t.sender_key let room_id t = t.room_id let first_known_index t = t.message_index (** Derive encryption key from current ratchet state *) let derive_key t = let combined = String.concat "" (Array.to_list t.ratchet) in Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80 (** Decrypt a message *) let decrypt t ~ciphertext ~message_index = (* Check for replay *) if List.mem message_index t.received_indices then Error "Duplicate message index (replay attack)" else if message_index < t.message_index then Error "Message index too old" else begin (* Advance ratchet to the right position *) while t.message_index < message_index do advance_ratchet t done; (* Derive key and decrypt *) let key_material = derive_key t in let aes_key = String.sub key_material 0 32 in let hmac_key = String.sub key_material 32 32 in let iv = String.sub key_material 64 16 in (* Verify HMAC if present (last 8 bytes of ciphertext) *) let ct_len = String.length ciphertext in if ct_len < 24 then Error "Ciphertext too short" else begin let ct_data = String.sub ciphertext 0 (ct_len - 8) in let mac = String.sub ciphertext (ct_len - 8) 8 in let expected_mac = Digestif.SHA256.hmac_string ~key:hmac_key ct_data |> Digestif.SHA256.to_raw_string |> fun s -> String.sub s 0 8 in if mac <> expected_mac then Error "MAC verification failed" else begin (* Decrypt using mirage-crypto AES.CBC *) let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv ct_data in (* Remove PKCS7 padding *) let pad_len = Char.code decrypted.[String.length decrypted - 1] in if pad_len > 16 then Error "Invalid padding" else begin t.received_indices <- message_index :: t.received_indices; advance_ratchet t; Ok (String.sub decrypted 0 (String.length decrypted - pad_len)) end end end end end (** Outbound session for encrypting messages to send to a room *) module Outbound = struct type t = { session_id : string; room_id : string; (* Ratchet state *) mutable ratchet : string array; mutable message_index : int; (* Ed25519 signing key *) signing_priv : Ed25519.priv; signing_pub : Ed25519.pub; (* Creation and rotation tracking *) creation_time : Ptime.t; mutable message_count : int; max_messages : int; max_age : Ptime.Span.t; (* Users this session has been shared with *) mutable shared_with : (string * string) list; (* user_id, device_id pairs *) } (** Create a new outbound session for a room *) let create ~room_id = let session_id = Mirage_crypto_rng.generate 16 |> base64_encode in let random_part () = Mirage_crypto_rng.generate 32 in let ratchet = [| random_part (); random_part (); random_part (); random_part () |] in let signing_priv, signing_pub = Ed25519.generate () in let now = match Ptime.of_float_s (Unix.gettimeofday ()) with | Some t -> t | None -> Ptime.epoch in { session_id; room_id; ratchet; message_index = 0; signing_priv; signing_pub; creation_time = now; message_count = 0; max_messages = 100; max_age = Ptime.Span.of_int_s (7 * 24 * 60 * 60); (* 1 week *) shared_with = []; } (** Advance the ratchet *) let advance_ratchet t = let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in let i = t.message_index land 3 in for j = i to 3 do t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j) done; t.message_index <- t.message_index + 1 let session_id t = t.session_id let room_id t = t.room_id let message_index t = t.message_index (** Check if session should be rotated *) let needs_rotation t = t.message_count >= t.max_messages || match Ptime.of_float_s (Unix.gettimeofday ()) with | Some now -> (match Ptime.diff now t.creation_time |> Ptime.Span.compare t.max_age with | n when n > 0 -> true | _ -> false) | None -> false (** Derive encryption key *) let derive_key t = let combined = String.concat "" (Array.to_list t.ratchet) in Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80 (** Export the session key for sharing via m.room_key *) let export_session_key t = let ratchet_data = String.concat "" (Array.to_list t.ratchet) in base64_encode ratchet_data (** Get the signing key *) let signing_key t = Ed25519.pub_to_octets t.signing_pub |> base64_encode (** Encrypt a message *) let encrypt t plaintext = let key_material = derive_key t in let aes_key = String.sub key_material 0 32 in let hmac_key = String.sub key_material 32 32 in let iv = String.sub key_material 64 16 in (* PKCS7 padding *) let block_size = 16 in let pad_len = block_size - (String.length plaintext mod block_size) in let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in (* Encrypt using mirage-crypto AES.CBC *) let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in let ct_data = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in (* Add HMAC (first 8 bytes) *) let mac = Digestif.SHA256.hmac_string ~key:hmac_key ct_data |> Digestif.SHA256.to_raw_string |> fun s -> String.sub s 0 8 in let ciphertext = ct_data ^ mac in let msg_index = t.message_index in (* Advance ratchet for next message *) advance_ratchet t; t.message_count <- t.message_count + 1; (* Return message index and ciphertext *) (msg_index, base64_encode ciphertext) (** Mark session as shared with a user/device *) let mark_shared_with t ~user_id ~device_id = if not (List.mem (user_id, device_id) t.shared_with) then t.shared_with <- (user_id, device_id) :: t.shared_with (** Check if already shared with a user/device *) let is_shared_with t ~user_id ~device_id = List.mem (user_id, device_id) t.shared_with (** Get list of users this session is shared with *) let shared_with t = t.shared_with end end (** Olm Machine - high-level state machine for E2EE operations *) module Machine = struct type t = { user_id : string; device_id : string; account : Account.t; (* Active Olm sessions indexed by their curve25519 key *) mutable sessions : (string * Session.t list) list; (* Outbound Megolm sessions by room_id *) mutable outbound_group_sessions : (string * Megolm.Outbound.t) list; (* Inbound Megolm sessions by (room_id, session_id) *) mutable inbound_group_sessions : ((string * string) * Megolm.Inbound.t) list; (* Device keys we know about: user_id -> device_id -> device_keys *) mutable device_keys : (string * (string * Keys.queried_device_keys) list) list; } (** Create a new OlmMachine *) let create ~user_id ~device_id = let account = Account.create () in { user_id; device_id; account; sessions = []; outbound_group_sessions = []; inbound_group_sessions = []; device_keys = []; } (** Get identity keys *) let identity_keys t = Account.identity_keys t.account (** Get device keys for upload *) let device_keys_for_upload t = let ed25519, curve25519 = identity_keys t in let algorithms = [ "m.olm.v1.curve25519-aes-sha2-256"; "m.megolm.v1.aes-sha2-256"; ] in let keys = [ (Printf.sprintf "ed25519:%s" t.device_id, ed25519); (Printf.sprintf "curve25519:%s" t.device_id, curve25519); ] in (t.user_id, t.device_id, algorithms, keys) (** Generate one-time keys if needed *) let generate_one_time_keys t count = Account.generate_one_time_keys t.account count (** Get one-time keys for upload *) let one_time_keys_for_upload t = Account.signed_one_time_keys t.account (** Mark keys as uploaded *) let mark_keys_as_published t = Account.mark_keys_as_published t.account (** Store device keys from key query response *) let receive_device_keys t ~user_id ~devices = t.device_keys <- (user_id, devices) :: (List.filter (fun (uid, _) -> uid <> user_id) t.device_keys) (** Get or create outbound Megolm session for a room *) let get_outbound_group_session t ~room_id = match List.assoc_opt room_id t.outbound_group_sessions with | Some session when not (Megolm.Outbound.needs_rotation session) -> session | _ -> (* Create new session *) let session = Megolm.Outbound.create ~room_id in t.outbound_group_sessions <- (room_id, session) :: (List.filter (fun (rid, _) -> rid <> room_id) t.outbound_group_sessions); session (** Store inbound Megolm session from room key event *) let receive_room_key t ~sender_key ~room_id ~session_id ~session_key ~signing_key = let session = Megolm.Inbound.from_room_key ~sender_key ~room_id ~session_id ~session_key ~signing_key in t.inbound_group_sessions <- ((room_id, session_id), session) :: t.inbound_group_sessions (** Encrypt a room message using Megolm *) let encrypt_room_message t ~room_id ~content = let session = get_outbound_group_session t ~room_id in let msg_index, ciphertext = Megolm.Outbound.encrypt session content in let _, curve25519_key = identity_keys t in (* Build m.room.encrypted content *) let encrypted_content = Printf.sprintf {|{"algorithm":"m.megolm.v1.aes-sha2-256","sender_key":"%s","ciphertext":"%s","session_id":"%s","device_id":"%s"}|} curve25519_key ciphertext (Megolm.Outbound.session_id session) t.device_id in let _ = msg_index in encrypted_content (** Decrypt a room message *) let decrypt_room_message t ~room_id ~sender_key ~session_id ~ciphertext ~message_index = match List.assoc_opt (room_id, session_id) t.inbound_group_sessions with | Some session when Megolm.Inbound.sender_key session = sender_key -> Megolm.Inbound.decrypt session ~ciphertext ~message_index | Some _ -> Error "Sender key mismatch" | None -> Error "Unknown session" (** Get or create Olm session for a device *) let get_olm_session t ~their_identity_key = match List.assoc_opt their_identity_key t.sessions with | Some (session :: _) -> Some session | _ -> None (** Create outbound Olm session *) let create_olm_session t ~their_identity_key ~their_one_time_key = let session = Session.create_outbound t.account ~their_identity_key ~their_one_time_key in let existing = match List.assoc_opt their_identity_key t.sessions with | Some sessions -> sessions | None -> [] in t.sessions <- (their_identity_key, session :: existing) :: (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions); session (** Process inbound Olm message to create session *) let create_inbound_session t ~their_identity_key ~their_ephemeral_key ~one_time_key_id = let session = Session.create_inbound t.account ~their_identity_key ~their_ephemeral_key ~one_time_key_id in let existing = match List.assoc_opt their_identity_key t.sessions with | Some sessions -> sessions | None -> [] in t.sessions <- (their_identity_key, session :: existing) :: (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions); session (** Encrypt a to-device message *) let encrypt_to_device t ~their_identity_key ~their_one_time_key ~plaintext = let session = match get_olm_session t ~their_identity_key with | Some s -> s | None -> create_olm_session t ~their_identity_key ~their_one_time_key in Session.encrypt session plaintext (** Decrypt a to-device message *) let decrypt_to_device t ~their_identity_key ~message_type ~ciphertext = match get_olm_session t ~their_identity_key with | Some session -> Session.decrypt session ~message_type ~ciphertext | None -> Error "No session for sender" (** Number of one-time keys remaining *) let one_time_keys_count t = Account.one_time_keys_count t.account (** Should upload more one-time keys? *) let should_upload_keys t = one_time_keys_count t < Account.max_one_time_keys t.account / 2 end