upstream: https://github.com/mirage/mirage-crypto
1open Crypto.Uncommon
2open Common
3open Result.Syntax
4
5let two = Z.(~$2)
6and three = Z.(~$3)
7
8(* A constant-time [find_uint8] with a default value. *)
9let ct_find_uint8 ~default ?off ~f cs =
10 let res = Eqaf.find_uint8 ?off ~f cs in
11 Eqaf.select_int (res + 1) default res
12
13let ( &. ) f g = fun h -> f (g h)
14
15type 'a or_digest = [ `Message of 'a | `Digest of string ]
16
17module Digest_or (H : Digestif.S) = struct
18 let digest_or = function
19 | `Message msg -> H.(digest_string msg |> to_raw_string)
20 | `Digest digest ->
21 let n = String.length digest and m = H.digest_size in
22 if n = m then digest
23 else invalid_arg "(`Digest _): %d bytes, expecting %d" n m
24end
25
26exception Insufficient_key
27
28type pub = { e : Z.t; n : Z.t }
29
30(* due to PKCS1 *)
31let minimum_octets = 12
32let minimum_bits = (8 * minimum_octets) - 7
33
34let pub ~e ~n =
35 (* We cannot verify a public key being good (this would require to verify "n"
36 being the multiplication of two prime numbers - figuring out which primes
37 were used is the security property of RSA).
38
39 but we validate to ensure our usage of powm_sec does not lead to
40 exceptions, and we avoid tiny public keys where PKCS1 / PSS would lead to
41 infinite loops or not work due to insufficient space for the header. *)
42 let* () =
43 guard
44 Z.(n > zero && is_odd n && numbits n >= minimum_bits)
45 (`Msg "invalid modulus")
46 in
47 let* () = guard Z.(one < e && e < n) (`Msg "invalid exponent") in
48 (* NOTE that we could check for e being odd, or a prime, or 2^16+1, but
49 these are not requirements, neither for RSA nor for powm_sec *)
50 Ok { e; n }
51
52type priv = {
53 e : Z.t;
54 d : Z.t;
55 n : Z.t;
56 p : Z.t;
57 q : Z.t;
58 dp : Z.t;
59 dq : Z.t;
60 q' : Z.t;
61}
62
63let valid_prime name p =
64 guard
65 Z.(p > zero && is_odd p && Z_extra.pseudoprime p)
66 (`Msg ("invalid prime " ^ name))
67
68let rprime a b = Z.(gcd a b = one)
69
70let valid_e ~e ~p ~q =
71 let* () =
72 guard
73 (rprime e (Z.pred p) && rprime e (Z.pred q))
74 (`Msg "e is not coprime of p and q")
75 in
76 guard (Z_extra.pseudoprime e) (`Msg "exponent e is not a pseudoprime")
77
78let priv ~e ~d ~n ~p ~q ~dp ~dq ~q' =
79 let* _ = pub ~e ~n in
80 let* () = valid_prime "p" p in
81 let* () = valid_prime "q" q in
82 let* () = guard (p <> q) (`Msg "p and q are the same number") in
83 let* () = valid_e ~e ~p ~q in
84 (* p and q are prime, and not equal -> multiplicative inverse exists *)
85 let* () = guard Z.(q' = invert q p) (`Msg "q' <> q ^ -1 mod p") in
86 let* () =
87 guard Z.(n = p * q) (`Msg "modulus is not the product of p and q")
88 in
89 let* () = guard Z.(one < d && d < n) (`Msg "invalid private exponent") in
90 let* () = guard Z.(dp = d mod pred p) (`Msg "dp <> d mod (p - 1)") in
91 let* () = guard Z.(dq = d mod pred q) (`Msg "dq <> d mod (q - 1)") in
92 (* e has been checked (valid_e) to be coprime to p-1 and q-1 ->
93 muliplicative inverse exists *)
94 let* () =
95 guard
96 Z.(one = d * e mod lcm (pred p) (pred q))
97 (`Msg "1 <> d * e mod lcm (p - 1) (q - 1)")
98 in
99 Ok { e; d; n; p; q; dp; dq; q' }
100
101let priv_of_primes ~e ~p ~q =
102 let* () = valid_prime "p" p in
103 let* () = valid_prime "q" q in
104 let* () = guard (p <> q) (`Msg "p and q are the same prime") in
105 let* () = valid_e ~e ~p ~q in
106 let n = Z.(p * q) in
107 let* _ = pub ~e ~n in
108 (* valid_e checks e coprime to p-1 and q-1, a multiplicative inverse exists *)
109 let d = Z.(invert e (lcm (pred p) (pred q))) in
110 let dp = Z.(d mod pred p) and dq = Z.(d mod pred q) in
111 (* above we checked that p and q both are primes and not equal -> there
112 should be a multiplicate inverse *)
113 let q' = Z.invert q p in
114 (* does not need to check valid_priv, since it is valid by construction *)
115 Ok { e; d; n; p; q; dp; dq; q' }
116
117(* Handbook of applied cryptography, 8.2.2 (i). *)
118let priv_of_exp ?g ?(attempts = 100) ~e ~d ~n () =
119 let* _ = pub ~e ~n in
120 let* () = guard Z.(one < d && d < n) (`Msg "invalid private exponent") in
121 let rec doit ~attempts =
122 let factor s t =
123 let rec go ax = function
124 | 0 -> None
125 | i' ->
126 let ax2 = Z.(ax * ax mod n) in
127 if Z.(ax <> one && ax <> pred n && ax2 = one) then Some ax
128 else go ax2 (i' - 1)
129 in
130 Option.map Z.(gcd n &. pred) (go Z.(powm (Z_extra.gen ?g n) t n) s)
131 in
132 if attempts > 0 then
133 let* s, t = Z_extra.strip_factor ~f:two Z.(e * d |> pred) in
134 match s with
135 | 0 -> Error (`Msg "invalid factor 0")
136 | _ -> (
137 match factor s t with
138 | None -> doit ~attempts:(attempts - 1)
139 | Some p ->
140 let q = Z.(div n p) in
141 priv_of_primes ~e ~p:(max p q) ~q:(min p q))
142 else Error (`Msg "attempts exceeded")
143 in
144 doit ~attempts
145
146let rec generate ?g ?(e = Z.(~$0x10001)) ~bits () =
147 if
148 bits < minimum_bits || e < three
149 || bits <= Z.numbits e
150 || not (Z_extra.pseudoprime e)
151 then invalid_arg "Rsa.generate: e: %a, bits: %d" Z.pp_print e bits;
152 let pb, qb = (bits / 2, bits - (bits / 2)) in
153 let p, q = Z_extra.(prime ?g ~msb:2 pb, prime ?g ~msb:2 qb) in
154 match priv_of_primes ~e ~p:(max p q) ~q:(min p q) with
155 | Error _ -> generate ?g ~e ~bits ()
156 | Ok priv -> priv
157
158let pub_of_priv ({ e; n; _ } : priv) = { e; n }
159
160let pub_bits ({ n; _ } : pub) = Z.numbits n
161and priv_bits ({ n; _ } : priv) = Z.numbits n
162
163type mask = [ `No | `Yes | `Yes_with of Crypto_rng.g ]
164
165let encrypt_unsafe ~key:({ e; n } : pub) msg = Z.(powm msg e n)
166
167let decrypt_unsafe ~crt_hardening ~key:({ e; d; n; p; q; dp; dq; q' } : priv) c
168 =
169 let m1 = Z.(powm_sec c dp p) and m2 = Z.(powm_sec c dq q) in
170 (* NOTE: neither erem, nor the multiplications (addition, subtraction) are
171 guaranteed to be constant time by gmp *)
172 let h = Z.(erem (q' * (m1 - m2)) p) in
173 let m = Z.((h * q) + m2) in
174 (* counter Arjen Lenstra's CRT attack by verifying the signature. Since the
175 public exponent is small, this is not very expensive. Mentioned again
176 "Factoring RSA keys with TLS Perfect Forward Secrecy" (Weimer, 2015). *)
177 if (not crt_hardening) || Z.(powm_sec m e n) = c then m
178 else Z.(powm_sec c d n)
179
180let decrypt_blinded_unsafe ~crt_hardening ?g ~key:({ e; n; _ } as key : priv) c
181 =
182 let r = until (rprime n) (fun _ -> Z_extra.gen_r ?g two n) in
183 (* since r and n are coprime, there must be a multiplicative inverse *)
184 let r' = Z.(invert r n) in
185 let c' = Z.(powm_sec r e n * c mod n) in
186 let x = decrypt_unsafe ~crt_hardening ~key c' in
187 Z.(r' * x mod n)
188
189let encrypt_z, decrypt_z =
190 let check_params n msg =
191 if msg < two then invalid_arg "Rsa: message: %a" Z.pp_print msg;
192 if n <= msg then raise Insufficient_key
193 in
194 ( (fun ~(key : pub) msg ->
195 check_params key.n msg;
196 encrypt_unsafe ~key msg),
197 fun ~crt_hardening ~mask ~(key : priv) msg ->
198 check_params key.n msg;
199 match mask with
200 | `No -> decrypt_unsafe ~crt_hardening ~key msg
201 | `Yes -> decrypt_blinded_unsafe ~crt_hardening ~key msg
202 | `Yes_with g -> decrypt_blinded_unsafe ~crt_hardening ~g ~key msg )
203
204let reformat out f msg =
205 Z_extra.(of_octets_be msg |> f |> to_octets_be ~size:(out // 8))
206
207let encrypt ~key = reformat (pub_bits key) (encrypt_z ~key)
208
209let decrypt ?(crt_hardening = false) ?(mask = `Yes) ~key =
210 reformat (priv_bits key) (decrypt_z ~crt_hardening ~mask ~key)
211
212let bx00, bx01 = ("\x00", "\x01")
213
214module PKCS1 = struct
215 let min_pad = 8
216
217 (* XXX Generalize this into `Rng.samplev` or something. *)
218 let generate_with ?g ~f n =
219 let buf = Bytes.create n
220 and k =
221 let b = Crypto_rng.block g in
222 n // b * b
223 in
224 let rec go nonce i j =
225 if i = n then Bytes.unsafe_to_string buf
226 else if j = k then go Crypto_rng.(generate ?g k) i 0
227 else
228 match String.get_uint8 nonce j with
229 | b when f b ->
230 Bytes.set_uint8 buf i b;
231 go nonce (succ i) (succ j)
232 | _ -> go nonce i (succ j)
233 in
234 go Crypto_rng.(generate ?g k) 0 0
235
236 let pad ~mark ~padding k msg =
237 let pad = padding (k - String.length msg - 3 |> imax min_pad) in
238 String.concat "" [ bx00; mark; pad; bx00; msg ]
239
240 let unpad ~mark ~is_pad buf =
241 let f = not &. is_pad in
242 let i = ct_find_uint8 ~default:2 ~off:2 ~f buf in
243 let c1 = String.get_uint8 buf 0 = 0x00
244 and c2 = String.get_uint8 buf 1 = mark
245 and c3 = String.get_uint8 buf i = 0x00
246 and c4 = min_pad <= i - 2 in
247 if c1 && c2 && c3 && c4 then
248 Some (String.sub buf (i + 1) (String.length buf - i - 1))
249 else None
250
251 let pad_01 =
252 let padding size = String.make size '\xff' in
253 pad ~mark:"\x01" ~padding
254
255 let pad_02 ?g = pad ~mark:"\x02" ~padding:(generate_with ?g ~f:(( <> ) 0x00))
256 let unpad_01 = unpad ~mark:0x01 ~is_pad:(( = ) 0xff)
257 let unpad_02 = unpad ~mark:0x02 ~is_pad:(( <> ) 0x00)
258
259 let padded pad transform keybits msg =
260 let n = keybits // 8 in
261 let p = pad n msg in
262 if String.length p = n then transform p else raise Insufficient_key
263
264 let unpadded unpad transform keybits msg =
265 if String.length msg = keybits // 8 then
266 try unpad (transform msg) with Insufficient_key -> None
267 else None
268
269 let sig_encode ?(crt_hardening = true) ?mask ~key msg =
270 padded pad_01 (decrypt ~crt_hardening ?mask ~key) (priv_bits key) msg
271
272 let sig_decode ~key msg = unpadded unpad_01 (encrypt ~key) (pub_bits key) msg
273 let encrypt ?g ~key msg = padded (pad_02 ?g) (encrypt ~key) (pub_bits key) msg
274
275 let decrypt ?(crt_hardening = false) ?mask ~key msg =
276 unpadded unpad_02 (decrypt ~crt_hardening ?mask ~key) (priv_bits key) msg
277
278 let asn_of_hash, detect =
279 let map =
280 [
281 ( `MD5,
282 "\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10"
283 );
284 (`SHA1, "\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14");
285 ( `SHA224,
286 "\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c"
287 );
288 ( `SHA256,
289 "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"
290 );
291 ( `SHA384,
292 "\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30"
293 );
294 ( `SHA512,
295 "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40"
296 );
297 ]
298 in
299 ( (fun h -> List.assoc h map),
300 fun buf ->
301 List.find_opt (fun (_, d) -> String.starts_with ~prefix:d buf) map )
302
303 let sign ?(crt_hardening = true) ?mask ~hash ~key msg =
304 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash')) in
305 let module D = Digest_or (H) in
306 let msg' = asn_of_hash hash ^ D.digest_or msg in
307 sig_encode ~crt_hardening ?mask ~key msg'
308
309 let verify ~hashp ~key ~signature msg =
310 let ( >>= ) = Option.bind and ( >>| ) = Fun.flip Option.map in
311 Option.value
312 ( sig_decode ~key signature >>= fun buf ->
313 detect buf >>| fun (hash, asn) ->
314 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash'))
315 in
316 let module D = Digest_or (H) in
317 hashp hash && Eqaf.equal (asn ^ D.digest_or msg) buf )
318 ~default:false
319
320 let min_key hash =
321 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash')) in
322 ((String.length (asn_of_hash hash) + H.digest_size + min_pad + 2) * 8) + 1
323end
324
325module MGF1 (H : Digestif.S) = struct
326 let repr n =
327 let buf = Bytes.create 4 in
328 Bytes.set_int32_be buf 0 n;
329 Bytes.unsafe_to_string buf
330
331 (* Assumes len < 2^32 * H.digest_size. *)
332 let mgf ~seed len =
333 let rec go acc c = function
334 | 0 -> Bytes.sub (Bytes.concat Bytes.empty (List.rev acc)) 0 len
335 | n ->
336 let h = Bytes.create H.digest_size in
337 H.get_into_bytes (H.feedi_string H.empty (iter2 seed (repr c))) h;
338 go (h :: acc) Int32.(succ c) (pred n)
339 in
340 go [] 0l (len // H.digest_size)
341
342 let mask ~seed buf =
343 let mgf_data = mgf ~seed (String.length buf) in
344 unsafe_xor_into buf ~src_off:0 mgf_data ~dst_off:0 (String.length buf);
345 mgf_data
346end
347
348module OAEP (H : Digestif.S) = struct
349 module MGF = MGF1 (H)
350
351 let hlen = H.digest_size
352 let max_msg_bytes k = k - (2 * hlen) - 2
353
354 let eme_oaep_encode ?g ?(label = "") k msg =
355 let seed = Crypto_rng.generate ?g hlen
356 and pad = String.make (max_msg_bytes k - String.length msg) '\x00' in
357 let db =
358 String.concat ""
359 [ H.(digest_string label |> to_raw_string); pad; bx01; msg ]
360 in
361 let mdb = Bytes.unsafe_to_string (MGF.mask ~seed db) in
362 let mseed = Bytes.unsafe_to_string (MGF.mask ~seed:mdb seed) in
363 String.concat "" [ bx00; mseed; mdb ]
364
365 let eme_oaep_decode ?(label = "") msg =
366 let b0 = String.sub msg 0 1
367 and ms = String.sub msg 1 hlen
368 and mdb = String.sub msg (1 + hlen) (String.length msg - 1 - hlen) in
369 let db =
370 Bytes.unsafe_to_string
371 (MGF.mask ~seed:(Bytes.unsafe_to_string (MGF.mask ~seed:mdb ms)) mdb)
372 in
373 let i = ct_find_uint8 ~default:0 ~off:hlen ~f:(( <> ) 0x00) db in
374 let c1 =
375 Eqaf.equal (String.sub db 0 hlen) H.(digest_string label |> to_raw_string)
376 and c2 = String.get_uint8 b0 0 = 0x00
377 and c3 = String.get_uint8 db i = 0x01 in
378 if c1 && c2 && c3 then
379 Some (String.sub db (i + 1) (String.length db - i - 1))
380 else None
381
382 let encrypt ?g ?label ~key msg =
383 let k = pub_bits key // 8 in
384 if String.length msg > max_msg_bytes k then raise Insufficient_key
385 else encrypt ~key @@ eme_oaep_encode ?g ?label k msg
386
387 let decrypt ?(crt_hardening = false) ?mask ?label ~key em =
388 let k = priv_bits key // 8 in
389 if String.length em <> k || max_msg_bytes k < 0 then None
390 else
391 try eme_oaep_decode ?label @@ decrypt ~crt_hardening ?mask ~key em
392 with Insufficient_key -> None
393
394 (* XXX Review rfc3447 7.1.2 and
395 * http://archiv.infsec.ethz.ch/education/fs08/secsem/Manger01.pdf
396 * again for timing properties. *)
397
398 (* XXX expose seed for deterministic testing? *)
399end
400
401module PSS (H : Digestif.S) = struct
402 module MGF = MGF1 (H)
403 module H1 = Digest_or (H)
404
405 let hlen = H.digest_size
406 let bxbc = "\xbc"
407 let b0mask embits = 0xff lsr ((8 - (embits mod 8)) mod 8)
408 let zero_8 = String.make 8 '\x00'
409
410 let digest ~salt msg =
411 H.to_raw_string @@ H.digesti_string @@ iter3 zero_8 (H1.digest_or msg) salt
412
413 let emsa_pss_encode ?g slen emlen msg =
414 let n = emlen // 8 and salt = Crypto_rng.generate ?g slen in
415 let h = digest ~salt msg in
416 let db =
417 String.concat "" [ String.make (n - slen - hlen - 2) '\x00'; bx01; salt ]
418 in
419 let mdb = MGF.mask ~seed:h db in
420 Bytes.set_uint8 mdb 0 @@ (Bytes.get_uint8 mdb 0 land b0mask emlen);
421 String.concat "" [ Bytes.unsafe_to_string mdb; h; bxbc ]
422
423 let emsa_pss_verify slen emlen em msg =
424 let mdb = String.sub em 0 (String.length em - hlen - 1)
425 and h = String.sub em (String.length em - hlen - 1) hlen
426 and bxx = String.get_uint8 em (String.length em - 1) in
427 let db = MGF.mask ~seed:h mdb in
428 Bytes.set_uint8 db 0 (Bytes.get_uint8 db 0 land b0mask emlen);
429 let db = Bytes.unsafe_to_string db in
430 let salt = String.sub db (String.length db - slen) slen in
431 let h' = digest ~salt msg
432 and i = ct_find_uint8 ~default:0 ~f:(( <> ) 0x00) db in
433 let c1 = lnot (b0mask emlen) land String.get_uint8 mdb 0 = 0x00
434 and c2 = i = String.length em - hlen - slen - 2
435 and c3 = String.get_uint8 db i = 0x01
436 and c4 = bxx = 0xbc
437 and c5 = Eqaf.equal h h' in
438 c1 && c2 && c3 && c4 && c5
439
440 let sufficient_key ~slen kbits =
441 hlen + slen + 2 <= kbits / 8 (* 8 * (hlen + slen + 1) + 2 <= kbits *)
442
443 let sign ?g ?(crt_hardening = false) ?mask ?(slen = hlen) ~key msg =
444 let b = priv_bits key in
445 if not (sufficient_key ~slen b) then raise Insufficient_key
446 else
447 let msg' = emsa_pss_encode ?g (imax 0 slen) (b - 1) msg in
448 decrypt ~crt_hardening ?mask ~key msg'
449
450 let verify ?(slen = hlen) ~key ~signature msg =
451 let b = pub_bits key and s = String.length signature in
452 s = b // 8
453 && sufficient_key ~slen b
454 &&
455 try
456 let em = encrypt ~key signature in
457 let to_see = s - ((b - 1) // 8) in
458 emsa_pss_verify (imax 0 slen) (b - 1)
459 (String.sub em to_see (String.length em - to_see))
460 msg
461 with Insufficient_key -> false
462end