forked from
gazagnaire.org/ocaml-crypto
upstream: https://github.com/mirage/mirage-crypto
1open Wycheproof
2open Crypto_ec
3
4let ( let* ) = Result.bind
5let hex = Alcotest.testable Wycheproof.pp_hex Wycheproof.equal_hex
6
7module Asn = struct
8 let parse_point curve s =
9 let seq2 a b = Asn.S.(sequence2 (required a) (required b)) in
10 let term = Asn.S.(seq2 (seq2 oid oid) bit_string_octets) in
11 let ec_public_key = Asn.OID.(base 1 2 <|| [ 840; 10045; 2; 1 ]) in
12 let prime_oid =
13 match curve with
14 | "secp256r1" -> Asn.OID.(base 1 2 <|| [ 840; 10045; 3; 1; 7 ])
15 | "secp384r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 34 ])
16 | "secp521r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 35 ])
17 | _ -> assert false
18 in
19 match Asn.decode (Asn.codec Asn.ber term) s with
20 | Error _ -> Error "ASN1 parse error"
21 | Ok (((oid1, oid2), data), rest) ->
22 if String.length rest <> 0 then Error "ASN1 leftover"
23 else if not (Asn.OID.equal oid1 ec_public_key) then
24 Error "ASN1: wrong oid 1"
25 else if not (Asn.OID.equal oid2 prime_oid) then
26 Error "ASN1: wrong oid 2"
27 else Ok data
28
29 let parse_signature cs =
30 let asn =
31 Asn.S.(sequence2 (required unsigned_integer) (required unsigned_integer))
32 in
33 match Asn.(decode (codec der asn) cs) with
34 | Error _ -> Error "ASN1 parse error"
35 | Ok (r_s, rest) ->
36 if String.length rest <> 0 then Error "ASN1 leftover" else Ok r_s
37end
38
39let to_string_result ~pp_error = function
40 | Ok _ as ok -> ok
41 | Error e ->
42 let msg = Format.asprintf "%a" pp_error e in
43 Error msg
44
45let pad ~total_len buf =
46 match total_len - String.length buf with
47 | 0 -> Ok buf
48 | n when n < 0 ->
49 let is_zero = ref true in
50 for i = 0 to abs n - 1 do
51 if Bytes.(get_uint8 (Bytes.unsafe_of_string buf) i) <> 0 then
52 is_zero := false
53 done;
54 if !is_zero then Ok (String.sub buf (abs n) total_len)
55 else Error "input is too long"
56 | pad_len -> Ok (String.make pad_len '\000' ^ buf)
57
58let len = function
59 | "secp256r1" -> 32
60 | "secp384r1" -> 48
61 | "secp521r1" -> 66
62 | _ -> assert false
63
64let parse_secret curve s =
65 let total_len = len curve in
66 pad ~total_len s
67
68type test = { public_key : string; raw_private_key : string; expected : string }
69
70let perform_key_exchange curve ~public_key ~raw_private_key =
71 to_string_result ~pp_error
72 (match curve with
73 | "secp256r1" -> begin
74 match P256.Dh.secret_of_octets raw_private_key with
75 | Ok (p, _) -> P256.Dh.key_exchange p public_key
76 | Error _ -> assert false
77 end
78 | "secp384r1" -> begin
79 match P384.Dh.secret_of_octets raw_private_key with
80 | Ok (p, _) -> P384.Dh.key_exchange p public_key
81 | Error _ -> assert false
82 end
83 | "secp521r1" -> begin
84 match P521.Dh.secret_of_octets raw_private_key with
85 | Ok (p, _) -> P521.Dh.key_exchange p public_key
86 | Error _ -> assert false
87 end
88 | _ -> assert false)
89
90let interpret_test ~tcId curve { public_key; raw_private_key; expected } () =
91 match perform_key_exchange curve ~public_key ~raw_private_key with
92 | Ok got -> Alcotest.check hex __LOC__ expected got
93 | Error err ->
94 Printf.ksprintf (fun s -> Alcotest.fail s) "While parsing %d: %s" tcId err
95
96type invalid_test = { public : string; private_ : string }
97
98let is_ok = function Ok _ -> true | Error _ -> false
99
100let interpret_invalid_test curve { public; private_ } () =
101 let result =
102 let* public_key = Asn.parse_point curve public in
103 let* raw_private_key = parse_secret curve private_ in
104 perform_key_exchange curve ~public_key ~raw_private_key
105 in
106 Alcotest.check Alcotest.bool __LOC__ false (is_ok result)
107
108type strategy = Test of test | Invalid_test of invalid_test | Skip
109
110let make_ecdh_test curve (test : ecdh_test) =
111 let ignored_flags = [ "UnnamedCurve" ] in
112 let curve_compression_test curve =
113 let curves = [ "secp256r1"; "secp384r1"; "secp521r1" ] in
114 test.tcId = 2 && List.exists (fun x -> String.equal x curve) curves
115 in
116 match test.result with
117 | _ when has_ignored_flag test ~ignored_flags -> Ok Skip
118 | Invalid ->
119 Ok (Invalid_test { public = test.public; private_ = test.private_ })
120 | Acceptable when curve_compression_test curve ->
121 let* public_key = Asn.parse_point curve test.public in
122 let* raw_private_key = parse_secret curve test.private_ in
123 Ok (Test { public_key; raw_private_key; expected = test.shared })
124 | Acceptable -> Ok Skip
125 | Valid ->
126 let* public_key = Asn.parse_point curve test.public in
127 let* raw_private_key = parse_secret curve test.private_ in
128 Ok (Test { public_key; raw_private_key; expected = test.shared })
129
130let to_ecdh_tests curve (x : ecdh_test) =
131 let name = Printf.sprintf "%d - %s" x.tcId x.comment in
132 match make_ecdh_test curve x with
133 | Ok (Test t) -> [ (name, `Quick, interpret_test ~tcId:x.tcId curve t) ]
134 | Ok (Invalid_test t) -> [ (name, `Quick, interpret_invalid_test curve t) ]
135 | Ok Skip -> []
136 | Error e -> Printf.ksprintf failwith "While parsing %d: %s" x.tcId e
137
138let ecdh_tests file =
139 let data = load_file_exn file in
140 let groups : ecdh_test_group list =
141 List.map ecdh_test_group_exn data.testGroups
142 in
143 List.concat_map
144 (fun (group : ecdh_test_group) ->
145 List.concat_map (to_ecdh_tests group.curve) group.tests)
146 groups
147
148let make_ecdsa_test curve key hash (tst : dsa_test) =
149 let name = Printf.sprintf "%d - %s" tst.tcId tst.comment in
150 let size = len curve in
151 let msg =
152 let dgst =
153 match hash with
154 | "SHA-256" -> Digestif.SHA256.(digest_string tst.msg |> to_raw_string)
155 | "SHA-384" -> Digestif.SHA384.(digest_string tst.msg |> to_raw_string)
156 | "SHA-512" -> Digestif.SHA512.(digest_string tst.msg |> to_raw_string)
157 | "SHA-224" -> Digestif.SHA224.(digest_string tst.msg |> to_raw_string)
158 | _ -> assert false
159 in
160 String.sub dgst 0 (min size (String.length dgst))
161 in
162 let verified (r, s) =
163 match curve with
164 | "secp256r1" -> begin
165 match P256.Dsa.pub_of_octets key with
166 | Ok key -> P256.Dsa.verify ~key (r, s) msg
167 | Error _ -> assert false
168 end
169 | "secp384r1" -> begin
170 match P384.Dsa.pub_of_octets key with
171 | Ok key -> P384.Dsa.verify ~key (r, s) msg
172 | Error _ -> assert false
173 end
174 | "secp521r1" -> begin
175 match P521.Dsa.pub_of_octets key with
176 | Ok key -> P521.Dsa.verify ~key (r, s) msg
177 | Error _ -> assert false
178 end
179 | _ -> assert false
180 in
181 match tst.result with
182 | Acceptable | Invalid ->
183 let f () =
184 match Asn.parse_signature tst.sig_ with
185 | Ok (r, s) -> Alcotest.(check bool __LOC__ false (verified (r, s)))
186 | Error _s -> ()
187 in
188 (name, `Quick, f)
189 | Valid ->
190 let f () =
191 match Asn.parse_signature tst.sig_ with
192 | Ok (r, s) -> Alcotest.(check bool __LOC__ true (verified (r, s)))
193 | Error s -> Alcotest.fail s
194 in
195 (name, `Quick, f)
196
197let to_ecdsa_tests (x : ecdsa_test_group) =
198 List.map (make_ecdsa_test x.key.curve x.key.uncompressed x.sha) x.tests
199
200let ecdsa_tests file =
201 let data = load_file_exn file in
202 let groups : ecdsa_test_group list =
203 List.map ecdsa_test_group_exn data.testGroups
204 in
205 List.concat_map to_ecdsa_tests groups
206
207let to_x25519_test (x : ecdh_test) =
208 let name = Printf.sprintf "%d - %s" x.tcId x.comment
209 and priv =
210 match X25519.secret_of_octets x.private_ with
211 | Ok (p, _) -> p
212 | Error _ -> assert false
213 in
214 match x.result with
215 | Acceptable ->
216 let f () =
217 match
218 ( X25519.key_exchange priv x.public,
219 has_ignored_flag x ~ignored_flags:[ "LowOrderPublic" ] )
220 with
221 | Ok _, true -> Alcotest.fail "acceptable should have errored"
222 | Ok r, false ->
223 Alcotest.(check bool __LOC__ true (String.equal r x.shared))
224 | Error _, true -> ()
225 | Error e, false -> Alcotest.failf "acceptable errored %a" pp_error e
226 in
227 (name, `Quick, f)
228 | Invalid ->
229 let f () =
230 match X25519.key_exchange priv x.public with
231 | Ok r -> Alcotest.(check bool __LOC__ false (String.equal r x.shared))
232 | Error e -> Alcotest.failf "invalid errored %a" pp_error e
233 in
234 (name, `Quick, f)
235 | Valid ->
236 let f () =
237 match X25519.key_exchange priv x.public with
238 | Ok r -> Alcotest.(check bool __LOC__ true (String.equal r x.shared))
239 | Error e -> Alcotest.failf "valid errored %a" pp_error e
240 in
241 (name, `Quick, f)
242
243let x25519_tests =
244 let data = load_file_exn "x25519_test.json" in
245 let groups : ecdh_test_group list =
246 List.map ecdh_test_group_exn data.testGroups
247 in
248 List.concat_map
249 (fun (group : ecdh_test_group) -> List.map to_x25519_test group.tests)
250 groups
251
252let to_ed25519_test (priv, pub) (x : dsa_test) =
253 let name = Printf.sprintf "%d - %s" x.tcId x.comment in
254 match x.result with
255 | Invalid ->
256 let f () =
257 Alcotest.(
258 check bool __LOC__ false (Ed25519.verify ~key:pub x.sig_ ~msg:x.msg));
259 let s = Ed25519.sign ~key:priv x.msg in
260 Alcotest.(check bool __LOC__ false (String.equal s x.sig_))
261 in
262 (name, `Quick, f)
263 | Valid ->
264 let f () =
265 Alcotest.(
266 check bool __LOC__ true (Ed25519.verify ~key:pub x.sig_ ~msg:x.msg));
267 let s = Ed25519.sign ~key:priv x.msg in
268 Alcotest.(check bool __LOC__ true (String.equal s x.sig_))
269 in
270 (name, `Quick, f)
271 | Acceptable -> assert false
272
273let to_ed25519_keys (key : eddsa_key) =
274 match (Ed25519.priv_of_octets key.sk, Ed25519.pub_of_octets key.pk) with
275 | Ok priv, Ok pub ->
276 assert (String.equal Ed25519.(pub_to_octets (pub_of_priv priv)) key.pk);
277 (priv, pub)
278 | _ -> assert false
279
280let ed25519_tests =
281 let data = load_file_exn "eddsa_test.json" in
282 let groups : eddsa_test_group list =
283 List.map eddsa_test_group_exn data.testGroups
284 in
285 List.concat_map
286 (fun (group : eddsa_test_group) ->
287 let keys = to_ed25519_keys group.key in
288 List.map (to_ed25519_test keys) group.tests)
289 groups
290
291let () =
292 Alcotest.run "Wycheproof NIST curves"
293 [
294 ("ECDH P256 test vectors", ecdh_tests "ecdh_secp256r1_test.json");
295 ( "ECDSA P256 test vectors (SHA256)",
296 ecdsa_tests "ecdsa_secp256r1_sha256_test.json" );
297 ( "ECDSA P256 test vectors (SHA512)",
298 ecdsa_tests "ecdsa_secp256r1_sha512_test.json" );
299 ("ECDH P384 test vectors", ecdh_tests "ecdh_secp384r1_test.json");
300 ( "ECDSA P384 test vectors (SHA384)",
301 ecdsa_tests "ecdsa_secp384r1_sha384_test.json" );
302 ( "ECDSA P384 test vectors (SHA512)",
303 ecdsa_tests "ecdsa_secp384r1_sha512_test.json" );
304 ("ECDH P521 test vectors", ecdh_tests "ecdh_secp521r1_test.json");
305 ( "ECDSA P521 test vectors (SHA512)",
306 ecdsa_tests "ecdsa_secp521r1_sha512_test.json" );
307 ("X25519 test vectors", x25519_tests);
308 ("ED25519 test vectors", ed25519_tests);
309 ]