upstream: https://github.com/mirage/mirage-crypto
at main 462 lines 16 kB view raw
1open Crypto.Uncommon 2open Common 3open Result.Syntax 4 5let two = Z.(~$2) 6and three = Z.(~$3) 7 8(* A constant-time [find_uint8] with a default value. *) 9let ct_find_uint8 ~default ?off ~f cs = 10 let res = Eqaf.find_uint8 ?off ~f cs in 11 Eqaf.select_int (res + 1) default res 12 13let ( &. ) f g = fun h -> f (g h) 14 15type 'a or_digest = [ `Message of 'a | `Digest of string ] 16 17module Digest_or (H : Digestif.S) = struct 18 let digest_or = function 19 | `Message msg -> H.(digest_string msg |> to_raw_string) 20 | `Digest digest -> 21 let n = String.length digest and m = H.digest_size in 22 if n = m then digest 23 else invalid_arg "(`Digest _): %d bytes, expecting %d" n m 24end 25 26exception Insufficient_key 27 28type pub = { e : Z.t; n : Z.t } 29 30(* due to PKCS1 *) 31let minimum_octets = 12 32let minimum_bits = (8 * minimum_octets) - 7 33 34let pub ~e ~n = 35 (* We cannot verify a public key being good (this would require to verify "n" 36 being the multiplication of two prime numbers - figuring out which primes 37 were used is the security property of RSA). 38 39 but we validate to ensure our usage of powm_sec does not lead to 40 exceptions, and we avoid tiny public keys where PKCS1 / PSS would lead to 41 infinite loops or not work due to insufficient space for the header. *) 42 let* () = 43 guard 44 Z.(n > zero && is_odd n && numbits n >= minimum_bits) 45 (`Msg "invalid modulus") 46 in 47 let* () = guard Z.(one < e && e < n) (`Msg "invalid exponent") in 48 (* NOTE that we could check for e being odd, or a prime, or 2^16+1, but 49 these are not requirements, neither for RSA nor for powm_sec *) 50 Ok { e; n } 51 52type priv = { 53 e : Z.t; 54 d : Z.t; 55 n : Z.t; 56 p : Z.t; 57 q : Z.t; 58 dp : Z.t; 59 dq : Z.t; 60 q' : Z.t; 61} 62 63let valid_prime name p = 64 guard 65 Z.(p > zero && is_odd p && Z_extra.pseudoprime p) 66 (`Msg ("invalid prime " ^ name)) 67 68let rprime a b = Z.(gcd a b = one) 69 70let valid_e ~e ~p ~q = 71 let* () = 72 guard 73 (rprime e (Z.pred p) && rprime e (Z.pred q)) 74 (`Msg "e is not coprime of p and q") 75 in 76 guard (Z_extra.pseudoprime e) (`Msg "exponent e is not a pseudoprime") 77 78let priv ~e ~d ~n ~p ~q ~dp ~dq ~q' = 79 let* _ = pub ~e ~n in 80 let* () = valid_prime "p" p in 81 let* () = valid_prime "q" q in 82 let* () = guard (p <> q) (`Msg "p and q are the same number") in 83 let* () = valid_e ~e ~p ~q in 84 (* p and q are prime, and not equal -> multiplicative inverse exists *) 85 let* () = guard Z.(q' = invert q p) (`Msg "q' <> q ^ -1 mod p") in 86 let* () = 87 guard Z.(n = p * q) (`Msg "modulus is not the product of p and q") 88 in 89 let* () = guard Z.(one < d && d < n) (`Msg "invalid private exponent") in 90 let* () = guard Z.(dp = d mod pred p) (`Msg "dp <> d mod (p - 1)") in 91 let* () = guard Z.(dq = d mod pred q) (`Msg "dq <> d mod (q - 1)") in 92 (* e has been checked (valid_e) to be coprime to p-1 and q-1 -> 93 muliplicative inverse exists *) 94 let* () = 95 guard 96 Z.(one = d * e mod lcm (pred p) (pred q)) 97 (`Msg "1 <> d * e mod lcm (p - 1) (q - 1)") 98 in 99 Ok { e; d; n; p; q; dp; dq; q' } 100 101let priv_of_primes ~e ~p ~q = 102 let* () = valid_prime "p" p in 103 let* () = valid_prime "q" q in 104 let* () = guard (p <> q) (`Msg "p and q are the same prime") in 105 let* () = valid_e ~e ~p ~q in 106 let n = Z.(p * q) in 107 let* _ = pub ~e ~n in 108 (* valid_e checks e coprime to p-1 and q-1, a multiplicative inverse exists *) 109 let d = Z.(invert e (lcm (pred p) (pred q))) in 110 let dp = Z.(d mod pred p) and dq = Z.(d mod pred q) in 111 (* above we checked that p and q both are primes and not equal -> there 112 should be a multiplicate inverse *) 113 let q' = Z.invert q p in 114 (* does not need to check valid_priv, since it is valid by construction *) 115 Ok { e; d; n; p; q; dp; dq; q' } 116 117(* Handbook of applied cryptography, 8.2.2 (i). *) 118let priv_of_exp ?g ?(attempts = 100) ~e ~d ~n () = 119 let* _ = pub ~e ~n in 120 let* () = guard Z.(one < d && d < n) (`Msg "invalid private exponent") in 121 let rec doit ~attempts = 122 let factor s t = 123 let rec go ax = function 124 | 0 -> None 125 | i' -> 126 let ax2 = Z.(ax * ax mod n) in 127 if Z.(ax <> one && ax <> pred n && ax2 = one) then Some ax 128 else go ax2 (i' - 1) 129 in 130 Option.map Z.(gcd n &. pred) (go Z.(powm (Z_extra.gen ?g n) t n) s) 131 in 132 if attempts > 0 then 133 let* s, t = Z_extra.strip_factor ~f:two Z.(e * d |> pred) in 134 match s with 135 | 0 -> Error (`Msg "invalid factor 0") 136 | _ -> ( 137 match factor s t with 138 | None -> doit ~attempts:(attempts - 1) 139 | Some p -> 140 let q = Z.(div n p) in 141 priv_of_primes ~e ~p:(max p q) ~q:(min p q)) 142 else Error (`Msg "attempts exceeded") 143 in 144 doit ~attempts 145 146let rec generate ?g ?(e = Z.(~$0x10001)) ~bits () = 147 if 148 bits < minimum_bits || e < three 149 || bits <= Z.numbits e 150 || not (Z_extra.pseudoprime e) 151 then invalid_arg "Rsa.generate: e: %a, bits: %d" Z.pp_print e bits; 152 let pb, qb = (bits / 2, bits - (bits / 2)) in 153 let p, q = Z_extra.(prime ?g ~msb:2 pb, prime ?g ~msb:2 qb) in 154 match priv_of_primes ~e ~p:(max p q) ~q:(min p q) with 155 | Error _ -> generate ?g ~e ~bits () 156 | Ok priv -> priv 157 158let pub_of_priv ({ e; n; _ } : priv) = { e; n } 159 160let pub_bits ({ n; _ } : pub) = Z.numbits n 161and priv_bits ({ n; _ } : priv) = Z.numbits n 162 163type mask = [ `No | `Yes | `Yes_with of Crypto_rng.g ] 164 165let encrypt_unsafe ~key:({ e; n } : pub) msg = Z.(powm msg e n) 166 167let decrypt_unsafe ~crt_hardening ~key:({ e; d; n; p; q; dp; dq; q' } : priv) c 168 = 169 let m1 = Z.(powm_sec c dp p) and m2 = Z.(powm_sec c dq q) in 170 (* NOTE: neither erem, nor the multiplications (addition, subtraction) are 171 guaranteed to be constant time by gmp *) 172 let h = Z.(erem (q' * (m1 - m2)) p) in 173 let m = Z.((h * q) + m2) in 174 (* counter Arjen Lenstra's CRT attack by verifying the signature. Since the 175 public exponent is small, this is not very expensive. Mentioned again 176 "Factoring RSA keys with TLS Perfect Forward Secrecy" (Weimer, 2015). *) 177 if (not crt_hardening) || Z.(powm_sec m e n) = c then m 178 else Z.(powm_sec c d n) 179 180let decrypt_blinded_unsafe ~crt_hardening ?g ~key:({ e; n; _ } as key : priv) c 181 = 182 let r = until (rprime n) (fun _ -> Z_extra.gen_r ?g two n) in 183 (* since r and n are coprime, there must be a multiplicative inverse *) 184 let r' = Z.(invert r n) in 185 let c' = Z.(powm_sec r e n * c mod n) in 186 let x = decrypt_unsafe ~crt_hardening ~key c' in 187 Z.(r' * x mod n) 188 189let encrypt_z, decrypt_z = 190 let check_params n msg = 191 if msg < two then invalid_arg "Rsa: message: %a" Z.pp_print msg; 192 if n <= msg then raise Insufficient_key 193 in 194 ( (fun ~(key : pub) msg -> 195 check_params key.n msg; 196 encrypt_unsafe ~key msg), 197 fun ~crt_hardening ~mask ~(key : priv) msg -> 198 check_params key.n msg; 199 match mask with 200 | `No -> decrypt_unsafe ~crt_hardening ~key msg 201 | `Yes -> decrypt_blinded_unsafe ~crt_hardening ~key msg 202 | `Yes_with g -> decrypt_blinded_unsafe ~crt_hardening ~g ~key msg ) 203 204let reformat out f msg = 205 Z_extra.(of_octets_be msg |> f |> to_octets_be ~size:(out // 8)) 206 207let encrypt ~key = reformat (pub_bits key) (encrypt_z ~key) 208 209let decrypt ?(crt_hardening = false) ?(mask = `Yes) ~key = 210 reformat (priv_bits key) (decrypt_z ~crt_hardening ~mask ~key) 211 212let bx00, bx01 = ("\x00", "\x01") 213 214module PKCS1 = struct 215 let min_pad = 8 216 217 (* XXX Generalize this into `Rng.samplev` or something. *) 218 let generate_with ?g ~f n = 219 let buf = Bytes.create n 220 and k = 221 let b = Crypto_rng.block g in 222 n // b * b 223 in 224 let rec go nonce i j = 225 if i = n then Bytes.unsafe_to_string buf 226 else if j = k then go Crypto_rng.(generate ?g k) i 0 227 else 228 match String.get_uint8 nonce j with 229 | b when f b -> 230 Bytes.set_uint8 buf i b; 231 go nonce (succ i) (succ j) 232 | _ -> go nonce i (succ j) 233 in 234 go Crypto_rng.(generate ?g k) 0 0 235 236 let pad ~mark ~padding k msg = 237 let pad = padding (k - String.length msg - 3 |> imax min_pad) in 238 String.concat "" [ bx00; mark; pad; bx00; msg ] 239 240 let unpad ~mark ~is_pad buf = 241 let f = not &. is_pad in 242 let i = ct_find_uint8 ~default:2 ~off:2 ~f buf in 243 let c1 = String.get_uint8 buf 0 = 0x00 244 and c2 = String.get_uint8 buf 1 = mark 245 and c3 = String.get_uint8 buf i = 0x00 246 and c4 = min_pad <= i - 2 in 247 if c1 && c2 && c3 && c4 then 248 Some (String.sub buf (i + 1) (String.length buf - i - 1)) 249 else None 250 251 let pad_01 = 252 let padding size = String.make size '\xff' in 253 pad ~mark:"\x01" ~padding 254 255 let pad_02 ?g = pad ~mark:"\x02" ~padding:(generate_with ?g ~f:(( <> ) 0x00)) 256 let unpad_01 = unpad ~mark:0x01 ~is_pad:(( = ) 0xff) 257 let unpad_02 = unpad ~mark:0x02 ~is_pad:(( <> ) 0x00) 258 259 let padded pad transform keybits msg = 260 let n = keybits // 8 in 261 let p = pad n msg in 262 if String.length p = n then transform p else raise Insufficient_key 263 264 let unpadded unpad transform keybits msg = 265 if String.length msg = keybits // 8 then 266 try unpad (transform msg) with Insufficient_key -> None 267 else None 268 269 let sig_encode ?(crt_hardening = true) ?mask ~key msg = 270 padded pad_01 (decrypt ~crt_hardening ?mask ~key) (priv_bits key) msg 271 272 let sig_decode ~key msg = unpadded unpad_01 (encrypt ~key) (pub_bits key) msg 273 let encrypt ?g ~key msg = padded (pad_02 ?g) (encrypt ~key) (pub_bits key) msg 274 275 let decrypt ?(crt_hardening = false) ?mask ~key msg = 276 unpadded unpad_02 (decrypt ~crt_hardening ?mask ~key) (priv_bits key) msg 277 278 let asn_of_hash, detect = 279 let map = 280 [ 281 ( `MD5, 282 "\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10" 283 ); 284 (`SHA1, "\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14"); 285 ( `SHA224, 286 "\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c" 287 ); 288 ( `SHA256, 289 "\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20" 290 ); 291 ( `SHA384, 292 "\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30" 293 ); 294 ( `SHA512, 295 "\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40" 296 ); 297 ] 298 in 299 ( (fun h -> List.assoc h map), 300 fun buf -> 301 List.find_opt (fun (_, d) -> String.starts_with ~prefix:d buf) map ) 302 303 let sign ?(crt_hardening = true) ?mask ~hash ~key msg = 304 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash')) in 305 let module D = Digest_or (H) in 306 let msg' = asn_of_hash hash ^ D.digest_or msg in 307 sig_encode ~crt_hardening ?mask ~key msg' 308 309 let verify ~hashp ~key ~signature msg = 310 let ( >>= ) = Option.bind and ( >>| ) = Fun.flip Option.map in 311 Option.value 312 ( sig_decode ~key signature >>= fun buf -> 313 detect buf >>| fun (hash, asn) -> 314 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash')) 315 in 316 let module D = Digest_or (H) in 317 hashp hash && Eqaf.equal (asn ^ D.digest_or msg) buf ) 318 ~default:false 319 320 let min_key hash = 321 let module H = (val Digestif.module_of_hash' (hash :> Digestif.hash')) in 322 ((String.length (asn_of_hash hash) + H.digest_size + min_pad + 2) * 8) + 1 323end 324 325module MGF1 (H : Digestif.S) = struct 326 let repr n = 327 let buf = Bytes.create 4 in 328 Bytes.set_int32_be buf 0 n; 329 Bytes.unsafe_to_string buf 330 331 (* Assumes len < 2^32 * H.digest_size. *) 332 let mgf ~seed len = 333 let rec go acc c = function 334 | 0 -> Bytes.sub (Bytes.concat Bytes.empty (List.rev acc)) 0 len 335 | n -> 336 let h = Bytes.create H.digest_size in 337 H.get_into_bytes (H.feedi_string H.empty (iter2 seed (repr c))) h; 338 go (h :: acc) Int32.(succ c) (pred n) 339 in 340 go [] 0l (len // H.digest_size) 341 342 let mask ~seed buf = 343 let mgf_data = mgf ~seed (String.length buf) in 344 unsafe_xor_into buf ~src_off:0 mgf_data ~dst_off:0 (String.length buf); 345 mgf_data 346end 347 348module OAEP (H : Digestif.S) = struct 349 module MGF = MGF1 (H) 350 351 let hlen = H.digest_size 352 let max_msg_bytes k = k - (2 * hlen) - 2 353 354 let eme_oaep_encode ?g ?(label = "") k msg = 355 let seed = Crypto_rng.generate ?g hlen 356 and pad = String.make (max_msg_bytes k - String.length msg) '\x00' in 357 let db = 358 String.concat "" 359 [ H.(digest_string label |> to_raw_string); pad; bx01; msg ] 360 in 361 let mdb = Bytes.unsafe_to_string (MGF.mask ~seed db) in 362 let mseed = Bytes.unsafe_to_string (MGF.mask ~seed:mdb seed) in 363 String.concat "" [ bx00; mseed; mdb ] 364 365 let eme_oaep_decode ?(label = "") msg = 366 let b0 = String.sub msg 0 1 367 and ms = String.sub msg 1 hlen 368 and mdb = String.sub msg (1 + hlen) (String.length msg - 1 - hlen) in 369 let db = 370 Bytes.unsafe_to_string 371 (MGF.mask ~seed:(Bytes.unsafe_to_string (MGF.mask ~seed:mdb ms)) mdb) 372 in 373 let i = ct_find_uint8 ~default:0 ~off:hlen ~f:(( <> ) 0x00) db in 374 let c1 = 375 Eqaf.equal (String.sub db 0 hlen) H.(digest_string label |> to_raw_string) 376 and c2 = String.get_uint8 b0 0 = 0x00 377 and c3 = String.get_uint8 db i = 0x01 in 378 if c1 && c2 && c3 then 379 Some (String.sub db (i + 1) (String.length db - i - 1)) 380 else None 381 382 let encrypt ?g ?label ~key msg = 383 let k = pub_bits key // 8 in 384 if String.length msg > max_msg_bytes k then raise Insufficient_key 385 else encrypt ~key @@ eme_oaep_encode ?g ?label k msg 386 387 let decrypt ?(crt_hardening = false) ?mask ?label ~key em = 388 let k = priv_bits key // 8 in 389 if String.length em <> k || max_msg_bytes k < 0 then None 390 else 391 try eme_oaep_decode ?label @@ decrypt ~crt_hardening ?mask ~key em 392 with Insufficient_key -> None 393 394 (* XXX Review rfc3447 7.1.2 and 395 * http://archiv.infsec.ethz.ch/education/fs08/secsem/Manger01.pdf 396 * again for timing properties. *) 397 398 (* XXX expose seed for deterministic testing? *) 399end 400 401module PSS (H : Digestif.S) = struct 402 module MGF = MGF1 (H) 403 module H1 = Digest_or (H) 404 405 let hlen = H.digest_size 406 let bxbc = "\xbc" 407 let b0mask embits = 0xff lsr ((8 - (embits mod 8)) mod 8) 408 let zero_8 = String.make 8 '\x00' 409 410 let digest ~salt msg = 411 H.to_raw_string @@ H.digesti_string @@ iter3 zero_8 (H1.digest_or msg) salt 412 413 let emsa_pss_encode ?g slen emlen msg = 414 let n = emlen // 8 and salt = Crypto_rng.generate ?g slen in 415 let h = digest ~salt msg in 416 let db = 417 String.concat "" [ String.make (n - slen - hlen - 2) '\x00'; bx01; salt ] 418 in 419 let mdb = MGF.mask ~seed:h db in 420 Bytes.set_uint8 mdb 0 @@ (Bytes.get_uint8 mdb 0 land b0mask emlen); 421 String.concat "" [ Bytes.unsafe_to_string mdb; h; bxbc ] 422 423 let emsa_pss_verify slen emlen em msg = 424 let mdb = String.sub em 0 (String.length em - hlen - 1) 425 and h = String.sub em (String.length em - hlen - 1) hlen 426 and bxx = String.get_uint8 em (String.length em - 1) in 427 let db = MGF.mask ~seed:h mdb in 428 Bytes.set_uint8 db 0 (Bytes.get_uint8 db 0 land b0mask emlen); 429 let db = Bytes.unsafe_to_string db in 430 let salt = String.sub db (String.length db - slen) slen in 431 let h' = digest ~salt msg 432 and i = ct_find_uint8 ~default:0 ~f:(( <> ) 0x00) db in 433 let c1 = lnot (b0mask emlen) land String.get_uint8 mdb 0 = 0x00 434 and c2 = i = String.length em - hlen - slen - 2 435 and c3 = String.get_uint8 db i = 0x01 436 and c4 = bxx = 0xbc 437 and c5 = Eqaf.equal h h' in 438 c1 && c2 && c3 && c4 && c5 439 440 let sufficient_key ~slen kbits = 441 hlen + slen + 2 <= kbits / 8 (* 8 * (hlen + slen + 1) + 2 <= kbits *) 442 443 let sign ?g ?(crt_hardening = false) ?mask ?(slen = hlen) ~key msg = 444 let b = priv_bits key in 445 if not (sufficient_key ~slen b) then raise Insufficient_key 446 else 447 let msg' = emsa_pss_encode ?g (imax 0 slen) (b - 1) msg in 448 decrypt ~crt_hardening ?mask ~key msg' 449 450 let verify ?(slen = hlen) ~key ~signature msg = 451 let b = pub_bits key and s = String.length signature in 452 s = b // 8 453 && sufficient_key ~slen b 454 && 455 try 456 let em = encrypt ~key signature in 457 let to_see = s - ((b - 1) // 8) in 458 emsa_pss_verify (imax 0 slen) (b - 1) 459 (String.sub em to_see (String.length em - to_see)) 460 msg 461 with Insufficient_key -> false 462end