upstream: https://github.com/mirage/mirage-crypto

ec: expose low-level point arithmetic

+163 -8
+38 -8
ec/mirage_crypto_ec.ml
··· 65 end 66 end 67 68 module type Dh_dsa = sig 69 module Dh : Dh 70 module Dsa : Dsa 71 end 72 73 type field_element = string ··· 225 out_p_to_p tmp 226 end 227 228 - module type Point = sig 229 val at_infinity : unit -> point 230 val is_infinity : point -> bool 231 val add : point -> point -> point ··· 239 val scalar_mult_base : scalar -> point 240 end 241 242 - module Make_point (P : Parameters) (F : Foreign) : Point = struct 243 module Fe = Make_field_element(P)(F) 244 245 let at_infinity () = ··· 431 val generator_tables : unit -> field_element array array array 432 end 433 434 - module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct 435 let not_zero = 436 let zero = String.make Param.byte_length '\000' in 437 fun buf -> not (Eqaf.equal buf zero) ··· 486 Array.map (Array.map convert) table 487 end 488 489 - module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct 490 let point_of_octets c = 491 match P.of_octets c with 492 | Ok p when not (P.is_infinity p) -> Ok p ··· 597 b_uts tmp 598 end 599 600 - module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Digestif.S) = struct 601 type priv = scalar 602 603 let byte_length = Param.byte_length ··· 774 end 775 end 776 777 module P256 : Dh_dsa = struct 778 module Params = struct 779 let a = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC" ··· 818 external to_montgomery : out_field_element -> field_element -> unit = "mc_np256_to_montgomery" [@@noalloc] 819 end 820 821 - module P = Make_point(Params)(Foreign) 822 module S = Make_scalar(Params)(P) 823 module Dh = Make_dh(Params)(P)(S) 824 module Fn = Make_Fn(Params)(Foreign_n) 825 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA256) 826 end 827 828 module P384 : Dh_dsa = struct ··· 870 external to_montgomery : out_field_element -> field_element -> unit = "mc_np384_to_montgomery" [@@noalloc] 871 end 872 873 - module P = Make_point(Params)(Foreign) 874 module S = Make_scalar(Params)(P) 875 module Dh = Make_dh(Params)(P)(S) 876 module Fn = Make_Fn(Params)(Foreign_n) 877 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA384) 878 end 879 880 module P521 : Dh_dsa = struct ··· 923 external to_montgomery : out_field_element -> field_element -> unit = "mc_np521_to_montgomery" [@@noalloc] 924 end 925 926 - module P = Make_point(Params)(Foreign) 927 module S = Make_scalar(Params)(P) 928 module Dh = Make_dh(Params)(P)(S) 929 module Fn = Make_Fn(Params)(Foreign_n) 930 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA512) 931 end 932 933 module X25519 = struct
··· 65 end 66 end 67 68 + module type Point = sig 69 + type point 70 + type scalar 71 + val of_octets : string -> (point, error) result 72 + val to_octets : ?compress:bool -> point -> string 73 + val scalar_of_octets : string -> (scalar, error) result 74 + val scalar_to_octets : scalar -> string 75 + val generator : point 76 + val add : point -> point -> point 77 + val scalar_mult : scalar -> point -> point 78 + end 79 + 80 module type Dh_dsa = sig 81 module Dh : Dh 82 module Dsa : Dsa 83 + module Point : Point 84 end 85 86 type field_element = string ··· 238 out_p_to_p tmp 239 end 240 241 + module type Point_ops = sig 242 val at_infinity : unit -> point 243 val is_infinity : point -> bool 244 val add : point -> point -> point ··· 252 val scalar_mult_base : scalar -> point 253 end 254 255 + module Make_point_ops (P : Parameters) (F : Foreign) : Point_ops = struct 256 module Fe = Make_field_element(P)(F) 257 258 let at_infinity () = ··· 444 val generator_tables : unit -> field_element array array array 445 end 446 447 + module Make_scalar (Param : Parameters) (P : Point_ops) : Scalar = struct 448 let not_zero = 449 let zero = String.make Param.byte_length '\000' in 450 fun buf -> not (Eqaf.equal buf zero) ··· 499 Array.map (Array.map convert) table 500 end 501 502 + module Make_dh (Param : Parameters) (P : Point_ops) (S : Scalar) : Dh = struct 503 let point_of_octets c = 504 match P.of_octets c with 505 | Ok p when not (P.is_infinity p) -> Ok p ··· 610 b_uts tmp 611 end 612 613 + module Make_dsa (Param : Parameters) (F : Fn) (P : Point_ops) (S : Scalar) (H : Digestif.S) = struct 614 type priv = scalar 615 616 let byte_length = Param.byte_length ··· 787 end 788 end 789 790 + module Make_point (P : Point_ops) (S : Scalar) : Point 791 + with type point = point and type scalar = scalar 792 + = struct 793 + type nonrec point = point 794 + type nonrec scalar = scalar 795 + let of_octets = P.of_octets 796 + let to_octets ?(compress = false) p = P.to_octets ~compress p 797 + let scalar_of_octets = S.of_octets 798 + let scalar_to_octets = S.to_octets 799 + let generator = P.params_g 800 + let add = P.add 801 + let scalar_mult = S.scalar_mult 802 + end 803 + 804 module P256 : Dh_dsa = struct 805 module Params = struct 806 let a = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC" ··· 845 external to_montgomery : out_field_element -> field_element -> unit = "mc_np256_to_montgomery" [@@noalloc] 846 end 847 848 + module P = Make_point_ops(Params)(Foreign) 849 module S = Make_scalar(Params)(P) 850 module Dh = Make_dh(Params)(P)(S) 851 module Fn = Make_Fn(Params)(Foreign_n) 852 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA256) 853 + module Point = Make_point(P)(S) 854 end 855 856 module P384 : Dh_dsa = struct ··· 898 external to_montgomery : out_field_element -> field_element -> unit = "mc_np384_to_montgomery" [@@noalloc] 899 end 900 901 + module P = Make_point_ops(Params)(Foreign) 902 module S = Make_scalar(Params)(P) 903 module Dh = Make_dh(Params)(P)(S) 904 module Fn = Make_Fn(Params)(Foreign_n) 905 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA384) 906 + module Point = Make_point(P)(S) 907 end 908 909 module P521 : Dh_dsa = struct ··· 952 external to_montgomery : out_field_element -> field_element -> unit = "mc_np521_to_montgomery" [@@noalloc] 953 end 954 955 + module P = Make_point_ops(Params)(Foreign) 956 module S = Make_scalar(Params)(P) 957 module Dh = Make_dh(Params)(P)(S) 958 module Fn = Make_Fn(Params)(Foreign_n) 959 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA512) 960 + module Point = Make_point(P)(S) 961 end 962 963 module X25519 = struct
+37
ec/mirage_crypto_ec.mli
··· 155 end 156 end 157 158 (** Elliptic curve with Diffie-Hellman and DSA. *) 159 module type Dh_dsa = sig 160 ··· 163 164 (** Digital signature algorithm. *) 165 module Dsa : Dsa 166 end 167 168 (** The NIST P-256 curve, also known as SECP256R1. *)
··· 155 end 156 end 157 158 + (** Low-level point arithmetic. *) 159 + module type Point = sig 160 + type point 161 + (** The type for points on the elliptic curve. *) 162 + 163 + type scalar 164 + (** The type for scalars. *) 165 + 166 + val of_octets : string -> (point, error) result 167 + (** [of_octets buf] decodes a point from [buf] in uncompressed or compressed 168 + SEC 1 format. Returns an error if the point is not on the curve. *) 169 + 170 + val to_octets : ?compress:bool -> point -> string 171 + (** [to_octets ~compress point] encodes [point] to SEC 1 format. If 172 + [compress] is [true] (default [false]), the compressed format is used. *) 173 + 174 + val scalar_of_octets : string -> (scalar, error) result 175 + (** [scalar_of_octets buf] decodes a scalar from [buf]. Returns an error if 176 + the scalar is not in the valid range \[1, n-1\] where n is the group 177 + order. *) 178 + 179 + val scalar_to_octets : scalar -> string 180 + (** [scalar_to_octets scalar] encodes [scalar] to a byte string. *) 181 + 182 + val generator : point 183 + (** [generator] is the generator point (base point) of the curve. *) 184 + 185 + val add : point -> point -> point 186 + (** [add p q] is the sum of points [p] and [q]. *) 187 + 188 + val scalar_mult : scalar -> point -> point 189 + (** [scalar_mult s p] is the scalar multiplication of [p] by [s]. *) 190 + end 191 + 192 (** Elliptic curve with Diffie-Hellman and DSA. *) 193 module type Dh_dsa = sig 194 ··· 197 198 (** Digital signature algorithm. *) 199 module Dsa : Dsa 200 + 201 + (** Low-level point arithmetic. *) 202 + module Point : Point 203 end 204 205 (** The NIST P-256 curve, also known as SECP256R1. *)
+88
tests/test_ec.ml
··· 803 |}; 804 ] 805 806 let p521_regression () = 807 let key = of_hex 808 "04 01 e4 f8 8a 40 3d fe 2f 65 a0 20 50 01 9b 87 ··· 853 ("X25519", [ "RFC 7748", `Quick, x25519 ]); 854 ("ED25519", ed25519); 855 ("ECDSA P521 regression", [ "regreesion1", `Quick, p521_regression ]); 856 ]
··· 803 |}; 804 ] 805 806 + let point_module_tests (module C : Mirage_crypto_ec.Dh_dsa) name = 807 + let open C in 808 + let test_generator_not_identity () = 809 + (* Generator should not be the identity (at infinity) *) 810 + let g = Point.generator in 811 + let g_bytes = Point.to_octets g in 812 + (* Generator serialized should not be just the identity point *) 813 + Alcotest.(check bool) "generator has non-trivial encoding" 814 + true (String.length g_bytes > 1) 815 + in 816 + let test_point_serialization_roundtrip () = 817 + (* Generate a key pair and check that the public key roundtrips through Point *) 818 + let _priv, pub = Dsa.generate () in 819 + let pub_bytes = Dsa.pub_to_octets pub in 820 + match Point.of_octets pub_bytes with 821 + | Ok point -> 822 + let point_bytes = Point.to_octets point in 823 + Alcotest.(check string) "point roundtrip" pub_bytes point_bytes 824 + | Error e -> Alcotest.failf "of_octets failed: %a" pp_error e 825 + in 826 + let test_point_compressed_serialization () = 827 + let _priv, pub = Dsa.generate () in 828 + let pub_bytes = Dsa.pub_to_octets pub in 829 + match Point.of_octets pub_bytes with 830 + | Ok point -> 831 + let compressed = Point.to_octets ~compress:true point in 832 + (* Compressed form should be shorter *) 833 + Alcotest.(check bool) "compressed is shorter" 834 + true (String.length compressed < String.length pub_bytes); 835 + (* Should be able to decode compressed form *) 836 + (match Point.of_octets compressed with 837 + | Ok point' -> 838 + let uncompressed = Point.to_octets point' in 839 + Alcotest.(check string) "compressed roundtrip" pub_bytes uncompressed 840 + | Error e -> Alcotest.failf "compressed of_octets failed: %a" pp_error e) 841 + | Error e -> Alcotest.failf "of_octets failed: %a" pp_error e 842 + in 843 + let test_scalar_serialization_roundtrip () = 844 + (* Generate a key and check scalar roundtrip *) 845 + let secret, _pub = Dh.gen_key () in 846 + let secret_bytes = Dh.secret_to_octets secret in 847 + match Point.scalar_of_octets secret_bytes with 848 + | Ok scalar -> 849 + let scalar_bytes = Point.scalar_to_octets scalar in 850 + Alcotest.(check string) "scalar roundtrip" secret_bytes scalar_bytes 851 + | Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e 852 + in 853 + let test_scalar_mult_with_generator () = 854 + (* scalar_mult with generator should give the same result as pub_of_priv *) 855 + let priv, pub = Dsa.generate () in 856 + let priv_bytes = Dsa.priv_to_octets priv in 857 + let pub_bytes = Dsa.pub_to_octets pub in 858 + match Point.scalar_of_octets priv_bytes with 859 + | Ok scalar -> 860 + let computed_pub = Point.scalar_mult scalar Point.generator in 861 + let computed_bytes = Point.to_octets computed_pub in 862 + Alcotest.(check string) "scalar_mult generator" pub_bytes computed_bytes 863 + | Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e 864 + in 865 + let test_point_add () = 866 + (* Test that P + P = 2P (scalar_mult 2 P) *) 867 + let g = Point.generator in 868 + let g_plus_g = Point.add g g in 869 + (* scalar 2 in big-endian encoding *) 870 + let two = 871 + let buf = Bytes.make Dsa.byte_length '\000' in 872 + Bytes.set_uint8 buf (Dsa.byte_length - 1) 2; 873 + Bytes.to_string buf 874 + in 875 + match Point.scalar_of_octets two with 876 + | Ok scalar_2 -> 877 + let two_g = Point.scalar_mult scalar_2 g in 878 + Alcotest.(check string) "G + G = 2G" 879 + (Point.to_octets g_plus_g) (Point.to_octets two_g) 880 + | Error e -> Alcotest.failf "scalar_of_octets 2 failed: %a" pp_error e 881 + in 882 + [ 883 + name ^ " Point generator", `Quick, test_generator_not_identity; 884 + name ^ " Point serialization roundtrip", `Quick, test_point_serialization_roundtrip; 885 + name ^ " Point compressed serialization", `Quick, test_point_compressed_serialization; 886 + name ^ " Scalar serialization roundtrip", `Quick, test_scalar_serialization_roundtrip; 887 + name ^ " scalar_mult with generator", `Quick, test_scalar_mult_with_generator; 888 + name ^ " Point add", `Quick, test_point_add; 889 + ] 890 + 891 let p521_regression () = 892 let key = of_hex 893 "04 01 e4 f8 8a 40 3d fe 2f 65 a0 20 50 01 9b 87 ··· 938 ("X25519", [ "RFC 7748", `Quick, x25519 ]); 939 ("ED25519", ed25519); 940 ("ECDSA P521 regression", [ "regreesion1", `Quick, p521_regression ]); 941 + ("P256 Point module", point_module_tests (module P256) "P256"); 942 + ("P384 Point module", point_module_tests (module P384) "P384"); 943 + ("P521 Point module", point_module_tests (module P521) "P521"); 944 ]