(** Finite State Entropy (FSE) decoding for Zstandard. FSE is an entropy coding method based on ANS (Asymmetric Numeral Systems). FSE streams are read backwards (from end to beginning). *) (** FSE decoding table entry *) type entry = { symbol : int; num_bits : int; new_state_base : int; } (** FSE decoding table *) type dtable = { entries : entry array; accuracy_log : int; } (** Find the highest set bit (floor(log2(n))) *) let[@inline] highest_set_bit n = if n = 0 then -1 else let rec loop i = if (1 lsl i) <= n then loop (i + 1) else i - 1 in loop 0 (** Build FSE decoding table from normalized frequencies. Frequencies can be negative (-1 means probability < 1). *) let build_dtable frequencies accuracy_log = let table_size = 1 lsl accuracy_log in let num_symbols = Array.length frequencies in (* Create entries array *) let entries = Array.init table_size (fun _ -> { symbol = 0; num_bits = 0; new_state_base = 0 } ) in (* Track state descriptors for each symbol *) let state_desc = Array.make num_symbols 0 in (* First pass: place symbols with prob < 1 at the end *) let high_threshold = ref table_size in for s = 0 to num_symbols - 1 do if frequencies.(s) = -1 then begin decr high_threshold; entries.(!high_threshold) <- { symbol = s; num_bits = 0; new_state_base = 0 }; state_desc.(s) <- 1 end done; (* Second pass: distribute remaining symbols using the step formula *) let step = (table_size lsr 1) + (table_size lsr 3) + 3 in let mask = table_size - 1 in let pos = ref 0 in for s = 0 to num_symbols - 1 do if frequencies.(s) > 0 then begin state_desc.(s) <- frequencies.(s); for _ = 0 to frequencies.(s) - 1 do entries.(!pos) <- { entries.(!pos) with symbol = s }; (* Skip positions occupied by prob < 1 symbols *) pos := (!pos + step) land mask; while !pos >= !high_threshold do pos := (!pos + step) land mask done done end done; if !pos <> 0 then raise (Constants.Zstd_error Constants.Invalid_fse_table); (* Third pass: fill in num_bits and new_state_base *) for i = 0 to table_size - 1 do let s = entries.(i).symbol in let next_state_desc = state_desc.(s) in state_desc.(s) <- next_state_desc + 1; (* Number of bits is accuracy_log - log2(next_state_desc) *) let num_bits = accuracy_log - highest_set_bit next_state_desc in (* new_state_base = (next_state_desc << num_bits) - table_size *) let new_state_base = (next_state_desc lsl num_bits) - table_size in entries.(i) <- { entries.(i) with num_bits; new_state_base } done; { entries; accuracy_log } (** Build RLE table (single symbol repeated) *) let build_dtable_rle symbol = { entries = [| { symbol; num_bits = 0; new_state_base = 0 } |]; accuracy_log = 0; } (** Peek at the symbol for current state (doesn't update state) *) let[@inline] peek_symbol dtable state = dtable.entries.(state).symbol (** Update state by reading bits from the stream *) let[@inline] update_state dtable state (stream : Bit_reader.Backward.t) = let entry = dtable.entries.(state) in let bits = Bit_reader.Backward.read_bits stream entry.num_bits in entry.new_state_base + bits (** Decode symbol and update state *) let[@inline] decode_symbol dtable state stream = let symbol = peek_symbol dtable state in let new_state = update_state dtable state stream in (symbol, new_state) (** Initialize state by reading accuracy_log bits *) let[@inline] init_state dtable (stream : Bit_reader.Backward.t) = Bit_reader.Backward.read_bits stream dtable.accuracy_log (** Decode FSE header and build decoding table. Returns the table and advances the forward stream. *) let decode_header (stream : Bit_reader.Forward.t) max_accuracy_log = (* Accuracy log is first 4 bits + 5 *) let accuracy_log = (Bit_reader.Forward.read_bits stream 4) + 5 in if accuracy_log > max_accuracy_log then raise (Constants.Zstd_error Constants.Invalid_fse_table); let table_size = 1 lsl accuracy_log in let frequencies = Array.make Constants.max_fse_symbols 0 in let remaining = ref table_size in let symbol = ref 0 in while !remaining > 0 && !symbol < Constants.max_fse_symbols do (* Determine how many bits we might need *) let bits_needed = highest_set_bit (!remaining + 1) + 1 in let value = Bit_reader.Forward.read_bits stream bits_needed in (* Small value optimization: values < threshold use one less bit *) let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in let lower_mask = (1 lsl (bits_needed - 1)) - 1 in let (actual_value, bits_consumed) = if (value land lower_mask) < threshold then (value land lower_mask, bits_needed - 1) else if value > lower_mask then (value - threshold, bits_needed) else (value, bits_needed) in (* Rewind if we read too many bits *) if bits_consumed < bits_needed then Bit_reader.Forward.rewind_bits stream 1; (* Probability = value - 1 (so value 0 means prob = -1) *) let prob = actual_value - 1 in frequencies.(!symbol) <- prob; remaining := !remaining - abs prob; incr symbol; (* Handle zero probability with repeat flags *) if prob = 0 then begin let rec read_zeroes () = let repeat = Bit_reader.Forward.read_bits stream 2 in for _ = 1 to repeat do if !symbol < Constants.max_fse_symbols then begin frequencies.(!symbol) <- 0; incr symbol end done; if repeat = 3 then read_zeroes () in read_zeroes () end done; (* Align to byte boundary *) Bit_reader.Forward.align stream; if !remaining <> 0 then raise (Constants.Zstd_error Constants.Invalid_fse_table); (* Build the decoding table *) let freq_slice = Array.sub frequencies 0 !symbol in build_dtable freq_slice accuracy_log (** Decompress interleaved 2-state FSE stream. Used for Huffman weight encoding. Returns number of symbols decoded. *) let decompress_interleaved2 dtable src ~pos ~len output = let stream = Bit_reader.Backward.of_bytes src ~pos ~len in (* Initialize two states *) let state1 = ref (init_state dtable stream) in let state2 = ref (init_state dtable stream) in let out_pos = ref 0 in let out_len = Bytes.length output in (* Decode symbols alternating between states *) while Bit_reader.Backward.remaining stream >= 0 do if !out_pos >= out_len then raise (Constants.Zstd_error Constants.Output_too_small); let (sym1, new_state1) = decode_symbol dtable !state1 stream in Bytes.set_uint8 output !out_pos sym1; incr out_pos; state1 := new_state1; if Bit_reader.Backward.remaining stream < 0 then begin (* Stream exhausted, output final symbol from state2 *) if !out_pos < out_len then begin Bytes.set_uint8 output !out_pos (peek_symbol dtable !state2); incr out_pos end end else begin if !out_pos >= out_len then raise (Constants.Zstd_error Constants.Output_too_small); let (sym2, new_state2) = decode_symbol dtable !state2 stream in Bytes.set_uint8 output !out_pos sym2; incr out_pos; state2 := new_state2; if Bit_reader.Backward.remaining stream < 0 then begin (* Stream exhausted, output final symbol from state1 *) if !out_pos < out_len then begin Bytes.set_uint8 output !out_pos (peek_symbol dtable !state1); incr out_pos end end end done; !out_pos (** Build decoding table from predefined distribution *) let build_predefined_table distribution accuracy_log = build_dtable distribution accuracy_log (* ========== ENCODING ========== *) (** FSE compression table - matches C zstd's FSE_symbolCompressionTransform format. deltaNbBits is encoded as (maxBitsOut << 16) - minStatePlus This allows computing nbBitsOut = (state + deltaNbBits) >> 16 *) type symbol_transform = { delta_nb_bits : int; (* (maxBitsOut << 16) - minStatePlus *) delta_find_state : int; (* Cumulative offset to find next state *) } (** FSE compression table *) type ctable = { symbol_tt : symbol_transform array; (* Symbol compression transforms *) state_table : int array; (* Next state lookup table *) accuracy_log : int; table_size : int; } (** FSE compression state - matches C zstd's FSE_CState_t *) type cstate = { mutable value : int; (* Current state value *) ctable : ctable; (* Reference to compression table *) } (** Count symbol frequencies *) let count_symbols src ~pos ~len max_symbol = let counts = Array.make (max_symbol + 1) 0 in for i = pos to pos + len - 1 do let s = Bytes.get_uint8 src i in if s <= max_symbol then counts.(s) <- counts.(s) + 1 done; counts (** Normalize counts to sum to table_size *) let normalize_counts counts total accuracy_log = let table_size = 1 lsl accuracy_log in let num_symbols = Array.length counts in let norm = Array.make num_symbols 0 in if total = 0 then norm else begin let scale = table_size * 256 / total in let distributed = ref 0 in for s = 0 to num_symbols - 1 do if counts.(s) > 0 then begin let proba = (counts.(s) * scale + 128) / 256 in let proba = max 1 proba in norm.(s) <- proba; distributed := !distributed + proba end done; while !distributed > table_size do let max_val = ref 0 in let max_idx = ref 0 in for s = 0 to num_symbols - 1 do if norm.(s) > !max_val then begin max_val := norm.(s); max_idx := s end done; norm.(!max_idx) <- norm.(!max_idx) - 1; decr distributed done; while !distributed < table_size do let min_val = ref max_int in let min_idx = ref 0 in for s = 0 to num_symbols - 1 do if norm.(s) > 0 && norm.(s) < !min_val then begin min_val := norm.(s); min_idx := s end done; norm.(!min_idx) <- norm.(!min_idx) + 1; incr distributed done; norm end (** Build FSE compression table from normalized counts. Matches C zstd's FSE_buildCTable_wksp algorithm exactly. *) let build_ctable norm_counts accuracy_log = let table_size = 1 lsl accuracy_log in let table_mask = table_size - 1 in let num_symbols = Array.length norm_counts in let step = (table_size lsr 1) + (table_size lsr 3) + 3 in (* Symbol distribution table - which symbol at each state *) let table_symbol = Array.make table_size 0 in (* Cumulative counts for state table indexing *) let cumul = Array.make (num_symbols + 1) 0 in cumul.(0) <- 0; for s = 0 to num_symbols - 1 do let count = if norm_counts.(s) = -1 then 1 else max 0 norm_counts.(s) in cumul.(s + 1) <- cumul.(s) + count done; (* Place low probability symbols at the end *) let high_threshold = ref (table_size - 1) in for s = 0 to num_symbols - 1 do if norm_counts.(s) = -1 then begin table_symbol.(!high_threshold) <- s; decr high_threshold end done; (* Spread remaining symbols using step formula *) let pos = ref 0 in for s = 0 to num_symbols - 1 do let count = norm_counts.(s) in if count > 0 then begin for _ = 0 to count - 1 do table_symbol.(!pos) <- s; pos := (!pos + step) land table_mask; while !pos > !high_threshold do pos := (!pos + step) land table_mask done done end done; (* Build state table - for each position, compute next state *) let state_table = Array.make table_size 0 in let cumul_copy = Array.copy cumul in for u = 0 to table_size - 1 do let s = table_symbol.(u) in state_table.(cumul_copy.(s)) <- table_size + u; cumul_copy.(s) <- cumul_copy.(s) + 1 done; (* Build symbol compression transforms *) let symbol_tt = Array.init num_symbols (fun s -> let count = norm_counts.(s) in match count with | 0 -> (* Zero probability - use max bits (shouldn't be encoded) *) { delta_nb_bits = ((accuracy_log + 1) lsl 16) - (1 lsl accuracy_log); delta_find_state = 0 } | -1 | 1 -> (* Low probability symbol *) { delta_nb_bits = (accuracy_log lsl 16) - (1 lsl accuracy_log); delta_find_state = cumul.(s) - 1 } | _ -> (* Normal symbol *) let max_bits_out = accuracy_log - highest_set_bit (count - 1) in let min_state_plus = count lsl max_bits_out in { delta_nb_bits = (max_bits_out lsl 16) - min_state_plus; delta_find_state = cumul.(s) - count } ) in { symbol_tt; state_table; accuracy_log; table_size } (** Initialize compression state - matches C's FSE_initCState *) let init_cstate ctable = { value = 1 lsl ctable.accuracy_log; ctable } (** Initialize compression state with first symbol - matches C's FSE_initCState2. This saves bits by using the smallest valid state for the first symbol. *) let init_cstate2 ctable symbol = let st = ctable.symbol_tt.(symbol) in let nb_bits_out = (st.delta_nb_bits + (1 lsl 15)) lsr 16 in let init_value = (nb_bits_out lsl 16) - st.delta_nb_bits in let state_idx = (init_value lsr nb_bits_out) + st.delta_find_state in { value = ctable.state_table.(state_idx); ctable } (** Encode a single symbol - matches C's FSE_encodeSymbol exactly. Outputs bits representing state transition and updates state. *) let[@inline] encode_symbol (stream : Bit_writer.Backward.t) cstate symbol = let st = cstate.ctable.symbol_tt.(symbol) in let nb_bits_out = (cstate.value + st.delta_nb_bits) lsr 16 in Bit_writer.Backward.write_bits stream cstate.value nb_bits_out; let state_idx = (cstate.value lsr nb_bits_out) + st.delta_find_state in cstate.value <- cstate.ctable.state_table.(state_idx) (** Flush compression state - matches C's FSE_flushCState. Outputs final state value to allow decoder to initialize. *) let[@inline] flush_cstate (stream : Bit_writer.Backward.t) cstate = Bit_writer.Backward.write_bits stream cstate.value cstate.ctable.accuracy_log (** Write FSE header (normalized counts) *) let write_header (stream : Bit_writer.Forward.t) norm_counts accuracy_log = Bit_writer.Forward.write_bits stream (accuracy_log - 5) 4; let table_size = 1 lsl accuracy_log in let num_symbols = Array.length norm_counts in let remaining = ref table_size in let symbol = ref 0 in while !remaining > 0 && !symbol < num_symbols do let count = norm_counts.(!symbol) in let value = count + 1 in let bits_needed = highest_set_bit (!remaining + 1) + 1 in let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in if value < threshold then Bit_writer.Forward.write_bits stream value (bits_needed - 1) else Bit_writer.Forward.write_bits stream (value + threshold) bits_needed; remaining := !remaining - abs count; incr symbol; if count = 0 then begin let rec count_zeroes acc = if !symbol < num_symbols && norm_counts.(!symbol) = 0 then begin incr symbol; count_zeroes (acc + 1) end else acc in let zeroes = count_zeroes 0 in let rec write_repeats n = if n >= 3 then begin Bit_writer.Forward.write_bits stream 3 2; write_repeats (n - 3) end else Bit_writer.Forward.write_bits stream n 2 in write_repeats zeroes end done (** Build encoding table from predefined distribution *) let build_predefined_ctable distribution accuracy_log = build_ctable distribution accuracy_log