Zstd compression in pure OCaml
at main 435 lines 15 kB view raw
1(** Huffman coding for Zstandard literals decompression. 2 3 Zstd uses canonical Huffman codes for literal compression. 4 Huffman streams are read backwards like FSE streams. *) 5 6(** Huffman decoding table entry *) 7type entry = { 8 symbol : int; 9 num_bits : int; 10} 11 12(** Huffman decoding table *) 13type dtable = { 14 entries : entry array; 15 max_bits : int; 16} 17 18let highest_set_bit = Fse.highest_set_bit 19 20(** Build Huffman table from bit lengths. 21 Uses canonical Huffman coding. *) 22let build_dtable_from_bits bits num_symbols = 23 if num_symbols > Constants.max_huffman_symbols then 24 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 25 26 (* Find max bits and count symbols per bit length *) 27 let max_bits = ref 0 in 28 let rank_count = Array.make (Constants.max_huffman_bits + 1) 0 in 29 30 for i = 0 to num_symbols - 1 do 31 let b = bits.(i) in 32 if b > Constants.max_huffman_bits then 33 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 34 if b > !max_bits then max_bits := b; 35 rank_count.(b) <- rank_count.(b) + 1 36 done; 37 38 if !max_bits = 0 then 39 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 40 41 let table_size = 1 lsl !max_bits in 42 let entries = Array.init table_size (fun _ -> 43 { symbol = 0; num_bits = 0 } 44 ) in 45 46 (* Calculate starting indices for each rank *) 47 let rank_idx = Array.make (Constants.max_huffman_bits + 1) 0 in 48 rank_idx.(!max_bits) <- 0; 49 for i = !max_bits downto 1 do 50 rank_idx.(i - 1) <- rank_idx.(i) + rank_count.(i) * (1 lsl (!max_bits - i)); 51 (* Fill in num_bits for this range *) 52 for j = rank_idx.(i) to rank_idx.(i - 1) - 1 do 53 entries.(j) <- { entries.(j) with num_bits = i } 54 done 55 done; 56 57 if rank_idx.(0) <> table_size then 58 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 59 60 (* Assign symbols to table entries *) 61 for i = 0 to num_symbols - 1 do 62 let b = bits.(i) in 63 if b <> 0 then begin 64 let code = rank_idx.(b) in 65 let len = 1 lsl (!max_bits - b) in 66 for j = code to code + len - 1 do 67 entries.(j) <- { entries.(j) with symbol = i } 68 done; 69 rank_idx.(b) <- code + len 70 end 71 done; 72 73 { entries; max_bits = !max_bits } 74 75(** Build table from weights (as decoded from zstd format) *) 76let build_dtable_from_weights weights num_symbols = 77 if num_symbols + 1 > Constants.max_huffman_symbols then 78 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 79 80 let bits = Array.make (num_symbols + 1) 0 in 81 82 (* Calculate weight sum to find max_bits and last weight *) 83 let weight_sum = ref 0 in 84 for i = 0 to num_symbols - 1 do 85 let w = weights.(i) in 86 if w > Constants.max_huffman_bits then 87 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 88 if w > 0 then 89 weight_sum := !weight_sum + (1 lsl (w - 1)) 90 done; 91 92 (* Find max_bits (first power of 2 > weight_sum) *) 93 let max_bits = highest_set_bit !weight_sum + 1 in 94 let left_over = (1 lsl max_bits) - !weight_sum in 95 96 (* left_over must be a power of 2 *) 97 if left_over land (left_over - 1) <> 0 then 98 raise (Constants.Zstd_error Constants.Invalid_huffman_table); 99 100 let last_weight = highest_set_bit left_over + 1 in 101 102 (* Convert weights to bit lengths *) 103 for i = 0 to num_symbols - 1 do 104 let w = weights.(i) in 105 bits.(i) <- if w > 0 then max_bits + 1 - w else 0 106 done; 107 bits.(num_symbols) <- max_bits + 1 - last_weight; 108 109 build_dtable_from_bits bits (num_symbols + 1) 110 111(** Initialize Huffman state by reading max_bits *) 112let[@inline] init_state dtable (stream : Bit_reader.Backward.t) = 113 Bit_reader.Backward.read_bits stream dtable.max_bits 114 115(** Decode a symbol and update state *) 116let[@inline] decode_symbol dtable state (stream : Bit_reader.Backward.t) = 117 let entry = dtable.entries.(state) in 118 let symbol = entry.symbol in 119 let bits_used = entry.num_bits in 120 (* Shift out used bits and read new ones *) 121 let mask = (1 lsl dtable.max_bits) - 1 in 122 let rest = Bit_reader.Backward.read_bits stream bits_used in 123 let new_state = ((state lsl bits_used) + rest) land mask in 124 (symbol, new_state) 125 126(** Decompress a single Huffman stream *) 127let decompress_1stream dtable src ~pos ~len output ~out_pos ~out_len = 128 let stream = Bit_reader.Backward.of_bytes src ~pos ~len in 129 let state = ref (init_state dtable stream) in 130 131 let written = ref 0 in 132 while Bit_reader.Backward.remaining stream > -dtable.max_bits do 133 if out_pos + !written >= out_pos + out_len then 134 raise (Constants.Zstd_error Constants.Output_too_small); 135 136 let (symbol, new_state) = decode_symbol dtable !state stream in 137 Bytes.set_uint8 output (out_pos + !written) symbol; 138 incr written; 139 state := new_state 140 done; 141 142 (* Verify stream is exactly consumed *) 143 if Bit_reader.Backward.remaining stream <> -dtable.max_bits then 144 raise (Constants.Zstd_error Constants.Corruption); 145 146 !written 147 148(** Decompress 4 interleaved Huffman streams *) 149let decompress_4stream dtable src ~pos ~len output ~out_pos ~regen_size = 150 (* Read stream sizes from jump table (6 bytes) *) 151 let size1 = Bit_reader.get_u16_le src pos in 152 let size2 = Bit_reader.get_u16_le src (pos + 2) in 153 let size3 = Bit_reader.get_u16_le src (pos + 4) in 154 let size4 = len - 6 - size1 - size2 - size3 in 155 156 if size4 < 1 then 157 raise (Constants.Zstd_error Constants.Corruption); 158 159 (* Calculate output sizes *) 160 let out_size = (regen_size + 3) / 4 in 161 let out_size4 = regen_size - 3 * out_size in 162 163 (* Decompress each stream *) 164 let stream_pos = pos + 6 in 165 166 let written1 = decompress_1stream dtable src 167 ~pos:stream_pos ~len:size1 168 output ~out_pos ~out_len:out_size in 169 170 let written2 = decompress_1stream dtable src 171 ~pos:(stream_pos + size1) ~len:size2 172 output ~out_pos:(out_pos + out_size) ~out_len:out_size in 173 174 let written3 = decompress_1stream dtable src 175 ~pos:(stream_pos + size1 + size2) ~len:size3 176 output ~out_pos:(out_pos + 2 * out_size) ~out_len:out_size in 177 178 let written4 = decompress_1stream dtable src 179 ~pos:(stream_pos + size1 + size2 + size3) ~len:size4 180 output ~out_pos:(out_pos + 3 * out_size) ~out_len:out_size4 in 181 182 written1 + written2 + written3 + written4 183 184(** Decode Huffman table from stream. 185 Returns (dtable, bytes consumed) *) 186let decode_table (stream : Bit_reader.Forward.t) = 187 let header = Bit_reader.Forward.read_byte stream in 188 189 let weights = Array.make Constants.max_huffman_symbols 0 in 190 let num_symbols = 191 if header >= 128 then begin 192 (* Direct representation: 4 bits per weight *) 193 let count = header - 127 in 194 let bytes_needed = (count + 1) / 2 in 195 let data = Bit_reader.Forward.get_bytes stream bytes_needed in 196 197 for i = 0 to count - 1 do 198 let byte = Bytes.get_uint8 data (i / 2) in 199 weights.(i) <- if i mod 2 = 0 then byte lsr 4 else byte land 0xf 200 done; 201 count 202 end else begin 203 (* FSE compressed weights *) 204 let compressed_size = header in 205 let fse_data = Bit_reader.Forward.get_bytes stream compressed_size in 206 207 (* Decode FSE table for weights (max accuracy 7) *) 208 let fse_stream = Bit_reader.Forward.of_bytes fse_data in 209 let fse_table = Fse.decode_header fse_stream 7 in 210 211 (* Remaining bytes are the compressed weights *) 212 let weights_pos = Bit_reader.Forward.byte_position fse_stream in 213 let weights_len = compressed_size - weights_pos in 214 215 let weight_bytes = Bytes.create Constants.max_huffman_symbols in 216 let decoded = Fse.decompress_interleaved2 fse_table 217 fse_data ~pos:weights_pos ~len:weights_len weight_bytes in 218 219 for i = 0 to decoded - 1 do 220 weights.(i) <- Bytes.get_uint8 weight_bytes i 221 done; 222 decoded 223 end 224 in 225 226 build_dtable_from_weights weights num_symbols 227 228(* ========== ENCODING ========== *) 229 230(** Huffman encoding table *) 231type ctable = { 232 codes : int array; (* Canonical code for each symbol *) 233 num_bits : int array; (* Bit length for each symbol *) 234 max_bits : int; 235 num_symbols : int; 236} 237 238(** Build Huffman code from frequencies using package-merge algorithm *) 239let build_ctable counts max_symbol max_bits_limit = 240 let num_symbols = max_symbol + 1 in 241 let freqs = Array.sub counts 0 num_symbols in 242 243 (* Count non-zero frequencies *) 244 let non_zero = ref 0 in 245 for i = 0 to num_symbols - 1 do 246 if freqs.(i) > 0 then incr non_zero 247 done; 248 249 if !non_zero = 0 then 250 { codes = [||]; num_bits = [||]; max_bits = 0; num_symbols = 0 } 251 else if !non_zero = 1 then begin 252 (* Single symbol case *) 253 let num_bits = Array.make num_symbols 0 in 254 for i = 0 to num_symbols - 1 do 255 if freqs.(i) > 0 then num_bits.(i) <- 1 256 done; 257 let codes = Array.make num_symbols 0 in 258 { codes; num_bits; max_bits = 1; num_symbols } 259 end else begin 260 (* Sort symbols by frequency *) 261 let sorted = Array.init num_symbols (fun i -> (freqs.(i), i)) in 262 Array.sort (fun (f1, _) (f2, _) -> compare f1 f2) sorted; 263 264 (* Build Huffman tree using a simple greedy approach *) 265 (* This produces a valid but not necessarily optimal tree *) 266 let bit_lengths = Array.make num_symbols 0 in 267 268 (* Assign bit lengths based on frequency rank *) 269 let active_count = ref 0 in 270 for i = 0 to num_symbols - 1 do 271 let (freq, _sym) = sorted.(num_symbols - 1 - i) in 272 if freq > 0 then incr active_count 273 done; 274 275 (* Use Kraft's inequality to assign optimal lengths *) 276 (* Start with uniform distribution and adjust *) 277 let target_bits = max 1 (highest_set_bit !active_count + 1) in 278 let max_bits = min max_bits_limit (max target_bits 1) in 279 280 (* Simple heuristic: assign bits based on frequency ranking *) 281 let rank = ref 0 in 282 for i = num_symbols - 1 downto 0 do 283 let (freq, sym) = sorted.(i) in 284 if freq > 0 then begin 285 (* More frequent symbols get shorter codes *) 286 let bits = 287 if !rank < (1 lsl (max_bits - 1)) then 288 min max_bits (max 1 (max_bits - highest_set_bit (!rank + 1))) 289 else 290 max_bits 291 in 292 bit_lengths.(sym) <- bits; 293 incr rank 294 end 295 done; 296 297 (* Validate and adjust bit lengths to satisfy Kraft inequality *) 298 let rec adjust () = 299 let kraft_sum = ref 0.0 in 300 for i = 0 to num_symbols - 1 do 301 if bit_lengths.(i) > 0 then 302 kraft_sum := !kraft_sum +. (1.0 /. (float_of_int (1 lsl bit_lengths.(i)))) 303 done; 304 if !kraft_sum > 1.0 then begin 305 (* Increase some lengths *) 306 for i = 0 to num_symbols - 1 do 307 if bit_lengths.(i) > 0 && bit_lengths.(i) < max_bits then begin 308 bit_lengths.(i) <- bit_lengths.(i) + 1 309 end 310 done; 311 adjust () 312 end 313 in 314 adjust (); 315 316 (* Build canonical codes *) 317 let codes = Array.make num_symbols 0 in 318 let actual_max = ref 0 in 319 for i = 0 to num_symbols - 1 do 320 if bit_lengths.(i) > !actual_max then actual_max := bit_lengths.(i) 321 done; 322 323 (* Count symbols at each bit length *) 324 let bl_count = Array.make (!actual_max + 1) 0 in 325 for i = 0 to num_symbols - 1 do 326 if bit_lengths.(i) > 0 then 327 bl_count.(bit_lengths.(i)) <- bl_count.(bit_lengths.(i)) + 1 328 done; 329 330 (* Calculate starting code for each bit length *) 331 let next_code = Array.make (!actual_max + 1) 0 in 332 let code = ref 0 in 333 for bits = 1 to !actual_max do 334 code := (!code + bl_count.(bits - 1)) lsl 1; 335 next_code.(bits) <- !code 336 done; 337 338 (* Assign codes to symbols *) 339 for i = 0 to num_symbols - 1 do 340 let bits = bit_lengths.(i) in 341 if bits > 0 then begin 342 codes.(i) <- next_code.(bits); 343 next_code.(bits) <- next_code.(bits) + 1 344 end 345 done; 346 347 { codes; num_bits = bit_lengths; max_bits = !actual_max; num_symbols } 348 end 349 350(** Convert bit lengths to weights (zstd format) *) 351let bits_to_weights num_bits num_symbols max_bits = 352 let weights = Array.make num_symbols 0 in 353 for i = 0 to num_symbols - 1 do 354 if num_bits.(i) > 0 then 355 weights.(i) <- max_bits + 1 - num_bits.(i) 356 done; 357 weights 358 359(** Write Huffman table header using direct representation. 360 Returns the number of actual symbols to encode. 361 Note: For tables with >127 weights, FSE compression could be used 362 for better ratios, but direct representation is always valid. *) 363let write_header (stream : Bit_writer.Forward.t) ctable = 364 if ctable.num_symbols = 0 then 0 365 else begin 366 let weights = bits_to_weights ctable.num_bits ctable.num_symbols ctable.max_bits in 367 368 (* Find last non-zero weight (implicit last symbol) *) 369 let last_nonzero = ref (ctable.num_symbols - 1) in 370 while !last_nonzero > 0 && weights.(!last_nonzero) = 0 do 371 decr last_nonzero 372 done; 373 374 let num_weights = !last_nonzero in (* Last weight is implicit *) 375 376 (* Direct representation: header byte = 128 + num_weights, then 4 bits per weight *) 377 let header = 128 + num_weights in 378 Bit_writer.Forward.write_byte stream header; 379 380 (* Write weights packed as pairs (high nibble, low nibble) *) 381 for i = 0 to (num_weights - 1) / 2 do 382 let w1 = if 2 * i < num_weights then weights.(2 * i) else 0 in 383 let w2 = if 2 * i + 1 < num_weights then weights.(2 * i + 1) else 0 in 384 Bit_writer.Forward.write_byte stream ((w1 lsl 4) lor w2) 385 done; 386 387 num_weights + 1 388 end 389 390(** Encode a single symbol (write to backward stream) *) 391let[@inline] encode_symbol ctable (stream : Bit_writer.Backward.t) symbol = 392 let code = ctable.codes.(symbol) in 393 let bits = ctable.num_bits.(symbol) in 394 if bits > 0 then 395 Bit_writer.Backward.write_bits stream code bits 396 397(** Compress literals to a single Huffman stream *) 398let compress_1stream ctable literals ~pos ~len = 399 let stream = Bit_writer.Backward.create (len * 2 + 16) in 400 401 (* Encode symbols in reverse order *) 402 for i = pos + len - 1 downto pos do 403 let sym = Bytes.get_uint8 literals i in 404 encode_symbol ctable stream sym 405 done; 406 407 Bit_writer.Backward.finalize stream 408 409(** Compress literals to 4 interleaved Huffman streams *) 410let compress_4stream ctable literals ~pos ~len = 411 let chunk_size = (len + 3) / 4 in 412 let chunk4_size = len - 3 * chunk_size in 413 414 (* Compress each stream *) 415 let stream1 = compress_1stream ctable literals ~pos ~len:chunk_size in 416 let stream2 = compress_1stream ctable literals ~pos:(pos + chunk_size) ~len:chunk_size in 417 let stream3 = compress_1stream ctable literals ~pos:(pos + 2 * chunk_size) ~len:chunk_size in 418 let stream4 = compress_1stream ctable literals ~pos:(pos + 3 * chunk_size) ~len:chunk4_size in 419 420 (* Build output with jump table *) 421 let size1 = Bytes.length stream1 in 422 let size2 = Bytes.length stream2 in 423 let size3 = Bytes.length stream3 in 424 let total = 6 + size1 + size2 + size3 + Bytes.length stream4 in 425 426 let output = Bytes.create total in 427 Bytes.set_uint16_le output 0 size1; 428 Bytes.set_uint16_le output 2 size2; 429 Bytes.set_uint16_le output 4 size3; 430 Bytes.blit stream1 0 output 6 size1; 431 Bytes.blit stream2 0 output (6 + size1) size2; 432 Bytes.blit stream3 0 output (6 + size1 + size2) size3; 433 Bytes.blit stream4 0 output (6 + size1 + size2 + size3) (Bytes.length stream4); 434 435 output