Zstd compression in pure OCaml
at main 468 lines 16 kB view raw
1(** Finite State Entropy (FSE) decoding for Zstandard. 2 3 FSE is an entropy coding method based on ANS (Asymmetric Numeral Systems). 4 FSE streams are read backwards (from end to beginning). *) 5 6(** FSE decoding table entry *) 7type entry = { 8 symbol : int; 9 num_bits : int; 10 new_state_base : int; 11} 12 13(** FSE decoding table *) 14type dtable = { 15 entries : entry array; 16 accuracy_log : int; 17} 18 19(** Find the highest set bit (floor(log2(n))) *) 20let[@inline] highest_set_bit n = 21 if n = 0 then -1 22 else 23 let rec loop i = 24 if (1 lsl i) <= n then loop (i + 1) 25 else i - 1 26 in 27 loop 0 28 29(** Build FSE decoding table from normalized frequencies. 30 Frequencies can be negative (-1 means probability < 1). *) 31let build_dtable frequencies accuracy_log = 32 let table_size = 1 lsl accuracy_log in 33 let num_symbols = Array.length frequencies in 34 35 (* Create entries array *) 36 let entries = Array.init table_size (fun _ -> 37 { symbol = 0; num_bits = 0; new_state_base = 0 } 38 ) in 39 40 (* Track state descriptors for each symbol *) 41 let state_desc = Array.make num_symbols 0 in 42 43 (* First pass: place symbols with prob < 1 at the end *) 44 let high_threshold = ref table_size in 45 for s = 0 to num_symbols - 1 do 46 if frequencies.(s) = -1 then begin 47 decr high_threshold; 48 entries.(!high_threshold) <- { symbol = s; num_bits = 0; new_state_base = 0 }; 49 state_desc.(s) <- 1 50 end 51 done; 52 53 (* Second pass: distribute remaining symbols using the step formula *) 54 let step = (table_size lsr 1) + (table_size lsr 3) + 3 in 55 let mask = table_size - 1 in 56 let pos = ref 0 in 57 58 for s = 0 to num_symbols - 1 do 59 if frequencies.(s) > 0 then begin 60 state_desc.(s) <- frequencies.(s); 61 for _ = 0 to frequencies.(s) - 1 do 62 entries.(!pos) <- { entries.(!pos) with symbol = s }; 63 (* Skip positions occupied by prob < 1 symbols *) 64 pos := (!pos + step) land mask; 65 while !pos >= !high_threshold do 66 pos := (!pos + step) land mask 67 done 68 done 69 end 70 done; 71 72 if !pos <> 0 then 73 raise (Constants.Zstd_error Constants.Invalid_fse_table); 74 75 (* Third pass: fill in num_bits and new_state_base *) 76 for i = 0 to table_size - 1 do 77 let s = entries.(i).symbol in 78 let next_state_desc = state_desc.(s) in 79 state_desc.(s) <- next_state_desc + 1; 80 81 (* Number of bits is accuracy_log - log2(next_state_desc) *) 82 let num_bits = accuracy_log - highest_set_bit next_state_desc in 83 (* new_state_base = (next_state_desc << num_bits) - table_size *) 84 let new_state_base = (next_state_desc lsl num_bits) - table_size in 85 86 entries.(i) <- { entries.(i) with num_bits; new_state_base } 87 done; 88 89 { entries; accuracy_log } 90 91(** Build RLE table (single symbol repeated) *) 92let build_dtable_rle symbol = 93 { 94 entries = [| { symbol; num_bits = 0; new_state_base = 0 } |]; 95 accuracy_log = 0; 96 } 97 98(** Peek at the symbol for current state (doesn't update state) *) 99let[@inline] peek_symbol dtable state = 100 dtable.entries.(state).symbol 101 102(** Update state by reading bits from the stream *) 103let[@inline] update_state dtable state (stream : Bit_reader.Backward.t) = 104 let entry = dtable.entries.(state) in 105 let bits = Bit_reader.Backward.read_bits stream entry.num_bits in 106 entry.new_state_base + bits 107 108(** Decode symbol and update state *) 109let[@inline] decode_symbol dtable state stream = 110 let symbol = peek_symbol dtable state in 111 let new_state = update_state dtable state stream in 112 (symbol, new_state) 113 114(** Initialize state by reading accuracy_log bits *) 115let[@inline] init_state dtable (stream : Bit_reader.Backward.t) = 116 Bit_reader.Backward.read_bits stream dtable.accuracy_log 117 118(** Decode FSE header and build decoding table. 119 Returns the table and advances the forward stream. *) 120let decode_header (stream : Bit_reader.Forward.t) max_accuracy_log = 121 (* Accuracy log is first 4 bits + 5 *) 122 let accuracy_log = (Bit_reader.Forward.read_bits stream 4) + 5 in 123 if accuracy_log > max_accuracy_log then 124 raise (Constants.Zstd_error Constants.Invalid_fse_table); 125 126 let table_size = 1 lsl accuracy_log in 127 let frequencies = Array.make Constants.max_fse_symbols 0 in 128 129 let remaining = ref table_size in 130 let symbol = ref 0 in 131 132 while !remaining > 0 && !symbol < Constants.max_fse_symbols do 133 (* Determine how many bits we might need *) 134 let bits_needed = highest_set_bit (!remaining + 1) + 1 in 135 let value = Bit_reader.Forward.read_bits stream bits_needed in 136 137 (* Small value optimization: values < threshold use one less bit *) 138 let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in 139 let lower_mask = (1 lsl (bits_needed - 1)) - 1 in 140 141 let (actual_value, bits_consumed) = 142 if (value land lower_mask) < threshold then 143 (value land lower_mask, bits_needed - 1) 144 else if value > lower_mask then 145 (value - threshold, bits_needed) 146 else 147 (value, bits_needed) 148 in 149 150 (* Rewind if we read too many bits *) 151 if bits_consumed < bits_needed then 152 Bit_reader.Forward.rewind_bits stream 1; 153 154 (* Probability = value - 1 (so value 0 means prob = -1) *) 155 let prob = actual_value - 1 in 156 frequencies.(!symbol) <- prob; 157 remaining := !remaining - abs prob; 158 incr symbol; 159 160 (* Handle zero probability with repeat flags *) 161 if prob = 0 then begin 162 let rec read_zeroes () = 163 let repeat = Bit_reader.Forward.read_bits stream 2 in 164 for _ = 1 to repeat do 165 if !symbol < Constants.max_fse_symbols then begin 166 frequencies.(!symbol) <- 0; 167 incr symbol 168 end 169 done; 170 if repeat = 3 then read_zeroes () 171 in 172 read_zeroes () 173 end 174 done; 175 176 (* Align to byte boundary *) 177 Bit_reader.Forward.align stream; 178 179 if !remaining <> 0 then 180 raise (Constants.Zstd_error Constants.Invalid_fse_table); 181 182 (* Build the decoding table *) 183 let freq_slice = Array.sub frequencies 0 !symbol in 184 build_dtable freq_slice accuracy_log 185 186(** Decompress interleaved 2-state FSE stream. 187 Used for Huffman weight encoding. Returns number of symbols decoded. *) 188let decompress_interleaved2 dtable src ~pos ~len output = 189 let stream = Bit_reader.Backward.of_bytes src ~pos ~len in 190 191 (* Initialize two states *) 192 let state1 = ref (init_state dtable stream) in 193 let state2 = ref (init_state dtable stream) in 194 195 let out_pos = ref 0 in 196 let out_len = Bytes.length output in 197 198 (* Decode symbols alternating between states *) 199 while Bit_reader.Backward.remaining stream >= 0 do 200 if !out_pos >= out_len then 201 raise (Constants.Zstd_error Constants.Output_too_small); 202 203 let (sym1, new_state1) = decode_symbol dtable !state1 stream in 204 Bytes.set_uint8 output !out_pos sym1; 205 incr out_pos; 206 state1 := new_state1; 207 208 if Bit_reader.Backward.remaining stream < 0 then begin 209 (* Stream exhausted, output final symbol from state2 *) 210 if !out_pos < out_len then begin 211 Bytes.set_uint8 output !out_pos (peek_symbol dtable !state2); 212 incr out_pos 213 end 214 end else begin 215 if !out_pos >= out_len then 216 raise (Constants.Zstd_error Constants.Output_too_small); 217 218 let (sym2, new_state2) = decode_symbol dtable !state2 stream in 219 Bytes.set_uint8 output !out_pos sym2; 220 incr out_pos; 221 state2 := new_state2; 222 223 if Bit_reader.Backward.remaining stream < 0 then begin 224 (* Stream exhausted, output final symbol from state1 *) 225 if !out_pos < out_len then begin 226 Bytes.set_uint8 output !out_pos (peek_symbol dtable !state1); 227 incr out_pos 228 end 229 end 230 end 231 done; 232 233 !out_pos 234 235(** Build decoding table from predefined distribution *) 236let build_predefined_table distribution accuracy_log = 237 build_dtable distribution accuracy_log 238 239(* ========== ENCODING ========== *) 240 241(** FSE compression table - matches C zstd's FSE_symbolCompressionTransform format. 242 deltaNbBits is encoded as (maxBitsOut << 16) - minStatePlus 243 This allows computing nbBitsOut = (state + deltaNbBits) >> 16 *) 244type symbol_transform = { 245 delta_nb_bits : int; (* (maxBitsOut << 16) - minStatePlus *) 246 delta_find_state : int; (* Cumulative offset to find next state *) 247} 248 249(** FSE compression table *) 250type ctable = { 251 symbol_tt : symbol_transform array; (* Symbol compression transforms *) 252 state_table : int array; (* Next state lookup table *) 253 accuracy_log : int; 254 table_size : int; 255} 256 257(** FSE compression state - matches C zstd's FSE_CState_t *) 258type cstate = { 259 mutable value : int; (* Current state value *) 260 ctable : ctable; (* Reference to compression table *) 261} 262 263(** Count symbol frequencies *) 264let count_symbols src ~pos ~len max_symbol = 265 let counts = Array.make (max_symbol + 1) 0 in 266 for i = pos to pos + len - 1 do 267 let s = Bytes.get_uint8 src i in 268 if s <= max_symbol then 269 counts.(s) <- counts.(s) + 1 270 done; 271 counts 272 273(** Normalize counts to sum to table_size *) 274let normalize_counts counts total accuracy_log = 275 let table_size = 1 lsl accuracy_log in 276 let num_symbols = Array.length counts in 277 let norm = Array.make num_symbols 0 in 278 279 if total = 0 then norm 280 else begin 281 let scale = table_size * 256 / total in 282 let distributed = ref 0 in 283 284 for s = 0 to num_symbols - 1 do 285 if counts.(s) > 0 then begin 286 let proba = (counts.(s) * scale + 128) / 256 in 287 let proba = max 1 proba in 288 norm.(s) <- proba; 289 distributed := !distributed + proba 290 end 291 done; 292 293 while !distributed > table_size do 294 let max_val = ref 0 in 295 let max_idx = ref 0 in 296 for s = 0 to num_symbols - 1 do 297 if norm.(s) > !max_val then begin 298 max_val := norm.(s); 299 max_idx := s 300 end 301 done; 302 norm.(!max_idx) <- norm.(!max_idx) - 1; 303 decr distributed 304 done; 305 306 while !distributed < table_size do 307 let min_val = ref max_int in 308 let min_idx = ref 0 in 309 for s = 0 to num_symbols - 1 do 310 if norm.(s) > 0 && norm.(s) < !min_val then begin 311 min_val := norm.(s); 312 min_idx := s 313 end 314 done; 315 norm.(!min_idx) <- norm.(!min_idx) + 1; 316 incr distributed 317 done; 318 319 norm 320 end 321 322(** Build FSE compression table from normalized counts. 323 Matches C zstd's FSE_buildCTable_wksp algorithm exactly. *) 324let build_ctable norm_counts accuracy_log = 325 let table_size = 1 lsl accuracy_log in 326 let table_mask = table_size - 1 in 327 let num_symbols = Array.length norm_counts in 328 let step = (table_size lsr 1) + (table_size lsr 3) + 3 in 329 330 (* Symbol distribution table - which symbol at each state *) 331 let table_symbol = Array.make table_size 0 in 332 333 (* Cumulative counts for state table indexing *) 334 let cumul = Array.make (num_symbols + 1) 0 in 335 cumul.(0) <- 0; 336 for s = 0 to num_symbols - 1 do 337 let count = if norm_counts.(s) = -1 then 1 else max 0 norm_counts.(s) in 338 cumul.(s + 1) <- cumul.(s) + count 339 done; 340 341 (* Place low probability symbols at the end *) 342 let high_threshold = ref (table_size - 1) in 343 for s = 0 to num_symbols - 1 do 344 if norm_counts.(s) = -1 then begin 345 table_symbol.(!high_threshold) <- s; 346 decr high_threshold 347 end 348 done; 349 350 (* Spread remaining symbols using step formula *) 351 let pos = ref 0 in 352 for s = 0 to num_symbols - 1 do 353 let count = norm_counts.(s) in 354 if count > 0 then begin 355 for _ = 0 to count - 1 do 356 table_symbol.(!pos) <- s; 357 pos := (!pos + step) land table_mask; 358 while !pos > !high_threshold do 359 pos := (!pos + step) land table_mask 360 done 361 done 362 end 363 done; 364 365 (* Build state table - for each position, compute next state *) 366 let state_table = Array.make table_size 0 in 367 let cumul_copy = Array.copy cumul in 368 for u = 0 to table_size - 1 do 369 let s = table_symbol.(u) in 370 state_table.(cumul_copy.(s)) <- table_size + u; 371 cumul_copy.(s) <- cumul_copy.(s) + 1 372 done; 373 374 (* Build symbol compression transforms *) 375 let symbol_tt = Array.init num_symbols (fun s -> 376 let count = norm_counts.(s) in 377 match count with 378 | 0 -> 379 (* Zero probability - use max bits (shouldn't be encoded) *) 380 { delta_nb_bits = ((accuracy_log + 1) lsl 16) - (1 lsl accuracy_log); 381 delta_find_state = 0 } 382 | -1 | 1 -> 383 (* Low probability symbol *) 384 { delta_nb_bits = (accuracy_log lsl 16) - (1 lsl accuracy_log); 385 delta_find_state = cumul.(s) - 1 } 386 | _ -> 387 (* Normal symbol *) 388 let max_bits_out = accuracy_log - highest_set_bit (count - 1) in 389 let min_state_plus = count lsl max_bits_out in 390 { delta_nb_bits = (max_bits_out lsl 16) - min_state_plus; 391 delta_find_state = cumul.(s) - count } 392 ) in 393 394 { symbol_tt; state_table; accuracy_log; table_size } 395 396(** Initialize compression state - matches C's FSE_initCState *) 397let init_cstate ctable = 398 { value = 1 lsl ctable.accuracy_log; ctable } 399 400(** Initialize compression state with first symbol - matches C's FSE_initCState2. 401 This saves bits by using the smallest valid state for the first symbol. *) 402let init_cstate2 ctable symbol = 403 let st = ctable.symbol_tt.(symbol) in 404 let nb_bits_out = (st.delta_nb_bits + (1 lsl 15)) lsr 16 in 405 let init_value = (nb_bits_out lsl 16) - st.delta_nb_bits in 406 let state_idx = (init_value lsr nb_bits_out) + st.delta_find_state in 407 { value = ctable.state_table.(state_idx); ctable } 408 409(** Encode a single symbol - matches C's FSE_encodeSymbol exactly. 410 Outputs bits representing state transition and updates state. *) 411let[@inline] encode_symbol (stream : Bit_writer.Backward.t) cstate symbol = 412 let st = cstate.ctable.symbol_tt.(symbol) in 413 let nb_bits_out = (cstate.value + st.delta_nb_bits) lsr 16 in 414 Bit_writer.Backward.write_bits stream cstate.value nb_bits_out; 415 let state_idx = (cstate.value lsr nb_bits_out) + st.delta_find_state in 416 cstate.value <- cstate.ctable.state_table.(state_idx) 417 418(** Flush compression state - matches C's FSE_flushCState. 419 Outputs final state value to allow decoder to initialize. *) 420let[@inline] flush_cstate (stream : Bit_writer.Backward.t) cstate = 421 Bit_writer.Backward.write_bits stream cstate.value cstate.ctable.accuracy_log 422 423(** Write FSE header (normalized counts) *) 424let write_header (stream : Bit_writer.Forward.t) norm_counts accuracy_log = 425 Bit_writer.Forward.write_bits stream (accuracy_log - 5) 4; 426 427 let table_size = 1 lsl accuracy_log in 428 let num_symbols = Array.length norm_counts in 429 let remaining = ref table_size in 430 let symbol = ref 0 in 431 432 while !remaining > 0 && !symbol < num_symbols do 433 let count = norm_counts.(!symbol) in 434 let value = count + 1 in 435 436 let bits_needed = highest_set_bit (!remaining + 1) + 1 in 437 let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in 438 439 if value < threshold then 440 Bit_writer.Forward.write_bits stream value (bits_needed - 1) 441 else 442 Bit_writer.Forward.write_bits stream (value + threshold) bits_needed; 443 444 remaining := !remaining - abs count; 445 incr symbol; 446 447 if count = 0 then begin 448 let rec count_zeroes acc = 449 if !symbol < num_symbols && norm_counts.(!symbol) = 0 then begin 450 incr symbol; 451 count_zeroes (acc + 1) 452 end else acc 453 in 454 let zeroes = count_zeroes 0 in 455 let rec write_repeats n = 456 if n >= 3 then begin 457 Bit_writer.Forward.write_bits stream 3 2; 458 write_repeats (n - 3) 459 end else 460 Bit_writer.Forward.write_bits stream n 2 461 in 462 write_repeats zeroes 463 end 464 done 465 466(** Build encoding table from predefined distribution *) 467let build_predefined_ctable distribution accuracy_log = 468 build_ctable distribution accuracy_log