Matrix protocol in OCaml, Eio specialised
1(** Olm/Megolm cryptographic session management.
2
3 This module implements the Olm double-ratchet algorithm for encrypted
4 to-device messages, and Megolm for encrypted room messages. *)
5
6module Ed25519 = Mirage_crypto_ec.Ed25519
7module X25519 = Mirage_crypto_ec.X25519
8
9(* Base64 encoding/decoding - Matrix uses unpadded base64 *)
10let base64_encode s = Base64.encode_string ~pad:false s
11let base64_decode s = Base64.decode ~pad:false s
12
13(** Olm account - manages identity keys and one-time keys.
14
15 An Olm account contains:
16 - Ed25519 identity key (for signing)
17 - Curve25519 identity key (for key exchange)
18 - One-time keys (for session establishment) *)
19module Account = struct
20 type t = {
21 (* Identity keys *)
22 ed25519_priv : Ed25519.priv;
23 ed25519_pub : Ed25519.pub;
24 curve25519_secret : X25519.secret;
25 curve25519_public : string;
26 (* One-time keys - key_id -> (secret, public) *)
27 mutable one_time_keys : (string * (X25519.secret * string)) list;
28 (* Fallback keys *)
29 mutable fallback_key : (string * (X25519.secret * string)) option;
30 (* Key counter for generating key IDs *)
31 mutable next_key_id : int;
32 (* Max number of one-time keys to store *)
33 max_one_time_keys : int;
34 }
35
36 (** Generate a new Olm account with fresh identity keys. *)
37 let create () =
38 let ed25519_priv, ed25519_pub = Ed25519.generate () in
39 let curve25519_secret, curve25519_public = X25519.gen_key () in
40 {
41 ed25519_priv;
42 ed25519_pub;
43 curve25519_secret;
44 curve25519_public;
45 one_time_keys = [];
46 fallback_key = None;
47 next_key_id = 0;
48 max_one_time_keys = 100;
49 }
50
51 (** Get the Ed25519 identity key as base64. *)
52 let ed25519_key t =
53 Ed25519.pub_to_octets t.ed25519_pub |> base64_encode
54
55 (** Get the Curve25519 identity key as base64. *)
56 let curve25519_key t =
57 base64_encode t.curve25519_public
58
59 (** Get the identity keys as a pair (ed25519, curve25519). *)
60 let identity_keys t =
61 (ed25519_key t, curve25519_key t)
62
63 (** Sign a message with the account's Ed25519 key. *)
64 let sign t message =
65 let signature = Ed25519.sign ~key:t.ed25519_priv message in
66 base64_encode signature
67
68 (** Generate a unique key ID. *)
69 let generate_key_id t =
70 let id = Printf.sprintf "AAAA%02dAA" t.next_key_id in
71 t.next_key_id <- t.next_key_id + 1;
72 id
73
74 (** Generate new one-time keys. *)
75 let generate_one_time_keys t count =
76 let count = min count (t.max_one_time_keys - List.length t.one_time_keys) in
77 for _ = 1 to count do
78 let secret, public = X25519.gen_key () in
79 let key_id = generate_key_id t in
80 t.one_time_keys <- (key_id, (secret, public)) :: t.one_time_keys
81 done
82
83 (** Get one-time keys for upload (key_id -> public_key). *)
84 let one_time_keys t =
85 List.map (fun (key_id, (_secret, public)) ->
86 (key_id, base64_encode public)
87 ) t.one_time_keys
88
89 (** Get signed one-time keys for upload. *)
90 let signed_one_time_keys t =
91 List.map (fun (key_id, (_secret, public)) ->
92 let pub_b64 = base64_encode public in
93 let to_sign = Printf.sprintf {|{"key":"%s"}|} pub_b64 in
94 let signature = sign t to_sign in
95 (key_id, pub_b64, signature)
96 ) t.one_time_keys
97
98 (** Mark one-time keys as published (remove them from pending). *)
99 let mark_keys_as_published _t =
100 (* One-time keys are kept until used in a session *)
101 ()
102
103 (** Generate a fallback key. *)
104 let generate_fallback_key t =
105 let secret, public = X25519.gen_key () in
106 let key_id = generate_key_id t in
107 t.fallback_key <- Some (key_id, (secret, public))
108
109 (** Get the fallback key if one exists. *)
110 let fallback_key t =
111 match t.fallback_key with
112 | Some (key_id, (_secret, public)) -> Some (key_id, base64_encode public)
113 | None -> None
114
115 (** Remove a one-time key by ID (after session creation). *)
116 let remove_one_time_key t key_id =
117 t.one_time_keys <- List.filter (fun (id, _) -> id <> key_id) t.one_time_keys
118
119 (** Get the secret for a one-time key (for session creation). *)
120 let get_one_time_key_secret t key_id =
121 match List.assoc_opt key_id t.one_time_keys with
122 | Some (secret, _public) -> Some secret
123 | None ->
124 (* Check fallback key *)
125 match t.fallback_key with
126 | Some (id, (secret, _public)) when id = key_id -> Some secret
127 | _ -> None
128
129 (** Number of unpublished one-time keys. *)
130 let one_time_keys_count t = List.length t.one_time_keys
131
132 (** Maximum number of one-time keys this account can hold. *)
133 let max_one_time_keys t = t.max_one_time_keys
134end
135
136(** Olm session state for the double ratchet algorithm. *)
137module Session = struct
138 (** Ratchet chain key *)
139 type chain_key = {
140 key : string; (* 32 bytes *)
141 index : int;
142 }
143
144 (** Root key used for deriving new chain keys *)
145 type root_key = string (* 32 bytes *)
146
147 (** Message key for encrypting a single message *)
148 type message_key = string (* 32 bytes *)
149
150 (** Session state *)
151 type t = {
152 (* Session ID *)
153 session_id : string;
154 (* Their identity key (Curve25519) *)
155 their_identity_key : string;
156 (* Their current ratchet key *)
157 mutable their_ratchet_key : string option;
158 (* Our current ratchet key pair *)
159 mutable our_ratchet_secret : X25519.secret;
160 mutable our_ratchet_public : string;
161 (* Root key for deriving chain keys *)
162 mutable root_key : root_key;
163 (* Sending chain *)
164 mutable sending_chain : chain_key option;
165 (* Receiving chains (their_ratchet_key -> chain) *)
166 mutable receiving_chains : (string * chain_key) list;
167 (* Skipped message keys for out-of-order decryption *)
168 mutable skipped_keys : ((string * int) * message_key) list;
169 (* Creation time *)
170 creation_time : Ptime.t;
171 }
172
173 (** HKDF for key derivation using SHA-256 *)
174 let hkdf_sha256 ~salt ~info ~ikm length =
175 let prk = Hkdf.extract ~hash:`SHA256 ~salt ikm in
176 Hkdf.expand ~hash:`SHA256 ~prk ~info length
177
178 (** Derive root and chain keys from shared secret *)
179 let kdf_rk root_key shared_secret =
180 let derived = hkdf_sha256
181 ~salt:root_key
182 ~info:"OLM_ROOT"
183 ~ikm:shared_secret
184 64
185 in
186 let new_root = String.sub derived 0 32 in
187 let chain_key = String.sub derived 32 32 in
188 (new_root, chain_key)
189
190 (** Derive next chain key and message key *)
191 let kdf_ck chain_key =
192 let mk = hkdf_sha256
193 ~salt:""
194 ~info:"OLM_CHAIN_MESSAGE"
195 ~ikm:chain_key.key
196 32
197 in
198 let new_ck = hkdf_sha256
199 ~salt:""
200 ~info:"OLM_CHAIN_KEY"
201 ~ikm:chain_key.key
202 32
203 in
204 (mk, { key = new_ck; index = chain_key.index + 1 })
205
206 (** Perform X3DH key agreement for outbound session *)
207 let x3dh_outbound ~our_identity_secret ~our_ephemeral_secret
208 ~their_identity_key ~their_one_time_key =
209 (* DH1: our_identity_secret, their_one_time_key *)
210 let dh1 = match X25519.key_exchange our_identity_secret their_one_time_key with
211 | Ok s -> s
212 | Error _ -> failwith "Key exchange failed"
213 in
214 (* DH2: our_ephemeral_secret, their_identity_key *)
215 let dh2 = match X25519.key_exchange our_ephemeral_secret their_identity_key with
216 | Ok s -> s
217 | Error _ -> failwith "Key exchange failed"
218 in
219 (* DH3: our_ephemeral_secret, their_one_time_key *)
220 let dh3 = match X25519.key_exchange our_ephemeral_secret their_one_time_key with
221 | Ok s -> s
222 | Error _ -> failwith "Key exchange failed"
223 in
224 (* Combine: DH1 || DH2 || DH3 *)
225 dh1 ^ dh2 ^ dh3
226
227 (** Create a new outbound session (when sending first message). *)
228 let create_outbound account ~their_identity_key ~their_one_time_key =
229 (* Parse their keys from base64 *)
230 let their_identity = match base64_decode their_identity_key with
231 | Ok k -> k
232 | Error _ -> failwith "Invalid identity key"
233 in
234 let their_otk = match base64_decode their_one_time_key with
235 | Ok k -> k
236 | Error _ -> failwith "Invalid one-time key"
237 in
238 (* Generate ephemeral key for X3DH *)
239 let ephemeral_secret, _ephemeral_public = X25519.gen_key () in
240 (* Perform X3DH *)
241 let shared_secret = x3dh_outbound
242 ~our_identity_secret:account.Account.curve25519_secret
243 ~our_ephemeral_secret:ephemeral_secret
244 ~their_identity_key:their_identity
245 ~their_one_time_key:their_otk
246 in
247 (* Derive root key *)
248 let root_key = hkdf_sha256
249 ~salt:""
250 ~info:"OLM_ROOT"
251 ~ikm:shared_secret
252 32
253 in
254 (* Generate our initial ratchet key *)
255 let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in
256 (* Session ID is hash of the root key *)
257 let session_id =
258 Digestif.SHA256.(digest_string root_key |> to_raw_string)
259 |> base64_encode
260 in
261 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with
262 | Some t -> t
263 | None -> Ptime.epoch
264 in
265 (* Initial sending chain *)
266 let sending_chain = Some { key = root_key; index = 0 } in
267 {
268 session_id;
269 their_identity_key = their_identity;
270 their_ratchet_key = None;
271 our_ratchet_secret;
272 our_ratchet_public;
273 root_key;
274 sending_chain;
275 receiving_chains = [];
276 skipped_keys = [];
277 creation_time = now;
278 }
279
280 (** Create a new inbound session (when receiving first message). *)
281 let create_inbound account ~their_identity_key ~their_ephemeral_key ~one_time_key_id =
282 (* Get our one-time key secret *)
283 let our_otk_secret = match Account.get_one_time_key_secret account one_time_key_id with
284 | Some s -> s
285 | None -> failwith "One-time key not found"
286 in
287 (* Parse their keys *)
288 let their_identity = match base64_decode their_identity_key with
289 | Ok k -> k
290 | Error _ -> failwith "Invalid identity key"
291 in
292 let their_ephemeral = match base64_decode their_ephemeral_key with
293 | Ok k -> k
294 | Error _ -> failwith "Invalid ephemeral key"
295 in
296 (* Perform reverse X3DH *)
297 (* DH1: their_identity, our_otk *)
298 let dh1 = match X25519.key_exchange our_otk_secret their_identity with
299 | Ok s -> s
300 | Error _ -> failwith "Key exchange failed"
301 in
302 (* DH2: their_ephemeral, our_identity *)
303 let dh2 = match X25519.key_exchange account.Account.curve25519_secret their_ephemeral with
304 | Ok s -> s
305 | Error _ -> failwith "Key exchange failed"
306 in
307 (* DH3: their_ephemeral, our_otk *)
308 let dh3 = match X25519.key_exchange our_otk_secret their_ephemeral with
309 | Ok s -> s
310 | Error _ -> failwith "Key exchange failed"
311 in
312 let shared_secret = dh1 ^ dh2 ^ dh3 in
313 (* Derive root key *)
314 let root_key = hkdf_sha256
315 ~salt:""
316 ~info:"OLM_ROOT"
317 ~ikm:shared_secret
318 32
319 in
320 (* Generate our ratchet key *)
321 let our_ratchet_secret, our_ratchet_public = X25519.gen_key () in
322 let session_id =
323 Digestif.SHA256.(digest_string root_key |> to_raw_string)
324 |> base64_encode
325 in
326 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with
327 | Some t -> t
328 | None -> Ptime.epoch
329 in
330 (* Remove the used one-time key *)
331 Account.remove_one_time_key account one_time_key_id;
332 {
333 session_id;
334 their_identity_key = their_identity;
335 their_ratchet_key = Some their_ephemeral;
336 our_ratchet_secret;
337 our_ratchet_public;
338 root_key;
339 sending_chain = None;
340 receiving_chains = [(their_ephemeral, { key = root_key; index = 0 })];
341 skipped_keys = [];
342 creation_time = now;
343 }
344
345 (** Get session ID *)
346 let session_id t = t.session_id
347
348 (** Get their identity key *)
349 let their_identity_key t = base64_encode t.their_identity_key
350
351 (** Encrypt a message using AES-256-CBC with HMAC-SHA256 *)
352 let aes_encrypt key plaintext =
353 (* Use first 32 bytes for AES key, derive IV *)
354 let aes_key = String.sub key 0 32 in
355 let iv = Digestif.SHA256.(digest_string (aes_key ^ "IV") |> to_raw_string)
356 |> fun s -> String.sub s 0 16 in
357 (* PKCS7 padding *)
358 let block_size = 16 in
359 let pad_len = block_size - (String.length plaintext mod block_size) in
360 let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in
361 (* Encrypt using mirage-crypto AES.CBC *)
362 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in
363 let encrypted = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in
364 iv ^ encrypted
365
366 (** Decrypt a message *)
367 let aes_decrypt key ciphertext =
368 if String.length ciphertext < 16 then
369 Error "Ciphertext too short"
370 else
371 let iv = String.sub ciphertext 0 16 in
372 let data = String.sub ciphertext 16 (String.length ciphertext - 16) in
373 let aes_key = String.sub key 0 32 in
374 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in
375 let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv data in
376 (* Remove PKCS7 padding *)
377 if String.length decrypted = 0 then
378 Error "Empty plaintext"
379 else
380 let pad_len = Char.code decrypted.[String.length decrypted - 1] in
381 if pad_len > 16 || pad_len > String.length decrypted then
382 Error "Invalid padding"
383 else
384 Ok (String.sub decrypted 0 (String.length decrypted - pad_len))
385
386 (** Encrypt a plaintext message. Returns (message_type, ciphertext). *)
387 let encrypt t plaintext =
388 (* Get or create sending chain *)
389 let chain = match t.sending_chain with
390 | Some c -> c
391 | None ->
392 (* Need to ratchet first *)
393 let new_secret, new_public = X25519.gen_key () in
394 t.our_ratchet_secret <- new_secret;
395 t.our_ratchet_public <- new_public;
396 { key = t.root_key; index = 0 }
397 in
398 (* Derive message key *)
399 let message_key, new_chain = kdf_ck chain in
400 t.sending_chain <- Some new_chain;
401 (* Encrypt the message *)
402 let ciphertext = aes_encrypt message_key plaintext in
403 (* Create message payload with ratchet key and chain index *)
404 let ratchet_key_b64 = base64_encode t.our_ratchet_public in
405 let msg_type = if chain.index = 0 then 0 else 1 in (* 0 = prekey, 1 = normal *)
406 let payload = Printf.sprintf "%s|%d|%s"
407 ratchet_key_b64
408 new_chain.index
409 (base64_encode ciphertext)
410 in
411 (msg_type, payload)
412
413 (** Decrypt a message. *)
414 let decrypt t ~message_type ~ciphertext:payload =
415 (* Parse payload *)
416 match String.split_on_char '|' payload with
417 | [ratchet_key_b64; index_str; ciphertext_b64] ->
418 let their_ratchet = match base64_decode ratchet_key_b64 with
419 | Ok k -> k
420 | Error _ -> failwith "Invalid ratchet key"
421 in
422 let msg_index = int_of_string index_str in
423 let ciphertext = match base64_decode ciphertext_b64 with
424 | Ok c -> c
425 | Error _ -> failwith "Invalid ciphertext"
426 in
427 (* Check if we need to advance the ratchet *)
428 let _need_ratchet = match t.their_ratchet_key with
429 | Some k when k = their_ratchet -> false
430 | _ -> true
431 in
432 (* Find or create receiving chain *)
433 let chain = match List.assoc_opt their_ratchet t.receiving_chains with
434 | Some c -> c
435 | None ->
436 (* New ratchet - derive new chain *)
437 let dh_out = match X25519.key_exchange t.our_ratchet_secret their_ratchet with
438 | Ok s -> s
439 | Error _ -> failwith "Key exchange failed"
440 in
441 let new_root, chain_key = kdf_rk t.root_key dh_out in
442 t.root_key <- new_root;
443 t.their_ratchet_key <- Some their_ratchet;
444 let chain = { key = chain_key; index = 0 } in
445 t.receiving_chains <- (their_ratchet, chain) :: t.receiving_chains;
446 chain
447 in
448 (* Advance chain to the right index *)
449 let rec advance_chain c target_idx =
450 if c.index >= target_idx then c
451 else
452 let mk, new_c = kdf_ck c in
453 (* Store skipped keys *)
454 t.skipped_keys <- ((their_ratchet, c.index), mk) :: t.skipped_keys;
455 advance_chain new_c target_idx
456 in
457 let chain = advance_chain chain msg_index in
458 (* Get message key *)
459 let message_key, new_chain = kdf_ck chain in
460 (* Update chain *)
461 t.receiving_chains <-
462 (their_ratchet, new_chain) ::
463 (List.filter (fun (k, _) -> k <> their_ratchet) t.receiving_chains);
464 (* Decrypt *)
465 let _ = message_type in
466 aes_decrypt message_key ciphertext
467 | _ -> Error "Invalid message format"
468
469 (** Check if this is a pre-key message (first message in session). *)
470 let is_pre_key_message message_type = message_type = 0
471end
472
473(** Megolm session for room message encryption.
474
475 Megolm uses a ratchet that only moves forward, making it efficient
476 for encrypting many messages to many recipients. *)
477module Megolm = struct
478 (** Inbound session for decrypting received room messages *)
479 module Inbound = struct
480 type t = {
481 session_id : string;
482 sender_key : string; (* Curve25519 key of sender *)
483 room_id : string;
484 (* Ratchet state - 4 parts of 256 bits each *)
485 mutable ratchet : string array; (* 4 x 32 bytes *)
486 mutable message_index : int;
487 (* For detecting replays *)
488 mutable received_indices : int list;
489 (* Ed25519 signing key of the sender *)
490 signing_key : string;
491 creation_time : Ptime.t;
492 }
493
494 (** Advance the ratchet by one step *)
495 let advance_ratchet t =
496 (* Megolm ratchet: each part hashes the parts below it *)
497 let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in
498 (* R(i,j) = H(R(i-1,j) || j) for j = 0,1,2,3 *)
499 (* Simplified: we just hash each part with its index *)
500 let i = t.message_index land 3 in
501 for j = i to 3 do
502 t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j)
503 done;
504 t.message_index <- t.message_index + 1
505
506 (** Create from exported session data *)
507 let of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index ~signing_key =
508 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with
509 | Some t -> t
510 | None -> Ptime.epoch
511 in
512 {
513 session_id;
514 sender_key;
515 room_id;
516 ratchet;
517 message_index;
518 received_indices = [];
519 signing_key;
520 creation_time = now;
521 }
522
523 (** Create from room key event (m.room_key) *)
524 let from_room_key ~sender_key ~room_id ~session_id ~session_key ~signing_key =
525 (* Parse session_key which contains ratchet state *)
526 let ratchet = match base64_decode session_key with
527 | Ok data when String.length data >= 128 ->
528 [|
529 String.sub data 0 32;
530 String.sub data 32 32;
531 String.sub data 64 32;
532 String.sub data 96 32;
533 |]
534 | _ ->
535 (* Generate random initial state if parsing fails *)
536 let random_part () =
537 Mirage_crypto_rng.generate 32
538 in
539 [| random_part (); random_part (); random_part (); random_part () |]
540 in
541 of_export ~session_id ~sender_key ~room_id ~ratchet ~message_index:0 ~signing_key
542
543 let session_id t = t.session_id
544 let sender_key t = t.sender_key
545 let room_id t = t.room_id
546 let first_known_index t = t.message_index
547
548 (** Derive encryption key from current ratchet state *)
549 let derive_key t =
550 let combined = String.concat "" (Array.to_list t.ratchet) in
551 Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80
552
553 (** Decrypt a message *)
554 let decrypt t ~ciphertext ~message_index =
555 (* Check for replay *)
556 if List.mem message_index t.received_indices then
557 Error "Duplicate message index (replay attack)"
558 else if message_index < t.message_index then
559 Error "Message index too old"
560 else begin
561 (* Advance ratchet to the right position *)
562 while t.message_index < message_index do
563 advance_ratchet t
564 done;
565 (* Derive key and decrypt *)
566 let key_material = derive_key t in
567 let aes_key = String.sub key_material 0 32 in
568 let hmac_key = String.sub key_material 32 32 in
569 let iv = String.sub key_material 64 16 in
570 (* Verify HMAC if present (last 8 bytes of ciphertext) *)
571 let ct_len = String.length ciphertext in
572 if ct_len < 24 then
573 Error "Ciphertext too short"
574 else begin
575 let ct_data = String.sub ciphertext 0 (ct_len - 8) in
576 let mac = String.sub ciphertext (ct_len - 8) 8 in
577 let expected_mac =
578 Digestif.SHA256.hmac_string ~key:hmac_key ct_data
579 |> Digestif.SHA256.to_raw_string
580 |> fun s -> String.sub s 0 8
581 in
582 if mac <> expected_mac then
583 Error "MAC verification failed"
584 else begin
585 (* Decrypt using mirage-crypto AES.CBC *)
586 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in
587 let decrypted = Mirage_crypto.AES.CBC.decrypt ~key:cipher ~iv ct_data in
588 (* Remove PKCS7 padding *)
589 let pad_len = Char.code decrypted.[String.length decrypted - 1] in
590 if pad_len > 16 then
591 Error "Invalid padding"
592 else begin
593 t.received_indices <- message_index :: t.received_indices;
594 advance_ratchet t;
595 Ok (String.sub decrypted 0 (String.length decrypted - pad_len))
596 end
597 end
598 end
599 end
600 end
601
602 (** Outbound session for encrypting messages to send to a room *)
603 module Outbound = struct
604 type t = {
605 session_id : string;
606 room_id : string;
607 (* Ratchet state *)
608 mutable ratchet : string array;
609 mutable message_index : int;
610 (* Ed25519 signing key *)
611 signing_priv : Ed25519.priv;
612 signing_pub : Ed25519.pub;
613 (* Creation and rotation tracking *)
614 creation_time : Ptime.t;
615 mutable message_count : int;
616 max_messages : int;
617 max_age : Ptime.Span.t;
618 (* Users this session has been shared with *)
619 mutable shared_with : (string * string) list; (* user_id, device_id pairs *)
620 }
621
622 (** Create a new outbound session for a room *)
623 let create ~room_id =
624 let session_id =
625 Mirage_crypto_rng.generate 16
626 |> base64_encode
627 in
628 let random_part () =
629 Mirage_crypto_rng.generate 32
630 in
631 let ratchet = [| random_part (); random_part (); random_part (); random_part () |] in
632 let signing_priv, signing_pub = Ed25519.generate () in
633 let now = match Ptime.of_float_s (Unix.gettimeofday ()) with
634 | Some t -> t
635 | None -> Ptime.epoch
636 in
637 {
638 session_id;
639 room_id;
640 ratchet;
641 message_index = 0;
642 signing_priv;
643 signing_pub;
644 creation_time = now;
645 message_count = 0;
646 max_messages = 100;
647 max_age = Ptime.Span.of_int_s (7 * 24 * 60 * 60); (* 1 week *)
648 shared_with = [];
649 }
650
651 (** Advance the ratchet *)
652 let advance_ratchet t =
653 let hash s = Digestif.SHA256.(digest_string s |> to_raw_string) in
654 let i = t.message_index land 3 in
655 for j = i to 3 do
656 t.ratchet.(j) <- hash (t.ratchet.(j) ^ string_of_int j)
657 done;
658 t.message_index <- t.message_index + 1
659
660 let session_id t = t.session_id
661 let room_id t = t.room_id
662 let message_index t = t.message_index
663
664 (** Check if session should be rotated *)
665 let needs_rotation t =
666 t.message_count >= t.max_messages ||
667 match Ptime.of_float_s (Unix.gettimeofday ()) with
668 | Some now ->
669 (match Ptime.diff now t.creation_time |> Ptime.Span.compare t.max_age with
670 | n when n > 0 -> true
671 | _ -> false)
672 | None -> false
673
674 (** Derive encryption key *)
675 let derive_key t =
676 let combined = String.concat "" (Array.to_list t.ratchet) in
677 Hkdf.expand ~hash:`SHA256 ~prk:combined ~info:"MEGOLM_KEYS" 80
678
679 (** Export the session key for sharing via m.room_key *)
680 let export_session_key t =
681 let ratchet_data = String.concat "" (Array.to_list t.ratchet) in
682 base64_encode ratchet_data
683
684 (** Get the signing key *)
685 let signing_key t =
686 Ed25519.pub_to_octets t.signing_pub |> base64_encode
687
688 (** Encrypt a message *)
689 let encrypt t plaintext =
690 let key_material = derive_key t in
691 let aes_key = String.sub key_material 0 32 in
692 let hmac_key = String.sub key_material 32 32 in
693 let iv = String.sub key_material 64 16 in
694 (* PKCS7 padding *)
695 let block_size = 16 in
696 let pad_len = block_size - (String.length plaintext mod block_size) in
697 let padded = plaintext ^ String.make pad_len (Char.chr pad_len) in
698 (* Encrypt using mirage-crypto AES.CBC *)
699 let cipher = Mirage_crypto.AES.CBC.of_secret aes_key in
700 let ct_data = Mirage_crypto.AES.CBC.encrypt ~key:cipher ~iv padded in
701 (* Add HMAC (first 8 bytes) *)
702 let mac =
703 Digestif.SHA256.hmac_string ~key:hmac_key ct_data
704 |> Digestif.SHA256.to_raw_string
705 |> fun s -> String.sub s 0 8
706 in
707 let ciphertext = ct_data ^ mac in
708 let msg_index = t.message_index in
709 (* Advance ratchet for next message *)
710 advance_ratchet t;
711 t.message_count <- t.message_count + 1;
712 (* Return message index and ciphertext *)
713 (msg_index, base64_encode ciphertext)
714
715 (** Mark session as shared with a user/device *)
716 let mark_shared_with t ~user_id ~device_id =
717 if not (List.mem (user_id, device_id) t.shared_with) then
718 t.shared_with <- (user_id, device_id) :: t.shared_with
719
720 (** Check if already shared with a user/device *)
721 let is_shared_with t ~user_id ~device_id =
722 List.mem (user_id, device_id) t.shared_with
723
724 (** Get list of users this session is shared with *)
725 let shared_with t = t.shared_with
726 end
727end
728
729(** Olm Machine - high-level state machine for E2EE operations *)
730module Machine = struct
731 type t = {
732 user_id : string;
733 device_id : string;
734 account : Account.t;
735 (* Active Olm sessions indexed by their curve25519 key *)
736 mutable sessions : (string * Session.t list) list;
737 (* Outbound Megolm sessions by room_id *)
738 mutable outbound_group_sessions : (string * Megolm.Outbound.t) list;
739 (* Inbound Megolm sessions by (room_id, session_id) *)
740 mutable inbound_group_sessions : ((string * string) * Megolm.Inbound.t) list;
741 (* Device keys we know about: user_id -> device_id -> device_keys *)
742 mutable device_keys : (string * (string * Keys.queried_device_keys) list) list;
743 }
744
745 (** Create a new OlmMachine *)
746 let create ~user_id ~device_id =
747 let account = Account.create () in
748 {
749 user_id;
750 device_id;
751 account;
752 sessions = [];
753 outbound_group_sessions = [];
754 inbound_group_sessions = [];
755 device_keys = [];
756 }
757
758 (** Get identity keys *)
759 let identity_keys t = Account.identity_keys t.account
760
761 (** Get device keys for upload *)
762 let device_keys_for_upload t =
763 let ed25519, curve25519 = identity_keys t in
764 let algorithms = [
765 "m.olm.v1.curve25519-aes-sha2-256";
766 "m.megolm.v1.aes-sha2-256";
767 ] in
768 let keys = [
769 (Printf.sprintf "ed25519:%s" t.device_id, ed25519);
770 (Printf.sprintf "curve25519:%s" t.device_id, curve25519);
771 ] in
772 (t.user_id, t.device_id, algorithms, keys)
773
774 (** Generate one-time keys if needed *)
775 let generate_one_time_keys t count =
776 Account.generate_one_time_keys t.account count
777
778 (** Get one-time keys for upload *)
779 let one_time_keys_for_upload t =
780 Account.signed_one_time_keys t.account
781
782 (** Mark keys as uploaded *)
783 let mark_keys_as_published t =
784 Account.mark_keys_as_published t.account
785
786 (** Store device keys from key query response *)
787 let receive_device_keys t ~user_id ~devices =
788 t.device_keys <- (user_id, devices) ::
789 (List.filter (fun (uid, _) -> uid <> user_id) t.device_keys)
790
791 (** Get or create outbound Megolm session for a room *)
792 let get_outbound_group_session t ~room_id =
793 match List.assoc_opt room_id t.outbound_group_sessions with
794 | Some session when not (Megolm.Outbound.needs_rotation session) ->
795 session
796 | _ ->
797 (* Create new session *)
798 let session = Megolm.Outbound.create ~room_id in
799 t.outbound_group_sessions <-
800 (room_id, session) ::
801 (List.filter (fun (rid, _) -> rid <> room_id) t.outbound_group_sessions);
802 session
803
804 (** Store inbound Megolm session from room key event *)
805 let receive_room_key t ~sender_key ~room_id ~session_id ~session_key ~signing_key =
806 let session = Megolm.Inbound.from_room_key
807 ~sender_key ~room_id ~session_id ~session_key ~signing_key
808 in
809 t.inbound_group_sessions <-
810 ((room_id, session_id), session) :: t.inbound_group_sessions
811
812 (** Encrypt a room message using Megolm *)
813 let encrypt_room_message t ~room_id ~content =
814 let session = get_outbound_group_session t ~room_id in
815 let msg_index, ciphertext = Megolm.Outbound.encrypt session content in
816 let _, curve25519_key = identity_keys t in
817 (* Build m.room.encrypted content *)
818 let encrypted_content = Printf.sprintf
819 {|{"algorithm":"m.megolm.v1.aes-sha2-256","sender_key":"%s","ciphertext":"%s","session_id":"%s","device_id":"%s"}|}
820 curve25519_key
821 ciphertext
822 (Megolm.Outbound.session_id session)
823 t.device_id
824 in
825 let _ = msg_index in
826 encrypted_content
827
828 (** Decrypt a room message *)
829 let decrypt_room_message t ~room_id ~sender_key ~session_id ~ciphertext ~message_index =
830 match List.assoc_opt (room_id, session_id) t.inbound_group_sessions with
831 | Some session when Megolm.Inbound.sender_key session = sender_key ->
832 Megolm.Inbound.decrypt session ~ciphertext ~message_index
833 | Some _ ->
834 Error "Sender key mismatch"
835 | None ->
836 Error "Unknown session"
837
838 (** Get or create Olm session for a device *)
839 let get_olm_session t ~their_identity_key =
840 match List.assoc_opt their_identity_key t.sessions with
841 | Some (session :: _) -> Some session
842 | _ -> None
843
844 (** Create outbound Olm session *)
845 let create_olm_session t ~their_identity_key ~their_one_time_key =
846 let session = Session.create_outbound t.account
847 ~their_identity_key ~their_one_time_key
848 in
849 let existing = match List.assoc_opt their_identity_key t.sessions with
850 | Some sessions -> sessions
851 | None -> []
852 in
853 t.sessions <-
854 (their_identity_key, session :: existing) ::
855 (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions);
856 session
857
858 (** Process inbound Olm message to create session *)
859 let create_inbound_session t ~their_identity_key ~their_ephemeral_key ~one_time_key_id =
860 let session = Session.create_inbound t.account
861 ~their_identity_key ~their_ephemeral_key ~one_time_key_id
862 in
863 let existing = match List.assoc_opt their_identity_key t.sessions with
864 | Some sessions -> sessions
865 | None -> []
866 in
867 t.sessions <-
868 (their_identity_key, session :: existing) ::
869 (List.filter (fun (k, _) -> k <> their_identity_key) t.sessions);
870 session
871
872 (** Encrypt a to-device message *)
873 let encrypt_to_device t ~their_identity_key ~their_one_time_key ~plaintext =
874 let session = match get_olm_session t ~their_identity_key with
875 | Some s -> s
876 | None -> create_olm_session t ~their_identity_key ~their_one_time_key
877 in
878 Session.encrypt session plaintext
879
880 (** Decrypt a to-device message *)
881 let decrypt_to_device t ~their_identity_key ~message_type ~ciphertext =
882 match get_olm_session t ~their_identity_key with
883 | Some session ->
884 Session.decrypt session ~message_type ~ciphertext
885 | None ->
886 Error "No session for sender"
887
888 (** Number of one-time keys remaining *)
889 let one_time_keys_count t =
890 Account.one_time_keys_count t.account
891
892 (** Should upload more one-time keys? *)
893 let should_upload_keys t =
894 one_time_keys_count t < Account.max_one_time_keys t.account / 2
895end