Matrix protocol in OCaml, Eio specialised
1(** E2EE key management for Matrix.
2
3 This module handles device keys, one-time keys, and key exchange
4 for end-to-end encryption using Olm/Megolm protocols. *)
5
6module Ed25519 = Mirage_crypto_ec.Ed25519
7module X25519 = Mirage_crypto_ec.X25519
8
9(* Base64 encoding/decoding - Matrix uses unpadded base64 *)
10let base64_encode s =
11 Base64.encode_string ~pad:false s
12
13let base64_decode s =
14 Base64.decode ~pad:false s
15
16(* Key types *)
17type ed25519_keypair = {
18 priv : Ed25519.priv;
19 pub : Ed25519.pub;
20}
21
22type curve25519_keypair = {
23 secret : X25519.secret;
24 public : string;
25}
26
27(* Device keys structure *)
28type device_keys = {
29 user_id : Matrix_proto.Id.User_id.t;
30 device_id : string;
31 algorithms : string list;
32 ed25519_key : string; (* base64 *)
33 curve25519_key : string; (* base64 *)
34 signatures : (string * (string * string) list) list; (* user_id -> key_id -> sig *)
35}
36
37(* One-time key *)
38type one_time_key = {
39 key_id : string;
40 key : string; (* base64 curve25519 public key *)
41 signature : string option; (* optional signature *)
42}
43
44(* Fallback key *)
45type fallback_key = {
46 key_id : string;
47 key : string;
48 signature : string option;
49}
50
51(* Key generation *)
52let generate_ed25519 () =
53 let priv, pub = Ed25519.generate () in
54 { priv; pub }
55
56let generate_curve25519 () =
57 let secret, public = X25519.gen_key () in
58 { secret; public }
59
60(* Serialize keys to base64 *)
61let ed25519_pub_to_base64 pub =
62 Ed25519.pub_to_octets pub |> base64_encode
63
64let ed25519_priv_to_base64 priv =
65 Ed25519.priv_to_octets priv |> base64_encode
66
67let curve25519_pub_to_base64 public =
68 base64_encode public
69
70let curve25519_secret_to_base64 secret =
71 X25519.secret_to_octets secret |> base64_encode
72
73(* Deserialize keys from base64 *)
74let ed25519_pub_of_base64 s =
75 match base64_decode s with
76 | Error _ -> Error "Invalid base64"
77 | Ok octets ->
78 match Ed25519.pub_of_octets octets with
79 | Error _ -> Error "Invalid Ed25519 public key"
80 | Ok pub -> Ok pub
81
82let ed25519_priv_of_base64 s =
83 match base64_decode s with
84 | Error _ -> Error "Invalid base64"
85 | Ok octets ->
86 match Ed25519.priv_of_octets octets with
87 | Error _ -> Error "Invalid Ed25519 private key"
88 | Ok priv -> Ok priv
89
90let curve25519_pub_of_base64 s =
91 base64_decode s |> Result.map_error (fun _ -> "Invalid base64")
92
93let curve25519_secret_of_base64 s =
94 match base64_decode s with
95 | Error _ -> Error "Invalid base64"
96 | Ok octets ->
97 match X25519.secret_of_octets octets with
98 | Error _ -> Error "Invalid Curve25519 secret key"
99 | Ok (secret, _public) -> Ok secret
100
101(* Signing *)
102let sign_json ed25519_priv json_str =
103 let signature = Ed25519.sign ~key:ed25519_priv json_str in
104 base64_encode signature
105
106let verify_signature ed25519_pub signature_b64 json_str =
107 match base64_decode signature_b64 with
108 | Error _ -> false
109 | Ok signature -> Ed25519.verify ~key:ed25519_pub signature ~msg:json_str
110
111(* Curve25519 key exchange *)
112let key_exchange ~secret ~their_public =
113 match X25519.key_exchange secret their_public with
114 | Error _ -> Error "Key exchange failed"
115 | Ok shared -> Ok shared
116
117(* Generate a batch of one-time keys *)
118let generate_one_time_keys ~count ?sign_with () : (one_time_key * curve25519_keypair) list =
119 List.init count (fun i ->
120 let kp = generate_curve25519 () in
121 let key_id = Printf.sprintf "AAAAAA%d" i in (* Would use proper ID generation *)
122 let key = curve25519_pub_to_base64 kp.public in
123 let signature = match sign_with with
124 | Some ed_priv ->
125 let to_sign = Printf.sprintf "{\"key\":\"%s\"}" key in
126 Some (sign_json ed_priv to_sign)
127 | None -> None
128 in
129 (({ key_id; key; signature } : one_time_key), kp)
130 )
131
132(* JSON codecs for key upload/query *)
133
134(* Upload keys request *)
135type upload_keys_request = {
136 device_keys : device_keys_json option;
137 one_time_keys : (string * one_time_key_json) list;
138 fallback_keys : (string * one_time_key_json) list;
139} [@@warning "-69"]
140
141and device_keys_json = {
142 user_id : string;
143 device_id : string;
144 algorithms : string list;
145 keys : (string * string) list; (* key_id -> key *)
146 signatures : (string * (string * string) list) list;
147} [@@warning "-69"]
148
149and one_time_key_json = {
150 key : string;
151 signatures : (string * (string * string) list) list option;
152} [@@warning "-69"]
153
154module StringMap = Map.Make(String)
155
156let string_string_map_jsont : (string * string) list Jsont.t =
157 let map_jsont = Jsont.Object.as_string_map Jsont.string in
158 Jsont.map
159 ~dec:(fun m -> StringMap.bindings m)
160 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq)
161 map_jsont
162
163let signatures_jsont : (string * (string * string) list) list Jsont.t =
164 let inner = Jsont.Object.as_string_map Jsont.string in
165 let outer = Jsont.Object.as_string_map inner in
166 Jsont.map
167 ~dec:(fun m ->
168 StringMap.bindings m
169 |> List.map (fun (k, v) -> (k, StringMap.bindings v)))
170 ~enc:(fun l ->
171 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l
172 |> List.to_seq |> StringMap.of_seq)
173 outer
174
175let device_keys_json_jsont : device_keys_json Jsont.t =
176 Jsont.Object.(
177 map (fun user_id device_id algorithms keys signatures ->
178 { user_id; device_id; algorithms; keys; signatures })
179 |> mem "user_id" Jsont.string
180 |> mem "device_id" Jsont.string
181 |> mem "algorithms" (Jsont.list Jsont.string)
182 |> mem "keys" string_string_map_jsont
183 |> mem "signatures" signatures_jsont
184 |> finish)
185
186let one_time_key_json_jsont : one_time_key_json Jsont.t =
187 Jsont.Object.(
188 map (fun key signatures -> { key; signatures })
189 |> mem "key" Jsont.string
190 |> opt_mem "signatures" signatures_jsont ~enc:(fun (t : one_time_key_json) -> t.signatures)
191 |> finish)
192
193let one_time_keys_map_jsont : (string * one_time_key_json) list Jsont.t =
194 let map_jsont = Jsont.Object.as_string_map one_time_key_json_jsont in
195 Jsont.map
196 ~dec:(fun m -> StringMap.bindings m)
197 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq)
198 map_jsont
199
200let upload_keys_request_jsont : upload_keys_request Jsont.t =
201 Jsont.Object.(
202 map (fun device_keys one_time_keys fallback_keys ->
203 { device_keys; one_time_keys; fallback_keys })
204 |> opt_mem "device_keys" device_keys_json_jsont ~enc:(fun t -> t.device_keys)
205 |> mem "one_time_keys" one_time_keys_map_jsont ~dec_absent:[]
206 |> mem "fallback_keys" one_time_keys_map_jsont ~dec_absent:[]
207 |> finish)
208
209(* Upload keys response *)
210type upload_keys_response = {
211 one_time_key_counts : (string * int) list;
212}
213
214let one_time_key_counts_jsont : (string * int) list Jsont.t =
215 let map_jsont = Jsont.Object.as_string_map Jsont.int in
216 Jsont.map
217 ~dec:(fun m -> StringMap.bindings m)
218 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq)
219 map_jsont
220
221let upload_keys_response_jsont =
222 Jsont.Object.(
223 map (fun one_time_key_counts -> { one_time_key_counts })
224 |> mem "one_time_key_counts" one_time_key_counts_jsont ~dec_absent:[]
225 |> finish)
226
227(* Upload device keys *)
228let upload_keys client ?device_keys ?(one_time_keys=[]) ?(fallback_keys=[]) () =
229 let request = { device_keys; one_time_keys; fallback_keys } in
230 match Client.encode_body upload_keys_request_jsont request with
231 | Error e -> Error e
232 | Ok body ->
233 match Client.post client ~path:"/keys/upload" ~body () with
234 | Error e -> Error e
235 | Ok body -> Client.decode_response upload_keys_response_jsont body
236
237(* Query keys request/response *)
238type query_keys_request = {
239 timeout : int option;
240 device_keys : (string * string list) list; (* user_id -> device_ids *)
241} [@@warning "-69"]
242
243let device_keys_query_jsont : (string * string list) list Jsont.t =
244 let map_jsont = Jsont.Object.as_string_map (Jsont.list Jsont.string) in
245 Jsont.map
246 ~dec:(fun m -> StringMap.bindings m)
247 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq)
248 map_jsont
249
250let query_keys_request_jsont =
251 Jsont.Object.(
252 map (fun timeout device_keys -> { timeout; device_keys })
253 |> opt_mem "timeout" Jsont.int ~enc:(fun t -> t.timeout)
254 |> mem "device_keys" device_keys_query_jsont
255 |> finish)
256
257type queried_device_keys = {
258 user_id : string;
259 device_id : string;
260 algorithms : string list;
261 keys : (string * string) list;
262 signatures : (string * (string * string) list) list;
263 unsigned : Jsont.json option;
264}
265
266let queried_device_keys_jsont =
267 Jsont.Object.(
268 map (fun user_id device_id algorithms keys signatures unsigned ->
269 { user_id; device_id; algorithms; keys; signatures; unsigned })
270 |> mem "user_id" Jsont.string
271 |> mem "device_id" Jsont.string
272 |> mem "algorithms" (Jsont.list Jsont.string) ~dec_absent:[]
273 |> mem "keys" string_string_map_jsont ~dec_absent:[]
274 |> mem "signatures" signatures_jsont ~dec_absent:[]
275 |> opt_mem "unsigned" Jsont.json ~enc:(fun t -> t.unsigned)
276 |> finish)
277
278type query_keys_response = {
279 failures : (string * Jsont.json) list;
280 device_keys : (string * (string * queried_device_keys) list) list;
281}
282
283let device_keys_map_jsont =
284 let inner = Jsont.Object.as_string_map queried_device_keys_jsont in
285 let outer = Jsont.Object.as_string_map inner in
286 Jsont.map
287 ~dec:(fun m ->
288 StringMap.bindings m
289 |> List.map (fun (k, v) -> (k, StringMap.bindings v)))
290 ~enc:(fun l ->
291 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l
292 |> List.to_seq |> StringMap.of_seq)
293 outer
294
295let failures_jsont =
296 let map_jsont = Jsont.Object.as_string_map Jsont.json in
297 Jsont.map
298 ~dec:(fun m -> StringMap.bindings m)
299 ~enc:(fun l -> List.to_seq l |> StringMap.of_seq)
300 map_jsont
301
302let query_keys_response_jsont =
303 Jsont.Object.(
304 map (fun failures device_keys -> { failures; device_keys })
305 |> mem "failures" failures_jsont ~dec_absent:[]
306 |> mem "device_keys" device_keys_map_jsont ~dec_absent:[]
307 |> finish)
308
309(* Query device keys *)
310let query_keys client ?timeout ~users () =
311 let device_keys = List.map (fun (user_id, device_ids) ->
312 (Matrix_proto.Id.User_id.to_string user_id, device_ids)
313 ) users in
314 let request = { timeout; device_keys } in
315 match Client.encode_body query_keys_request_jsont request with
316 | Error e -> Error e
317 | Ok body ->
318 match Client.post client ~path:"/keys/query" ~body () with
319 | Error e -> Error e
320 | Ok body -> Client.decode_response query_keys_response_jsont body
321
322(* Claim one-time keys request/response *)
323type claim_keys_request = {
324 timeout : int option;
325 one_time_keys : (string * (string * string) list) list; (* user_id -> device_id -> algorithm *)
326} [@@warning "-69"]
327
328let one_time_keys_claim_jsont =
329 let inner = Jsont.Object.as_string_map Jsont.string in
330 let outer = Jsont.Object.as_string_map inner in
331 Jsont.map
332 ~dec:(fun m ->
333 StringMap.bindings m
334 |> List.map (fun (k, v) -> (k, StringMap.bindings v)))
335 ~enc:(fun l ->
336 List.map (fun (k, v) -> (k, List.to_seq v |> StringMap.of_seq)) l
337 |> List.to_seq |> StringMap.of_seq)
338 outer
339
340let claim_keys_request_jsont =
341 Jsont.Object.(
342 map (fun timeout one_time_keys -> { timeout; one_time_keys })
343 |> opt_mem "timeout" Jsont.int ~enc:(fun t -> t.timeout)
344 |> mem "one_time_keys" one_time_keys_claim_jsont
345 |> finish)
346
347type claim_keys_response = {
348 failures : (string * Jsont.json) list;
349 one_time_keys : (string * (string * (string * one_time_key_json) list) list) list;
350}
351
352let claimed_keys_map_jsont =
353 let key_map = Jsont.Object.as_string_map one_time_key_json_jsont in
354 let device_map = Jsont.Object.as_string_map key_map in
355 let user_map = Jsont.Object.as_string_map device_map in
356 Jsont.map
357 ~dec:(fun m ->
358 StringMap.bindings m
359 |> List.map (fun (user, devices) ->
360 (user, StringMap.bindings devices
361 |> List.map (fun (dev, keys) -> (dev, StringMap.bindings keys)))))
362 ~enc:(fun l ->
363 List.map (fun (user, devices) ->
364 (user, List.map (fun (dev, keys) ->
365 (dev, List.to_seq keys |> StringMap.of_seq)) devices
366 |> List.to_seq |> StringMap.of_seq)) l
367 |> List.to_seq |> StringMap.of_seq)
368 user_map
369
370let claim_keys_response_jsont =
371 Jsont.Object.(
372 map (fun failures one_time_keys -> { failures; one_time_keys })
373 |> mem "failures" failures_jsont ~dec_absent:[]
374 |> mem "one_time_keys" claimed_keys_map_jsont ~dec_absent:[]
375 |> finish)
376
377(* Claim one-time keys for Olm session establishment *)
378let claim_keys client ?timeout ~keys () =
379 let one_time_keys = List.map (fun (user_id, device_keys) ->
380 let user_str = Matrix_proto.Id.User_id.to_string user_id in
381 let device_map = List.map (fun (device_id, algorithm) ->
382 (device_id, algorithm)
383 ) device_keys in
384 (user_str, device_map)
385 ) keys in
386 let request = { timeout; one_time_keys } in
387 match Client.encode_body claim_keys_request_jsont request with
388 | Error e -> Error e
389 | Ok body ->
390 match Client.post client ~path:"/keys/claim" ~body () with
391 | Error e -> Error e
392 | Ok body -> Client.decode_response claim_keys_response_jsont body
393
394(* Key changes tracking *)
395type key_changes_response = {
396 changed : string list;
397 left : string list;
398}
399
400let key_changes_response_jsont =
401 Jsont.Object.(
402 map (fun changed left -> { changed; left })
403 |> mem "changed" (Jsont.list Jsont.string) ~dec_absent:[]
404 |> mem "left" (Jsont.list Jsont.string) ~dec_absent:[]
405 |> finish)
406
407let get_key_changes client ~from ~to_ =
408 let query = [("from", from); ("to", to_)] in
409 match Client.get client ~path:"/keys/changes" ~query () with
410 | Error e -> Error e
411 | Ok body -> Client.decode_response key_changes_response_jsont body
412
413(* Helper to create signed device keys for upload *)
414let create_device_keys ~user_id ~device_id ~ed25519_keypair ~curve25519_keypair =
415 let user_str = Matrix_proto.Id.User_id.to_string user_id in
416 let ed25519_pub = ed25519_pub_to_base64 ed25519_keypair.pub in
417 let curve25519_pub = curve25519_pub_to_base64 curve25519_keypair.public in
418 let keys = [
419 (Printf.sprintf "ed25519:%s" device_id, ed25519_pub);
420 (Printf.sprintf "curve25519:%s" device_id, curve25519_pub);
421 ] in
422 let algorithms = [
423 "m.olm.v1.curve25519-aes-sha2-256";
424 "m.megolm.v1.aes-sha2-256";
425 ] in
426 (* Create unsigned JSON for signing *)
427 let unsigned_json = {
428 user_id = user_str;
429 device_id;
430 algorithms;
431 keys;
432 signatures = [];
433 } in
434 match Client.encode_body device_keys_json_jsont unsigned_json with
435 | Error _ -> Error "Failed to encode device keys"
436 | Ok json_str ->
437 let signature = sign_json ed25519_keypair.priv json_str in
438 let signatures = [
439 (user_str, [(Printf.sprintf "ed25519:%s" device_id, signature)])
440 ] in
441 Ok {
442 user_id = user_str;
443 device_id;
444 algorithms;
445 keys;
446 signatures;
447 }