this repo has no description

tessera-npy: implement header parsing with tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

+247 -7
+180 -6
tessera-npy/lib/npy.ml
··· 4 | Float32 : float dtype 5 | Float64 : float dtype 6 7 type t = { 8 shape : int array; 9 fortran_order : bool; 10 - descr : string; 11 data : string; 12 } 13 14 - let of_string _s = Error "not implemented" 15 16 let shape t = t.shape 17 18 let fortran_order t = t.fortran_order 19 20 - let data_int8 _t = None 21 22 - let data_uint8 _t = None 23 24 - let data_float32 _t = None 25 26 - let data_float64 _t = None
··· 4 | Float32 : float dtype 5 | Float64 : float dtype 6 7 + type descr = D_int8 | D_uint8 | D_float32 | D_float64 8 + 9 type t = { 10 shape : int array; 11 fortran_order : bool; 12 + descr : descr; 13 data : string; 14 } 15 16 + let find_substring haystack needle = 17 + let nlen = String.length needle in 18 + let hlen = String.length haystack in 19 + let rec search i = 20 + if i + nlen > hlen then None 21 + else if String.sub haystack i nlen = needle then Some i 22 + else search (i + 1) 23 + in 24 + search 0 25 + 26 + let extract_quoted_value header key = 27 + let pattern = "'" ^ key ^ "': " in 28 + match find_substring header pattern with 29 + | None -> Error (Printf.sprintf "missing key: %s" key) 30 + | Some i -> 31 + let start = i + String.length pattern in 32 + if start >= String.length header then Error (Printf.sprintf "truncated value for key: %s" key) 33 + else 34 + let c = header.[start] in 35 + if c = '\'' then 36 + (* quoted string value *) 37 + let value_start = start + 1 in 38 + (match find_substring (String.sub header value_start (String.length header - value_start)) "'" with 39 + | None -> Error (Printf.sprintf "unterminated string for key: %s" key) 40 + | Some len -> Ok (String.sub header value_start len)) 41 + else 42 + (* unquoted value - read until comma or } *) 43 + let rec find_end j = 44 + if j >= String.length header then j 45 + else match header.[j] with 46 + | ',' | '}' | ')' -> j 47 + | _ -> find_end (j + 1) 48 + in 49 + let end_pos = find_end start in 50 + let value = String.trim (String.sub header start (end_pos - start)) in 51 + Ok value 52 + 53 + let parse_descr s = 54 + match s with 55 + | "|i1" -> Ok D_int8 56 + | "|u1" -> Ok D_uint8 57 + | "<f4" -> Ok D_float32 58 + | "<f8" -> Ok D_float64 59 + | _ -> Error (Printf.sprintf "unsupported dtype: %s" s) 60 + 61 + let parse_fortran_order s = 62 + match s with 63 + | "True" -> Ok true 64 + | "False" -> Ok false 65 + | _ -> Error (Printf.sprintf "invalid fortran_order: %s" s) 66 + 67 + let parse_shape header = 68 + let pattern = "'shape': (" in 69 + match find_substring header pattern with 70 + | None -> Error "missing shape" 71 + | Some i -> 72 + let start = i + String.length pattern in 73 + (match find_substring (String.sub header start (String.length header - start)) ")" with 74 + | None -> Error "unterminated shape" 75 + | Some len -> 76 + let shape_str = String.sub header start len in 77 + let shape_str = String.trim shape_str in 78 + if shape_str = "" then Ok [||] 79 + else 80 + let parts = String.split_on_char ',' shape_str in 81 + let parts = List.filter (fun s -> String.trim s <> "") parts in 82 + let dims = List.map (fun s -> int_of_string (String.trim s)) parts in 83 + Ok (Array.of_list dims)) 84 + 85 + let of_string s = 86 + let len = String.length s in 87 + if len < 10 then Error "too short for .npy file" 88 + else if String.sub s 0 6 <> "\x93NUMPY" then Error "bad magic number" 89 + else 90 + let major = Char.code s.[6] in 91 + let _minor = Char.code s.[7] in 92 + let header_len, header_offset = 93 + if major = 1 then 94 + let hl = Char.code s.[8] lor (Char.code s.[9] lsl 8) in 95 + (hl, 10) 96 + else if major = 2 then 97 + if len < 12 then (0, 12) 98 + else 99 + let hl = 100 + Char.code s.[8] 101 + lor (Char.code s.[9] lsl 8) 102 + lor (Char.code s.[10] lsl 16) 103 + lor (Char.code s.[11] lsl 24) 104 + in 105 + (hl, 12) 106 + else (0, 10) 107 + in 108 + if header_offset + header_len > len then Error "truncated header" 109 + else 110 + let header = String.sub s header_offset header_len in 111 + match extract_quoted_value header "descr" with 112 + | Error e -> Error e 113 + | Ok descr_str -> 114 + match parse_descr descr_str with 115 + | Error e -> Error e 116 + | Ok descr -> 117 + match extract_quoted_value header "fortran_order" with 118 + | Error e -> Error e 119 + | Ok fo_str -> 120 + match parse_fortran_order fo_str with 121 + | Error e -> Error e 122 + | Ok fortran_order -> 123 + match parse_shape header with 124 + | Error e -> Error e 125 + | Ok shape -> 126 + let data_offset = header_offset + header_len in 127 + let data = String.sub s data_offset (len - data_offset) in 128 + Ok { shape; fortran_order; descr; data } 129 130 let shape t = t.shape 131 132 let fortran_order t = t.fortran_order 133 134 + let data_int8 t = 135 + match t.descr with 136 + | D_int8 -> 137 + let n = String.length t.data in 138 + let ba = Bigarray.Array1.create Bigarray.int8_signed Bigarray.c_layout n in 139 + for i = 0 to n - 1 do 140 + let v = Char.code t.data.[i] in 141 + let v = if v >= 128 then v - 256 else v in 142 + Bigarray.Array1.set ba i v 143 + done; 144 + Some ba 145 + | _ -> None 146 + 147 + let data_uint8 t = 148 + match t.descr with 149 + | D_uint8 -> 150 + let n = String.length t.data in 151 + let ba = Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout n in 152 + for i = 0 to n - 1 do 153 + Bigarray.Array1.set ba i (Char.code t.data.[i]) 154 + done; 155 + Some ba 156 + | _ -> None 157 158 + let read_le_int32 s off = 159 + let b0 = Char.code s.[off] in 160 + let b1 = Char.code s.[off + 1] in 161 + let b2 = Char.code s.[off + 2] in 162 + let b3 = Char.code s.[off + 3] in 163 + Int32.logor 164 + (Int32.of_int b0) 165 + (Int32.logor 166 + (Int32.shift_left (Int32.of_int b1) 8) 167 + (Int32.logor 168 + (Int32.shift_left (Int32.of_int b2) 16) 169 + (Int32.shift_left (Int32.of_int b3) 24))) 170 171 + let read_le_int64 s off = 172 + let b i = Int64.of_int (Char.code s.[off + i]) in 173 + let ( lor ) = Int64.logor in 174 + let ( lsl ) = Int64.shift_left in 175 + (b 0) lor ((b 1) lsl 8) lor ((b 2) lsl 16) lor ((b 3) lsl 24) 176 + lor ((b 4) lsl 32) lor ((b 5) lsl 40) lor ((b 6) lsl 48) lor ((b 7) lsl 56) 177 178 + let data_float32 t = 179 + match t.descr with 180 + | D_float32 -> 181 + let n = String.length t.data / 4 in 182 + let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout n in 183 + for i = 0 to n - 1 do 184 + let bits = read_le_int32 t.data (i * 4) in 185 + Bigarray.Array1.set ba i (Int32.float_of_bits bits) 186 + done; 187 + Some ba 188 + | _ -> None 189 + 190 + let data_float64 t = 191 + match t.descr with 192 + | D_float64 -> 193 + let n = String.length t.data / 8 in 194 + let ba = Bigarray.Array1.create Bigarray.float64 Bigarray.c_layout n in 195 + for i = 0 to n - 1 do 196 + let bits = read_le_int64 t.data (i * 8) in 197 + Bigarray.Array1.set ba i (Int64.float_of_bits bits) 198 + done; 199 + Some ba 200 + | _ -> None
+67 -1
tessera-npy/test/test_npy.ml
··· 1 - let () = ()
··· 1 + let make_npy_v1 ~descr ~fortran_order ~shape data = 2 + let header = 3 + Printf.sprintf "{'descr': '%s', 'fortran_order': %s, 'shape': (%s), }" 4 + descr 5 + (if fortran_order then "True" else "False") 6 + (String.concat ", " (List.map string_of_int shape)) 7 + in 8 + let prefix_len = 6 + 2 + 2 in 9 + let raw_header_len = String.length header + 1 in 10 + let padded_len = 11 + let total = prefix_len + raw_header_len in 12 + let rem = total mod 64 in 13 + if rem = 0 then raw_header_len else raw_header_len + (64 - rem) 14 + in 15 + let buf = Buffer.create (prefix_len + padded_len + String.length data) in 16 + Buffer.add_string buf "\x93NUMPY"; 17 + Buffer.add_char buf '\x01'; 18 + Buffer.add_char buf '\x00'; 19 + Buffer.add_char buf (Char.chr (padded_len land 0xff)); 20 + Buffer.add_char buf (Char.chr ((padded_len lsr 8) land 0xff)); 21 + Buffer.add_string buf header; 22 + for _ = 1 to padded_len - raw_header_len do 23 + Buffer.add_char buf ' ' 24 + done; 25 + Buffer.add_char buf '\n'; 26 + Buffer.add_string buf data; 27 + Buffer.contents buf 28 + 29 + let test_parse_int8_header () = 30 + let npy = make_npy_v1 ~descr:"|i1" ~fortran_order:false ~shape:[3; 4] "\x00" in 31 + match Npy.of_string npy with 32 + | Error e -> Alcotest.fail e 33 + | Ok t -> 34 + Alcotest.(check (array int)) "shape" [|3; 4|] (Npy.shape t); 35 + Alcotest.(check bool) "fortran_order" false (Npy.fortran_order t) 36 + 37 + let test_parse_float32_header () = 38 + let npy = make_npy_v1 ~descr:"<f4" ~fortran_order:false ~shape:[2; 3] "\x00" in 39 + match Npy.of_string npy with 40 + | Error e -> Alcotest.fail e 41 + | Ok t -> 42 + Alcotest.(check (array int)) "shape" [|2; 3|] (Npy.shape t) 43 + 44 + let test_parse_3d_shape () = 45 + let npy = make_npy_v1 ~descr:"|i1" ~fortran_order:false ~shape:[10; 20; 128] "\x00" in 46 + match Npy.of_string npy with 47 + | Error e -> Alcotest.fail e 48 + | Ok t -> 49 + Alcotest.(check (array int)) "shape" [|10; 20; 128|] (Npy.shape t) 50 + 51 + let test_bad_magic () = 52 + let npy = "NOT_NPY_DATA" in 53 + match Npy.of_string npy with 54 + | Ok _ -> Alcotest.fail "should have failed" 55 + | Error _ -> () 56 + 57 + let () = 58 + Alcotest.run "tessera-npy" 59 + [ 60 + ( "header", 61 + [ 62 + Alcotest.test_case "int8 header" `Quick test_parse_int8_header; 63 + Alcotest.test_case "float32 header" `Quick test_parse_float32_header; 64 + Alcotest.test_case "3d shape" `Quick test_parse_3d_shape; 65 + Alcotest.test_case "bad magic" `Quick test_bad_magic; 66 + ] ); 67 + ]