upstream: https://github.com/mirage/mirage-crypto

ec: expose low-level point arithmetic

+163 -8
+38 -8
ec/mirage_crypto_ec.ml
··· 65 65 end 66 66 end 67 67 68 + module type Point = sig 69 + type point 70 + type scalar 71 + val of_octets : string -> (point, error) result 72 + val to_octets : ?compress:bool -> point -> string 73 + val scalar_of_octets : string -> (scalar, error) result 74 + val scalar_to_octets : scalar -> string 75 + val generator : point 76 + val add : point -> point -> point 77 + val scalar_mult : scalar -> point -> point 78 + end 79 + 68 80 module type Dh_dsa = sig 69 81 module Dh : Dh 70 82 module Dsa : Dsa 83 + module Point : Point 71 84 end 72 85 73 86 type field_element = string ··· 225 238 out_p_to_p tmp 226 239 end 227 240 228 - module type Point = sig 241 + module type Point_ops = sig 229 242 val at_infinity : unit -> point 230 243 val is_infinity : point -> bool 231 244 val add : point -> point -> point ··· 239 252 val scalar_mult_base : scalar -> point 240 253 end 241 254 242 - module Make_point (P : Parameters) (F : Foreign) : Point = struct 255 + module Make_point_ops (P : Parameters) (F : Foreign) : Point_ops = struct 243 256 module Fe = Make_field_element(P)(F) 244 257 245 258 let at_infinity () = ··· 431 444 val generator_tables : unit -> field_element array array array 432 445 end 433 446 434 - module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct 447 + module Make_scalar (Param : Parameters) (P : Point_ops) : Scalar = struct 435 448 let not_zero = 436 449 let zero = String.make Param.byte_length '\000' in 437 450 fun buf -> not (Eqaf.equal buf zero) ··· 486 499 Array.map (Array.map convert) table 487 500 end 488 501 489 - module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct 502 + module Make_dh (Param : Parameters) (P : Point_ops) (S : Scalar) : Dh = struct 490 503 let point_of_octets c = 491 504 match P.of_octets c with 492 505 | Ok p when not (P.is_infinity p) -> Ok p ··· 597 610 b_uts tmp 598 611 end 599 612 600 - module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Digestif.S) = struct 613 + module Make_dsa (Param : Parameters) (F : Fn) (P : Point_ops) (S : Scalar) (H : Digestif.S) = struct 601 614 type priv = scalar 602 615 603 616 let byte_length = Param.byte_length ··· 774 787 end 775 788 end 776 789 790 + module Make_point (P : Point_ops) (S : Scalar) : Point 791 + with type point = point and type scalar = scalar 792 + = struct 793 + type nonrec point = point 794 + type nonrec scalar = scalar 795 + let of_octets = P.of_octets 796 + let to_octets ?(compress = false) p = P.to_octets ~compress p 797 + let scalar_of_octets = S.of_octets 798 + let scalar_to_octets = S.to_octets 799 + let generator = P.params_g 800 + let add = P.add 801 + let scalar_mult = S.scalar_mult 802 + end 803 + 777 804 module P256 : Dh_dsa = struct 778 805 module Params = struct 779 806 let a = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC" ··· 818 845 external to_montgomery : out_field_element -> field_element -> unit = "mc_np256_to_montgomery" [@@noalloc] 819 846 end 820 847 821 - module P = Make_point(Params)(Foreign) 848 + module P = Make_point_ops(Params)(Foreign) 822 849 module S = Make_scalar(Params)(P) 823 850 module Dh = Make_dh(Params)(P)(S) 824 851 module Fn = Make_Fn(Params)(Foreign_n) 825 852 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA256) 853 + module Point = Make_point(P)(S) 826 854 end 827 855 828 856 module P384 : Dh_dsa = struct ··· 870 898 external to_montgomery : out_field_element -> field_element -> unit = "mc_np384_to_montgomery" [@@noalloc] 871 899 end 872 900 873 - module P = Make_point(Params)(Foreign) 901 + module P = Make_point_ops(Params)(Foreign) 874 902 module S = Make_scalar(Params)(P) 875 903 module Dh = Make_dh(Params)(P)(S) 876 904 module Fn = Make_Fn(Params)(Foreign_n) 877 905 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA384) 906 + module Point = Make_point(P)(S) 878 907 end 879 908 880 909 module P521 : Dh_dsa = struct ··· 923 952 external to_montgomery : out_field_element -> field_element -> unit = "mc_np521_to_montgomery" [@@noalloc] 924 953 end 925 954 926 - module P = Make_point(Params)(Foreign) 955 + module P = Make_point_ops(Params)(Foreign) 927 956 module S = Make_scalar(Params)(P) 928 957 module Dh = Make_dh(Params)(P)(S) 929 958 module Fn = Make_Fn(Params)(Foreign_n) 930 959 module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA512) 960 + module Point = Make_point(P)(S) 931 961 end 932 962 933 963 module X25519 = struct
+37
ec/mirage_crypto_ec.mli
··· 155 155 end 156 156 end 157 157 158 + (** Low-level point arithmetic. *) 159 + module type Point = sig 160 + type point 161 + (** The type for points on the elliptic curve. *) 162 + 163 + type scalar 164 + (** The type for scalars. *) 165 + 166 + val of_octets : string -> (point, error) result 167 + (** [of_octets buf] decodes a point from [buf] in uncompressed or compressed 168 + SEC 1 format. Returns an error if the point is not on the curve. *) 169 + 170 + val to_octets : ?compress:bool -> point -> string 171 + (** [to_octets ~compress point] encodes [point] to SEC 1 format. If 172 + [compress] is [true] (default [false]), the compressed format is used. *) 173 + 174 + val scalar_of_octets : string -> (scalar, error) result 175 + (** [scalar_of_octets buf] decodes a scalar from [buf]. Returns an error if 176 + the scalar is not in the valid range \[1, n-1\] where n is the group 177 + order. *) 178 + 179 + val scalar_to_octets : scalar -> string 180 + (** [scalar_to_octets scalar] encodes [scalar] to a byte string. *) 181 + 182 + val generator : point 183 + (** [generator] is the generator point (base point) of the curve. *) 184 + 185 + val add : point -> point -> point 186 + (** [add p q] is the sum of points [p] and [q]. *) 187 + 188 + val scalar_mult : scalar -> point -> point 189 + (** [scalar_mult s p] is the scalar multiplication of [p] by [s]. *) 190 + end 191 + 158 192 (** Elliptic curve with Diffie-Hellman and DSA. *) 159 193 module type Dh_dsa = sig 160 194 ··· 163 197 164 198 (** Digital signature algorithm. *) 165 199 module Dsa : Dsa 200 + 201 + (** Low-level point arithmetic. *) 202 + module Point : Point 166 203 end 167 204 168 205 (** The NIST P-256 curve, also known as SECP256R1. *)
+88
tests/test_ec.ml
··· 803 803 |}; 804 804 ] 805 805 806 + let point_module_tests (module C : Mirage_crypto_ec.Dh_dsa) name = 807 + let open C in 808 + let test_generator_not_identity () = 809 + (* Generator should not be the identity (at infinity) *) 810 + let g = Point.generator in 811 + let g_bytes = Point.to_octets g in 812 + (* Generator serialized should not be just the identity point *) 813 + Alcotest.(check bool) "generator has non-trivial encoding" 814 + true (String.length g_bytes > 1) 815 + in 816 + let test_point_serialization_roundtrip () = 817 + (* Generate a key pair and check that the public key roundtrips through Point *) 818 + let _priv, pub = Dsa.generate () in 819 + let pub_bytes = Dsa.pub_to_octets pub in 820 + match Point.of_octets pub_bytes with 821 + | Ok point -> 822 + let point_bytes = Point.to_octets point in 823 + Alcotest.(check string) "point roundtrip" pub_bytes point_bytes 824 + | Error e -> Alcotest.failf "of_octets failed: %a" pp_error e 825 + in 826 + let test_point_compressed_serialization () = 827 + let _priv, pub = Dsa.generate () in 828 + let pub_bytes = Dsa.pub_to_octets pub in 829 + match Point.of_octets pub_bytes with 830 + | Ok point -> 831 + let compressed = Point.to_octets ~compress:true point in 832 + (* Compressed form should be shorter *) 833 + Alcotest.(check bool) "compressed is shorter" 834 + true (String.length compressed < String.length pub_bytes); 835 + (* Should be able to decode compressed form *) 836 + (match Point.of_octets compressed with 837 + | Ok point' -> 838 + let uncompressed = Point.to_octets point' in 839 + Alcotest.(check string) "compressed roundtrip" pub_bytes uncompressed 840 + | Error e -> Alcotest.failf "compressed of_octets failed: %a" pp_error e) 841 + | Error e -> Alcotest.failf "of_octets failed: %a" pp_error e 842 + in 843 + let test_scalar_serialization_roundtrip () = 844 + (* Generate a key and check scalar roundtrip *) 845 + let secret, _pub = Dh.gen_key () in 846 + let secret_bytes = Dh.secret_to_octets secret in 847 + match Point.scalar_of_octets secret_bytes with 848 + | Ok scalar -> 849 + let scalar_bytes = Point.scalar_to_octets scalar in 850 + Alcotest.(check string) "scalar roundtrip" secret_bytes scalar_bytes 851 + | Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e 852 + in 853 + let test_scalar_mult_with_generator () = 854 + (* scalar_mult with generator should give the same result as pub_of_priv *) 855 + let priv, pub = Dsa.generate () in 856 + let priv_bytes = Dsa.priv_to_octets priv in 857 + let pub_bytes = Dsa.pub_to_octets pub in 858 + match Point.scalar_of_octets priv_bytes with 859 + | Ok scalar -> 860 + let computed_pub = Point.scalar_mult scalar Point.generator in 861 + let computed_bytes = Point.to_octets computed_pub in 862 + Alcotest.(check string) "scalar_mult generator" pub_bytes computed_bytes 863 + | Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e 864 + in 865 + let test_point_add () = 866 + (* Test that P + P = 2P (scalar_mult 2 P) *) 867 + let g = Point.generator in 868 + let g_plus_g = Point.add g g in 869 + (* scalar 2 in big-endian encoding *) 870 + let two = 871 + let buf = Bytes.make Dsa.byte_length '\000' in 872 + Bytes.set_uint8 buf (Dsa.byte_length - 1) 2; 873 + Bytes.to_string buf 874 + in 875 + match Point.scalar_of_octets two with 876 + | Ok scalar_2 -> 877 + let two_g = Point.scalar_mult scalar_2 g in 878 + Alcotest.(check string) "G + G = 2G" 879 + (Point.to_octets g_plus_g) (Point.to_octets two_g) 880 + | Error e -> Alcotest.failf "scalar_of_octets 2 failed: %a" pp_error e 881 + in 882 + [ 883 + name ^ " Point generator", `Quick, test_generator_not_identity; 884 + name ^ " Point serialization roundtrip", `Quick, test_point_serialization_roundtrip; 885 + name ^ " Point compressed serialization", `Quick, test_point_compressed_serialization; 886 + name ^ " Scalar serialization roundtrip", `Quick, test_scalar_serialization_roundtrip; 887 + name ^ " scalar_mult with generator", `Quick, test_scalar_mult_with_generator; 888 + name ^ " Point add", `Quick, test_point_add; 889 + ] 890 + 806 891 let p521_regression () = 807 892 let key = of_hex 808 893 "04 01 e4 f8 8a 40 3d fe 2f 65 a0 20 50 01 9b 87 ··· 853 938 ("X25519", [ "RFC 7748", `Quick, x25519 ]); 854 939 ("ED25519", ed25519); 855 940 ("ECDSA P521 regression", [ "regreesion1", `Quick, p521_regression ]); 941 + ("P256 Point module", point_module_tests (module P256) "P256"); 942 + ("P384 Point module", point_module_tests (module P384) "P384"); 943 + ("P521 Point module", point_module_tests (module P521) "P521"); 856 944 ]