upstream: https://github.com/mirage/mirage-crypto
at main 309 lines 11 kB view raw
1open Wycheproof 2open Crypto_ec 3 4let ( let* ) = Result.bind 5let hex = Alcotest.testable Wycheproof.pp_hex Wycheproof.equal_hex 6 7module Asn = struct 8 let parse_point curve s = 9 let seq2 a b = Asn.S.(sequence2 (required a) (required b)) in 10 let term = Asn.S.(seq2 (seq2 oid oid) bit_string_octets) in 11 let ec_public_key = Asn.OID.(base 1 2 <|| [ 840; 10045; 2; 1 ]) in 12 let prime_oid = 13 match curve with 14 | "secp256r1" -> Asn.OID.(base 1 2 <|| [ 840; 10045; 3; 1; 7 ]) 15 | "secp384r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 34 ]) 16 | "secp521r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 35 ]) 17 | _ -> assert false 18 in 19 match Asn.decode (Asn.codec Asn.ber term) s with 20 | Error _ -> Error "ASN1 parse error" 21 | Ok (((oid1, oid2), data), rest) -> 22 if String.length rest <> 0 then Error "ASN1 leftover" 23 else if not (Asn.OID.equal oid1 ec_public_key) then 24 Error "ASN1: wrong oid 1" 25 else if not (Asn.OID.equal oid2 prime_oid) then 26 Error "ASN1: wrong oid 2" 27 else Ok data 28 29 let parse_signature cs = 30 let asn = 31 Asn.S.(sequence2 (required unsigned_integer) (required unsigned_integer)) 32 in 33 match Asn.(decode (codec der asn) cs) with 34 | Error _ -> Error "ASN1 parse error" 35 | Ok (r_s, rest) -> 36 if String.length rest <> 0 then Error "ASN1 leftover" else Ok r_s 37end 38 39let to_string_result ~pp_error = function 40 | Ok _ as ok -> ok 41 | Error e -> 42 let msg = Format.asprintf "%a" pp_error e in 43 Error msg 44 45let pad ~total_len buf = 46 match total_len - String.length buf with 47 | 0 -> Ok buf 48 | n when n < 0 -> 49 let is_zero = ref true in 50 for i = 0 to abs n - 1 do 51 if Bytes.(get_uint8 (Bytes.unsafe_of_string buf) i) <> 0 then 52 is_zero := false 53 done; 54 if !is_zero then Ok (String.sub buf (abs n) total_len) 55 else Error "input is too long" 56 | pad_len -> Ok (String.make pad_len '\000' ^ buf) 57 58let len = function 59 | "secp256r1" -> 32 60 | "secp384r1" -> 48 61 | "secp521r1" -> 66 62 | _ -> assert false 63 64let parse_secret curve s = 65 let total_len = len curve in 66 pad ~total_len s 67 68type test = { public_key : string; raw_private_key : string; expected : string } 69 70let perform_key_exchange curve ~public_key ~raw_private_key = 71 to_string_result ~pp_error 72 (match curve with 73 | "secp256r1" -> begin 74 match P256.Dh.secret_of_octets raw_private_key with 75 | Ok (p, _) -> P256.Dh.key_exchange p public_key 76 | Error _ -> assert false 77 end 78 | "secp384r1" -> begin 79 match P384.Dh.secret_of_octets raw_private_key with 80 | Ok (p, _) -> P384.Dh.key_exchange p public_key 81 | Error _ -> assert false 82 end 83 | "secp521r1" -> begin 84 match P521.Dh.secret_of_octets raw_private_key with 85 | Ok (p, _) -> P521.Dh.key_exchange p public_key 86 | Error _ -> assert false 87 end 88 | _ -> assert false) 89 90let interpret_test ~tcId curve { public_key; raw_private_key; expected } () = 91 match perform_key_exchange curve ~public_key ~raw_private_key with 92 | Ok got -> Alcotest.check hex __LOC__ expected got 93 | Error err -> 94 Printf.ksprintf (fun s -> Alcotest.fail s) "While parsing %d: %s" tcId err 95 96type invalid_test = { public : string; private_ : string } 97 98let is_ok = function Ok _ -> true | Error _ -> false 99 100let interpret_invalid_test curve { public; private_ } () = 101 let result = 102 let* public_key = Asn.parse_point curve public in 103 let* raw_private_key = parse_secret curve private_ in 104 perform_key_exchange curve ~public_key ~raw_private_key 105 in 106 Alcotest.check Alcotest.bool __LOC__ false (is_ok result) 107 108type strategy = Test of test | Invalid_test of invalid_test | Skip 109 110let make_ecdh_test curve (test : ecdh_test) = 111 let ignored_flags = [ "UnnamedCurve" ] in 112 let curve_compression_test curve = 113 let curves = [ "secp256r1"; "secp384r1"; "secp521r1" ] in 114 test.tcId = 2 && List.exists (fun x -> String.equal x curve) curves 115 in 116 match test.result with 117 | _ when has_ignored_flag test ~ignored_flags -> Ok Skip 118 | Invalid -> 119 Ok (Invalid_test { public = test.public; private_ = test.private_ }) 120 | Acceptable when curve_compression_test curve -> 121 let* public_key = Asn.parse_point curve test.public in 122 let* raw_private_key = parse_secret curve test.private_ in 123 Ok (Test { public_key; raw_private_key; expected = test.shared }) 124 | Acceptable -> Ok Skip 125 | Valid -> 126 let* public_key = Asn.parse_point curve test.public in 127 let* raw_private_key = parse_secret curve test.private_ in 128 Ok (Test { public_key; raw_private_key; expected = test.shared }) 129 130let to_ecdh_tests curve (x : ecdh_test) = 131 let name = Printf.sprintf "%d - %s" x.tcId x.comment in 132 match make_ecdh_test curve x with 133 | Ok (Test t) -> [ (name, `Quick, interpret_test ~tcId:x.tcId curve t) ] 134 | Ok (Invalid_test t) -> [ (name, `Quick, interpret_invalid_test curve t) ] 135 | Ok Skip -> [] 136 | Error e -> Printf.ksprintf failwith "While parsing %d: %s" x.tcId e 137 138let ecdh_tests file = 139 let data = load_file_exn file in 140 let groups : ecdh_test_group list = 141 List.map ecdh_test_group_exn data.testGroups 142 in 143 List.concat_map 144 (fun (group : ecdh_test_group) -> 145 List.concat_map (to_ecdh_tests group.curve) group.tests) 146 groups 147 148let make_ecdsa_test curve key hash (tst : dsa_test) = 149 let name = Printf.sprintf "%d - %s" tst.tcId tst.comment in 150 let size = len curve in 151 let msg = 152 let dgst = 153 match hash with 154 | "SHA-256" -> Digestif.SHA256.(digest_string tst.msg |> to_raw_string) 155 | "SHA-384" -> Digestif.SHA384.(digest_string tst.msg |> to_raw_string) 156 | "SHA-512" -> Digestif.SHA512.(digest_string tst.msg |> to_raw_string) 157 | "SHA-224" -> Digestif.SHA224.(digest_string tst.msg |> to_raw_string) 158 | _ -> assert false 159 in 160 String.sub dgst 0 (min size (String.length dgst)) 161 in 162 let verified (r, s) = 163 match curve with 164 | "secp256r1" -> begin 165 match P256.Dsa.pub_of_octets key with 166 | Ok key -> P256.Dsa.verify ~key (r, s) msg 167 | Error _ -> assert false 168 end 169 | "secp384r1" -> begin 170 match P384.Dsa.pub_of_octets key with 171 | Ok key -> P384.Dsa.verify ~key (r, s) msg 172 | Error _ -> assert false 173 end 174 | "secp521r1" -> begin 175 match P521.Dsa.pub_of_octets key with 176 | Ok key -> P521.Dsa.verify ~key (r, s) msg 177 | Error _ -> assert false 178 end 179 | _ -> assert false 180 in 181 match tst.result with 182 | Acceptable | Invalid -> 183 let f () = 184 match Asn.parse_signature tst.sig_ with 185 | Ok (r, s) -> Alcotest.(check bool __LOC__ false (verified (r, s))) 186 | Error _s -> () 187 in 188 (name, `Quick, f) 189 | Valid -> 190 let f () = 191 match Asn.parse_signature tst.sig_ with 192 | Ok (r, s) -> Alcotest.(check bool __LOC__ true (verified (r, s))) 193 | Error s -> Alcotest.fail s 194 in 195 (name, `Quick, f) 196 197let to_ecdsa_tests (x : ecdsa_test_group) = 198 List.map (make_ecdsa_test x.key.curve x.key.uncompressed x.sha) x.tests 199 200let ecdsa_tests file = 201 let data = load_file_exn file in 202 let groups : ecdsa_test_group list = 203 List.map ecdsa_test_group_exn data.testGroups 204 in 205 List.concat_map to_ecdsa_tests groups 206 207let to_x25519_test (x : ecdh_test) = 208 let name = Printf.sprintf "%d - %s" x.tcId x.comment 209 and priv = 210 match X25519.secret_of_octets x.private_ with 211 | Ok (p, _) -> p 212 | Error _ -> assert false 213 in 214 match x.result with 215 | Acceptable -> 216 let f () = 217 match 218 ( X25519.key_exchange priv x.public, 219 has_ignored_flag x ~ignored_flags:[ "LowOrderPublic" ] ) 220 with 221 | Ok _, true -> Alcotest.fail "acceptable should have errored" 222 | Ok r, false -> 223 Alcotest.(check bool __LOC__ true (String.equal r x.shared)) 224 | Error _, true -> () 225 | Error e, false -> Alcotest.failf "acceptable errored %a" pp_error e 226 in 227 (name, `Quick, f) 228 | Invalid -> 229 let f () = 230 match X25519.key_exchange priv x.public with 231 | Ok r -> Alcotest.(check bool __LOC__ false (String.equal r x.shared)) 232 | Error e -> Alcotest.failf "invalid errored %a" pp_error e 233 in 234 (name, `Quick, f) 235 | Valid -> 236 let f () = 237 match X25519.key_exchange priv x.public with 238 | Ok r -> Alcotest.(check bool __LOC__ true (String.equal r x.shared)) 239 | Error e -> Alcotest.failf "valid errored %a" pp_error e 240 in 241 (name, `Quick, f) 242 243let x25519_tests = 244 let data = load_file_exn "x25519_test.json" in 245 let groups : ecdh_test_group list = 246 List.map ecdh_test_group_exn data.testGroups 247 in 248 List.concat_map 249 (fun (group : ecdh_test_group) -> List.map to_x25519_test group.tests) 250 groups 251 252let to_ed25519_test (priv, pub) (x : dsa_test) = 253 let name = Printf.sprintf "%d - %s" x.tcId x.comment in 254 match x.result with 255 | Invalid -> 256 let f () = 257 Alcotest.( 258 check bool __LOC__ false (Ed25519.verify ~key:pub x.sig_ ~msg:x.msg)); 259 let s = Ed25519.sign ~key:priv x.msg in 260 Alcotest.(check bool __LOC__ false (String.equal s x.sig_)) 261 in 262 (name, `Quick, f) 263 | Valid -> 264 let f () = 265 Alcotest.( 266 check bool __LOC__ true (Ed25519.verify ~key:pub x.sig_ ~msg:x.msg)); 267 let s = Ed25519.sign ~key:priv x.msg in 268 Alcotest.(check bool __LOC__ true (String.equal s x.sig_)) 269 in 270 (name, `Quick, f) 271 | Acceptable -> assert false 272 273let to_ed25519_keys (key : eddsa_key) = 274 match (Ed25519.priv_of_octets key.sk, Ed25519.pub_of_octets key.pk) with 275 | Ok priv, Ok pub -> 276 assert (String.equal Ed25519.(pub_to_octets (pub_of_priv priv)) key.pk); 277 (priv, pub) 278 | _ -> assert false 279 280let ed25519_tests = 281 let data = load_file_exn "eddsa_test.json" in 282 let groups : eddsa_test_group list = 283 List.map eddsa_test_group_exn data.testGroups 284 in 285 List.concat_map 286 (fun (group : eddsa_test_group) -> 287 let keys = to_ed25519_keys group.key in 288 List.map (to_ed25519_test keys) group.tests) 289 groups 290 291let () = 292 Alcotest.run "Wycheproof NIST curves" 293 [ 294 ("ECDH P256 test vectors", ecdh_tests "ecdh_secp256r1_test.json"); 295 ( "ECDSA P256 test vectors (SHA256)", 296 ecdsa_tests "ecdsa_secp256r1_sha256_test.json" ); 297 ( "ECDSA P256 test vectors (SHA512)", 298 ecdsa_tests "ecdsa_secp256r1_sha512_test.json" ); 299 ("ECDH P384 test vectors", ecdh_tests "ecdh_secp384r1_test.json"); 300 ( "ECDSA P384 test vectors (SHA384)", 301 ecdsa_tests "ecdsa_secp384r1_sha384_test.json" ); 302 ( "ECDSA P384 test vectors (SHA512)", 303 ecdsa_tests "ecdsa_secp384r1_sha512_test.json" ); 304 ("ECDH P521 test vectors", ecdh_tests "ecdh_secp521r1_test.json"); 305 ( "ECDSA P521 test vectors (SHA512)", 306 ecdsa_tests "ecdsa_secp521r1_sha512_test.json" ); 307 ("X25519 test vectors", x25519_tests); 308 ("ED25519 test vectors", ed25519_tests); 309 ]