(** Huffman coding for Zstandard literals decompression. Zstd uses canonical Huffman codes for literal compression. Huffman streams are read backwards like FSE streams. *) (** Huffman decoding table entry *) type entry = { symbol : int; num_bits : int; } (** Huffman decoding table *) type dtable = { entries : entry array; max_bits : int; } let highest_set_bit = Fse.highest_set_bit (** Build Huffman table from bit lengths. Uses canonical Huffman coding. *) let build_dtable_from_bits bits num_symbols = if num_symbols > Constants.max_huffman_symbols then raise (Constants.Zstd_error Constants.Invalid_huffman_table); (* Find max bits and count symbols per bit length *) let max_bits = ref 0 in let rank_count = Array.make (Constants.max_huffman_bits + 1) 0 in for i = 0 to num_symbols - 1 do let b = bits.(i) in if b > Constants.max_huffman_bits then raise (Constants.Zstd_error Constants.Invalid_huffman_table); if b > !max_bits then max_bits := b; rank_count.(b) <- rank_count.(b) + 1 done; if !max_bits = 0 then raise (Constants.Zstd_error Constants.Invalid_huffman_table); let table_size = 1 lsl !max_bits in let entries = Array.init table_size (fun _ -> { symbol = 0; num_bits = 0 } ) in (* Calculate starting indices for each rank *) let rank_idx = Array.make (Constants.max_huffman_bits + 1) 0 in rank_idx.(!max_bits) <- 0; for i = !max_bits downto 1 do rank_idx.(i - 1) <- rank_idx.(i) + rank_count.(i) * (1 lsl (!max_bits - i)); (* Fill in num_bits for this range *) for j = rank_idx.(i) to rank_idx.(i - 1) - 1 do entries.(j) <- { entries.(j) with num_bits = i } done done; if rank_idx.(0) <> table_size then raise (Constants.Zstd_error Constants.Invalid_huffman_table); (* Assign symbols to table entries *) for i = 0 to num_symbols - 1 do let b = bits.(i) in if b <> 0 then begin let code = rank_idx.(b) in let len = 1 lsl (!max_bits - b) in for j = code to code + len - 1 do entries.(j) <- { entries.(j) with symbol = i } done; rank_idx.(b) <- code + len end done; { entries; max_bits = !max_bits } (** Build table from weights (as decoded from zstd format) *) let build_dtable_from_weights weights num_symbols = if num_symbols + 1 > Constants.max_huffman_symbols then raise (Constants.Zstd_error Constants.Invalid_huffman_table); let bits = Array.make (num_symbols + 1) 0 in (* Calculate weight sum to find max_bits and last weight *) let weight_sum = ref 0 in for i = 0 to num_symbols - 1 do let w = weights.(i) in if w > Constants.max_huffman_bits then raise (Constants.Zstd_error Constants.Invalid_huffman_table); if w > 0 then weight_sum := !weight_sum + (1 lsl (w - 1)) done; (* Find max_bits (first power of 2 > weight_sum) *) let max_bits = highest_set_bit !weight_sum + 1 in let left_over = (1 lsl max_bits) - !weight_sum in (* left_over must be a power of 2 *) if left_over land (left_over - 1) <> 0 then raise (Constants.Zstd_error Constants.Invalid_huffman_table); let last_weight = highest_set_bit left_over + 1 in (* Convert weights to bit lengths *) for i = 0 to num_symbols - 1 do let w = weights.(i) in bits.(i) <- if w > 0 then max_bits + 1 - w else 0 done; bits.(num_symbols) <- max_bits + 1 - last_weight; build_dtable_from_bits bits (num_symbols + 1) (** Initialize Huffman state by reading max_bits *) let[@inline] init_state dtable (stream : Bit_reader.Backward.t) = Bit_reader.Backward.read_bits stream dtable.max_bits (** Decode a symbol and update state *) let[@inline] decode_symbol dtable state (stream : Bit_reader.Backward.t) = let entry = dtable.entries.(state) in let symbol = entry.symbol in let bits_used = entry.num_bits in (* Shift out used bits and read new ones *) let mask = (1 lsl dtable.max_bits) - 1 in let rest = Bit_reader.Backward.read_bits stream bits_used in let new_state = ((state lsl bits_used) + rest) land mask in (symbol, new_state) (** Decompress a single Huffman stream *) let decompress_1stream dtable src ~pos ~len output ~out_pos ~out_len = let stream = Bit_reader.Backward.of_bytes src ~pos ~len in let state = ref (init_state dtable stream) in let written = ref 0 in while Bit_reader.Backward.remaining stream > -dtable.max_bits do if out_pos + !written >= out_pos + out_len then raise (Constants.Zstd_error Constants.Output_too_small); let (symbol, new_state) = decode_symbol dtable !state stream in Bytes.set_uint8 output (out_pos + !written) symbol; incr written; state := new_state done; (* Verify stream is exactly consumed *) if Bit_reader.Backward.remaining stream <> -dtable.max_bits then raise (Constants.Zstd_error Constants.Corruption); !written (** Decompress 4 interleaved Huffman streams *) let decompress_4stream dtable src ~pos ~len output ~out_pos ~regen_size = (* Read stream sizes from jump table (6 bytes) *) let size1 = Bit_reader.get_u16_le src pos in let size2 = Bit_reader.get_u16_le src (pos + 2) in let size3 = Bit_reader.get_u16_le src (pos + 4) in let size4 = len - 6 - size1 - size2 - size3 in if size4 < 1 then raise (Constants.Zstd_error Constants.Corruption); (* Calculate output sizes *) let out_size = (regen_size + 3) / 4 in let out_size4 = regen_size - 3 * out_size in (* Decompress each stream *) let stream_pos = pos + 6 in let written1 = decompress_1stream dtable src ~pos:stream_pos ~len:size1 output ~out_pos ~out_len:out_size in let written2 = decompress_1stream dtable src ~pos:(stream_pos + size1) ~len:size2 output ~out_pos:(out_pos + out_size) ~out_len:out_size in let written3 = decompress_1stream dtable src ~pos:(stream_pos + size1 + size2) ~len:size3 output ~out_pos:(out_pos + 2 * out_size) ~out_len:out_size in let written4 = decompress_1stream dtable src ~pos:(stream_pos + size1 + size2 + size3) ~len:size4 output ~out_pos:(out_pos + 3 * out_size) ~out_len:out_size4 in written1 + written2 + written3 + written4 (** Decode Huffman table from stream. Returns (dtable, bytes consumed) *) let decode_table (stream : Bit_reader.Forward.t) = let header = Bit_reader.Forward.read_byte stream in let weights = Array.make Constants.max_huffman_symbols 0 in let num_symbols = if header >= 128 then begin (* Direct representation: 4 bits per weight *) let count = header - 127 in let bytes_needed = (count + 1) / 2 in let data = Bit_reader.Forward.get_bytes stream bytes_needed in for i = 0 to count - 1 do let byte = Bytes.get_uint8 data (i / 2) in weights.(i) <- if i mod 2 = 0 then byte lsr 4 else byte land 0xf done; count end else begin (* FSE compressed weights *) let compressed_size = header in let fse_data = Bit_reader.Forward.get_bytes stream compressed_size in (* Decode FSE table for weights (max accuracy 7) *) let fse_stream = Bit_reader.Forward.of_bytes fse_data in let fse_table = Fse.decode_header fse_stream 7 in (* Remaining bytes are the compressed weights *) let weights_pos = Bit_reader.Forward.byte_position fse_stream in let weights_len = compressed_size - weights_pos in let weight_bytes = Bytes.create Constants.max_huffman_symbols in let decoded = Fse.decompress_interleaved2 fse_table fse_data ~pos:weights_pos ~len:weights_len weight_bytes in for i = 0 to decoded - 1 do weights.(i) <- Bytes.get_uint8 weight_bytes i done; decoded end in build_dtable_from_weights weights num_symbols (* ========== ENCODING ========== *) (** Huffman encoding table *) type ctable = { codes : int array; (* Canonical code for each symbol *) num_bits : int array; (* Bit length for each symbol *) max_bits : int; num_symbols : int; } (** Build Huffman code from frequencies using package-merge algorithm *) let build_ctable counts max_symbol max_bits_limit = let num_symbols = max_symbol + 1 in let freqs = Array.sub counts 0 num_symbols in (* Count non-zero frequencies *) let non_zero = ref 0 in for i = 0 to num_symbols - 1 do if freqs.(i) > 0 then incr non_zero done; if !non_zero = 0 then { codes = [||]; num_bits = [||]; max_bits = 0; num_symbols = 0 } else if !non_zero = 1 then begin (* Single symbol case *) let num_bits = Array.make num_symbols 0 in for i = 0 to num_symbols - 1 do if freqs.(i) > 0 then num_bits.(i) <- 1 done; let codes = Array.make num_symbols 0 in { codes; num_bits; max_bits = 1; num_symbols } end else begin (* Sort symbols by frequency *) let sorted = Array.init num_symbols (fun i -> (freqs.(i), i)) in Array.sort (fun (f1, _) (f2, _) -> compare f1 f2) sorted; (* Build Huffman tree using a simple greedy approach *) (* This produces a valid but not necessarily optimal tree *) let bit_lengths = Array.make num_symbols 0 in (* Assign bit lengths based on frequency rank *) let active_count = ref 0 in for i = 0 to num_symbols - 1 do let (freq, _sym) = sorted.(num_symbols - 1 - i) in if freq > 0 then incr active_count done; (* Use Kraft's inequality to assign optimal lengths *) (* Start with uniform distribution and adjust *) let target_bits = max 1 (highest_set_bit !active_count + 1) in let max_bits = min max_bits_limit (max target_bits 1) in (* Simple heuristic: assign bits based on frequency ranking *) let rank = ref 0 in for i = num_symbols - 1 downto 0 do let (freq, sym) = sorted.(i) in if freq > 0 then begin (* More frequent symbols get shorter codes *) let bits = if !rank < (1 lsl (max_bits - 1)) then min max_bits (max 1 (max_bits - highest_set_bit (!rank + 1))) else max_bits in bit_lengths.(sym) <- bits; incr rank end done; (* Validate and adjust bit lengths to satisfy Kraft inequality *) let rec adjust () = let kraft_sum = ref 0.0 in for i = 0 to num_symbols - 1 do if bit_lengths.(i) > 0 then kraft_sum := !kraft_sum +. (1.0 /. (float_of_int (1 lsl bit_lengths.(i)))) done; if !kraft_sum > 1.0 then begin (* Increase some lengths *) for i = 0 to num_symbols - 1 do if bit_lengths.(i) > 0 && bit_lengths.(i) < max_bits then begin bit_lengths.(i) <- bit_lengths.(i) + 1 end done; adjust () end in adjust (); (* Build canonical codes *) let codes = Array.make num_symbols 0 in let actual_max = ref 0 in for i = 0 to num_symbols - 1 do if bit_lengths.(i) > !actual_max then actual_max := bit_lengths.(i) done; (* Count symbols at each bit length *) let bl_count = Array.make (!actual_max + 1) 0 in for i = 0 to num_symbols - 1 do if bit_lengths.(i) > 0 then bl_count.(bit_lengths.(i)) <- bl_count.(bit_lengths.(i)) + 1 done; (* Calculate starting code for each bit length *) let next_code = Array.make (!actual_max + 1) 0 in let code = ref 0 in for bits = 1 to !actual_max do code := (!code + bl_count.(bits - 1)) lsl 1; next_code.(bits) <- !code done; (* Assign codes to symbols *) for i = 0 to num_symbols - 1 do let bits = bit_lengths.(i) in if bits > 0 then begin codes.(i) <- next_code.(bits); next_code.(bits) <- next_code.(bits) + 1 end done; { codes; num_bits = bit_lengths; max_bits = !actual_max; num_symbols } end (** Convert bit lengths to weights (zstd format) *) let bits_to_weights num_bits num_symbols max_bits = let weights = Array.make num_symbols 0 in for i = 0 to num_symbols - 1 do if num_bits.(i) > 0 then weights.(i) <- max_bits + 1 - num_bits.(i) done; weights (** Write Huffman table header using direct representation. Returns the number of actual symbols to encode. Note: For tables with >127 weights, FSE compression could be used for better ratios, but direct representation is always valid. *) let write_header (stream : Bit_writer.Forward.t) ctable = if ctable.num_symbols = 0 then 0 else begin let weights = bits_to_weights ctable.num_bits ctable.num_symbols ctable.max_bits in (* Find last non-zero weight (implicit last symbol) *) let last_nonzero = ref (ctable.num_symbols - 1) in while !last_nonzero > 0 && weights.(!last_nonzero) = 0 do decr last_nonzero done; let num_weights = !last_nonzero in (* Last weight is implicit *) (* Direct representation: header byte = 128 + num_weights, then 4 bits per weight *) let header = 128 + num_weights in Bit_writer.Forward.write_byte stream header; (* Write weights packed as pairs (high nibble, low nibble) *) for i = 0 to (num_weights - 1) / 2 do let w1 = if 2 * i < num_weights then weights.(2 * i) else 0 in let w2 = if 2 * i + 1 < num_weights then weights.(2 * i + 1) else 0 in Bit_writer.Forward.write_byte stream ((w1 lsl 4) lor w2) done; num_weights + 1 end (** Encode a single symbol (write to backward stream) *) let[@inline] encode_symbol ctable (stream : Bit_writer.Backward.t) symbol = let code = ctable.codes.(symbol) in let bits = ctable.num_bits.(symbol) in if bits > 0 then Bit_writer.Backward.write_bits stream code bits (** Compress literals to a single Huffman stream *) let compress_1stream ctable literals ~pos ~len = let stream = Bit_writer.Backward.create (len * 2 + 16) in (* Encode symbols in reverse order *) for i = pos + len - 1 downto pos do let sym = Bytes.get_uint8 literals i in encode_symbol ctable stream sym done; Bit_writer.Backward.finalize stream (** Compress literals to 4 interleaved Huffman streams *) let compress_4stream ctable literals ~pos ~len = let chunk_size = (len + 3) / 4 in let chunk4_size = len - 3 * chunk_size in (* Compress each stream *) let stream1 = compress_1stream ctable literals ~pos ~len:chunk_size in let stream2 = compress_1stream ctable literals ~pos:(pos + chunk_size) ~len:chunk_size in let stream3 = compress_1stream ctable literals ~pos:(pos + 2 * chunk_size) ~len:chunk_size in let stream4 = compress_1stream ctable literals ~pos:(pos + 3 * chunk_size) ~len:chunk4_size in (* Build output with jump table *) let size1 = Bytes.length stream1 in let size2 = Bytes.length stream2 in let size3 = Bytes.length stream3 in let total = 6 + size1 + size2 + size3 + Bytes.length stream4 in let output = Bytes.create total in Bytes.set_uint16_le output 0 size1; Bytes.set_uint16_le output 2 size2; Bytes.set_uint16_le output 4 size3; Bytes.blit stream1 0 output 6 size1; Bytes.blit stream2 0 output (6 + size1) size2; Bytes.blit stream3 0 output (6 + size1 + size2) size3; Bytes.blit stream4 0 output (6 + size1 + size2 + size3) (Bytes.length stream4); output