Zstd compression in pure OCaml
1(** Finite State Entropy (FSE) decoding for Zstandard.
2
3 FSE is an entropy coding method based on ANS (Asymmetric Numeral Systems).
4 FSE streams are read backwards (from end to beginning). *)
5
6(** FSE decoding table entry *)
7type entry = {
8 symbol : int;
9 num_bits : int;
10 new_state_base : int;
11}
12
13(** FSE decoding table *)
14type dtable = {
15 entries : entry array;
16 accuracy_log : int;
17}
18
19(** Find the highest set bit (floor(log2(n))) *)
20let[@inline] highest_set_bit n =
21 if n = 0 then -1
22 else
23 let rec loop i =
24 if (1 lsl i) <= n then loop (i + 1)
25 else i - 1
26 in
27 loop 0
28
29(** Build FSE decoding table from normalized frequencies.
30 Frequencies can be negative (-1 means probability < 1). *)
31let build_dtable frequencies accuracy_log =
32 let table_size = 1 lsl accuracy_log in
33 let num_symbols = Array.length frequencies in
34
35 (* Create entries array *)
36 let entries = Array.init table_size (fun _ ->
37 { symbol = 0; num_bits = 0; new_state_base = 0 }
38 ) in
39
40 (* Track state descriptors for each symbol *)
41 let state_desc = Array.make num_symbols 0 in
42
43 (* First pass: place symbols with prob < 1 at the end *)
44 let high_threshold = ref table_size in
45 for s = 0 to num_symbols - 1 do
46 if frequencies.(s) = -1 then begin
47 decr high_threshold;
48 entries.(!high_threshold) <- { symbol = s; num_bits = 0; new_state_base = 0 };
49 state_desc.(s) <- 1
50 end
51 done;
52
53 (* Second pass: distribute remaining symbols using the step formula *)
54 let step = (table_size lsr 1) + (table_size lsr 3) + 3 in
55 let mask = table_size - 1 in
56 let pos = ref 0 in
57
58 for s = 0 to num_symbols - 1 do
59 if frequencies.(s) > 0 then begin
60 state_desc.(s) <- frequencies.(s);
61 for _ = 0 to frequencies.(s) - 1 do
62 entries.(!pos) <- { entries.(!pos) with symbol = s };
63 (* Skip positions occupied by prob < 1 symbols *)
64 pos := (!pos + step) land mask;
65 while !pos >= !high_threshold do
66 pos := (!pos + step) land mask
67 done
68 done
69 end
70 done;
71
72 if !pos <> 0 then
73 raise (Constants.Zstd_error Constants.Invalid_fse_table);
74
75 (* Third pass: fill in num_bits and new_state_base *)
76 for i = 0 to table_size - 1 do
77 let s = entries.(i).symbol in
78 let next_state_desc = state_desc.(s) in
79 state_desc.(s) <- next_state_desc + 1;
80
81 (* Number of bits is accuracy_log - log2(next_state_desc) *)
82 let num_bits = accuracy_log - highest_set_bit next_state_desc in
83 (* new_state_base = (next_state_desc << num_bits) - table_size *)
84 let new_state_base = (next_state_desc lsl num_bits) - table_size in
85
86 entries.(i) <- { entries.(i) with num_bits; new_state_base }
87 done;
88
89 { entries; accuracy_log }
90
91(** Build RLE table (single symbol repeated) *)
92let build_dtable_rle symbol =
93 {
94 entries = [| { symbol; num_bits = 0; new_state_base = 0 } |];
95 accuracy_log = 0;
96 }
97
98(** Peek at the symbol for current state (doesn't update state) *)
99let[@inline] peek_symbol dtable state =
100 dtable.entries.(state).symbol
101
102(** Update state by reading bits from the stream *)
103let[@inline] update_state dtable state (stream : Bit_reader.Backward.t) =
104 let entry = dtable.entries.(state) in
105 let bits = Bit_reader.Backward.read_bits stream entry.num_bits in
106 entry.new_state_base + bits
107
108(** Decode symbol and update state *)
109let[@inline] decode_symbol dtable state stream =
110 let symbol = peek_symbol dtable state in
111 let new_state = update_state dtable state stream in
112 (symbol, new_state)
113
114(** Initialize state by reading accuracy_log bits *)
115let[@inline] init_state dtable (stream : Bit_reader.Backward.t) =
116 Bit_reader.Backward.read_bits stream dtable.accuracy_log
117
118(** Decode FSE header and build decoding table.
119 Returns the table and advances the forward stream. *)
120let decode_header (stream : Bit_reader.Forward.t) max_accuracy_log =
121 (* Accuracy log is first 4 bits + 5 *)
122 let accuracy_log = (Bit_reader.Forward.read_bits stream 4) + 5 in
123 if accuracy_log > max_accuracy_log then
124 raise (Constants.Zstd_error Constants.Invalid_fse_table);
125
126 let table_size = 1 lsl accuracy_log in
127 let frequencies = Array.make Constants.max_fse_symbols 0 in
128
129 let remaining = ref table_size in
130 let symbol = ref 0 in
131
132 while !remaining > 0 && !symbol < Constants.max_fse_symbols do
133 (* Determine how many bits we might need *)
134 let bits_needed = highest_set_bit (!remaining + 1) + 1 in
135 let value = Bit_reader.Forward.read_bits stream bits_needed in
136
137 (* Small value optimization: values < threshold use one less bit *)
138 let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in
139 let lower_mask = (1 lsl (bits_needed - 1)) - 1 in
140
141 let (actual_value, bits_consumed) =
142 if (value land lower_mask) < threshold then
143 (value land lower_mask, bits_needed - 1)
144 else if value > lower_mask then
145 (value - threshold, bits_needed)
146 else
147 (value, bits_needed)
148 in
149
150 (* Rewind if we read too many bits *)
151 if bits_consumed < bits_needed then
152 Bit_reader.Forward.rewind_bits stream 1;
153
154 (* Probability = value - 1 (so value 0 means prob = -1) *)
155 let prob = actual_value - 1 in
156 frequencies.(!symbol) <- prob;
157 remaining := !remaining - abs prob;
158 incr symbol;
159
160 (* Handle zero probability with repeat flags *)
161 if prob = 0 then begin
162 let rec read_zeroes () =
163 let repeat = Bit_reader.Forward.read_bits stream 2 in
164 for _ = 1 to repeat do
165 if !symbol < Constants.max_fse_symbols then begin
166 frequencies.(!symbol) <- 0;
167 incr symbol
168 end
169 done;
170 if repeat = 3 then read_zeroes ()
171 in
172 read_zeroes ()
173 end
174 done;
175
176 (* Align to byte boundary *)
177 Bit_reader.Forward.align stream;
178
179 if !remaining <> 0 then
180 raise (Constants.Zstd_error Constants.Invalid_fse_table);
181
182 (* Build the decoding table *)
183 let freq_slice = Array.sub frequencies 0 !symbol in
184 build_dtable freq_slice accuracy_log
185
186(** Decompress interleaved 2-state FSE stream.
187 Used for Huffman weight encoding. Returns number of symbols decoded. *)
188let decompress_interleaved2 dtable src ~pos ~len output =
189 let stream = Bit_reader.Backward.of_bytes src ~pos ~len in
190
191 (* Initialize two states *)
192 let state1 = ref (init_state dtable stream) in
193 let state2 = ref (init_state dtable stream) in
194
195 let out_pos = ref 0 in
196 let out_len = Bytes.length output in
197
198 (* Decode symbols alternating between states *)
199 while Bit_reader.Backward.remaining stream >= 0 do
200 if !out_pos >= out_len then
201 raise (Constants.Zstd_error Constants.Output_too_small);
202
203 let (sym1, new_state1) = decode_symbol dtable !state1 stream in
204 Bytes.set_uint8 output !out_pos sym1;
205 incr out_pos;
206 state1 := new_state1;
207
208 if Bit_reader.Backward.remaining stream < 0 then begin
209 (* Stream exhausted, output final symbol from state2 *)
210 if !out_pos < out_len then begin
211 Bytes.set_uint8 output !out_pos (peek_symbol dtable !state2);
212 incr out_pos
213 end
214 end else begin
215 if !out_pos >= out_len then
216 raise (Constants.Zstd_error Constants.Output_too_small);
217
218 let (sym2, new_state2) = decode_symbol dtable !state2 stream in
219 Bytes.set_uint8 output !out_pos sym2;
220 incr out_pos;
221 state2 := new_state2;
222
223 if Bit_reader.Backward.remaining stream < 0 then begin
224 (* Stream exhausted, output final symbol from state1 *)
225 if !out_pos < out_len then begin
226 Bytes.set_uint8 output !out_pos (peek_symbol dtable !state1);
227 incr out_pos
228 end
229 end
230 end
231 done;
232
233 !out_pos
234
235(** Build decoding table from predefined distribution *)
236let build_predefined_table distribution accuracy_log =
237 build_dtable distribution accuracy_log
238
239(* ========== ENCODING ========== *)
240
241(** FSE compression table - matches C zstd's FSE_symbolCompressionTransform format.
242 deltaNbBits is encoded as (maxBitsOut << 16) - minStatePlus
243 This allows computing nbBitsOut = (state + deltaNbBits) >> 16 *)
244type symbol_transform = {
245 delta_nb_bits : int; (* (maxBitsOut << 16) - minStatePlus *)
246 delta_find_state : int; (* Cumulative offset to find next state *)
247}
248
249(** FSE compression table *)
250type ctable = {
251 symbol_tt : symbol_transform array; (* Symbol compression transforms *)
252 state_table : int array; (* Next state lookup table *)
253 accuracy_log : int;
254 table_size : int;
255}
256
257(** FSE compression state - matches C zstd's FSE_CState_t *)
258type cstate = {
259 mutable value : int; (* Current state value *)
260 ctable : ctable; (* Reference to compression table *)
261}
262
263(** Count symbol frequencies *)
264let count_symbols src ~pos ~len max_symbol =
265 let counts = Array.make (max_symbol + 1) 0 in
266 for i = pos to pos + len - 1 do
267 let s = Bytes.get_uint8 src i in
268 if s <= max_symbol then
269 counts.(s) <- counts.(s) + 1
270 done;
271 counts
272
273(** Normalize counts to sum to table_size *)
274let normalize_counts counts total accuracy_log =
275 let table_size = 1 lsl accuracy_log in
276 let num_symbols = Array.length counts in
277 let norm = Array.make num_symbols 0 in
278
279 if total = 0 then norm
280 else begin
281 let scale = table_size * 256 / total in
282 let distributed = ref 0 in
283
284 for s = 0 to num_symbols - 1 do
285 if counts.(s) > 0 then begin
286 let proba = (counts.(s) * scale + 128) / 256 in
287 let proba = max 1 proba in
288 norm.(s) <- proba;
289 distributed := !distributed + proba
290 end
291 done;
292
293 while !distributed > table_size do
294 let max_val = ref 0 in
295 let max_idx = ref 0 in
296 for s = 0 to num_symbols - 1 do
297 if norm.(s) > !max_val then begin
298 max_val := norm.(s);
299 max_idx := s
300 end
301 done;
302 norm.(!max_idx) <- norm.(!max_idx) - 1;
303 decr distributed
304 done;
305
306 while !distributed < table_size do
307 let min_val = ref max_int in
308 let min_idx = ref 0 in
309 for s = 0 to num_symbols - 1 do
310 if norm.(s) > 0 && norm.(s) < !min_val then begin
311 min_val := norm.(s);
312 min_idx := s
313 end
314 done;
315 norm.(!min_idx) <- norm.(!min_idx) + 1;
316 incr distributed
317 done;
318
319 norm
320 end
321
322(** Build FSE compression table from normalized counts.
323 Matches C zstd's FSE_buildCTable_wksp algorithm exactly. *)
324let build_ctable norm_counts accuracy_log =
325 let table_size = 1 lsl accuracy_log in
326 let table_mask = table_size - 1 in
327 let num_symbols = Array.length norm_counts in
328 let step = (table_size lsr 1) + (table_size lsr 3) + 3 in
329
330 (* Symbol distribution table - which symbol at each state *)
331 let table_symbol = Array.make table_size 0 in
332
333 (* Cumulative counts for state table indexing *)
334 let cumul = Array.make (num_symbols + 1) 0 in
335 cumul.(0) <- 0;
336 for s = 0 to num_symbols - 1 do
337 let count = if norm_counts.(s) = -1 then 1 else max 0 norm_counts.(s) in
338 cumul.(s + 1) <- cumul.(s) + count
339 done;
340
341 (* Place low probability symbols at the end *)
342 let high_threshold = ref (table_size - 1) in
343 for s = 0 to num_symbols - 1 do
344 if norm_counts.(s) = -1 then begin
345 table_symbol.(!high_threshold) <- s;
346 decr high_threshold
347 end
348 done;
349
350 (* Spread remaining symbols using step formula *)
351 let pos = ref 0 in
352 for s = 0 to num_symbols - 1 do
353 let count = norm_counts.(s) in
354 if count > 0 then begin
355 for _ = 0 to count - 1 do
356 table_symbol.(!pos) <- s;
357 pos := (!pos + step) land table_mask;
358 while !pos > !high_threshold do
359 pos := (!pos + step) land table_mask
360 done
361 done
362 end
363 done;
364
365 (* Build state table - for each position, compute next state *)
366 let state_table = Array.make table_size 0 in
367 let cumul_copy = Array.copy cumul in
368 for u = 0 to table_size - 1 do
369 let s = table_symbol.(u) in
370 state_table.(cumul_copy.(s)) <- table_size + u;
371 cumul_copy.(s) <- cumul_copy.(s) + 1
372 done;
373
374 (* Build symbol compression transforms *)
375 let symbol_tt = Array.init num_symbols (fun s ->
376 let count = norm_counts.(s) in
377 match count with
378 | 0 ->
379 (* Zero probability - use max bits (shouldn't be encoded) *)
380 { delta_nb_bits = ((accuracy_log + 1) lsl 16) - (1 lsl accuracy_log);
381 delta_find_state = 0 }
382 | -1 | 1 ->
383 (* Low probability symbol *)
384 { delta_nb_bits = (accuracy_log lsl 16) - (1 lsl accuracy_log);
385 delta_find_state = cumul.(s) - 1 }
386 | _ ->
387 (* Normal symbol *)
388 let max_bits_out = accuracy_log - highest_set_bit (count - 1) in
389 let min_state_plus = count lsl max_bits_out in
390 { delta_nb_bits = (max_bits_out lsl 16) - min_state_plus;
391 delta_find_state = cumul.(s) - count }
392 ) in
393
394 { symbol_tt; state_table; accuracy_log; table_size }
395
396(** Initialize compression state - matches C's FSE_initCState *)
397let init_cstate ctable =
398 { value = 1 lsl ctable.accuracy_log; ctable }
399
400(** Initialize compression state with first symbol - matches C's FSE_initCState2.
401 This saves bits by using the smallest valid state for the first symbol. *)
402let init_cstate2 ctable symbol =
403 let st = ctable.symbol_tt.(symbol) in
404 let nb_bits_out = (st.delta_nb_bits + (1 lsl 15)) lsr 16 in
405 let init_value = (nb_bits_out lsl 16) - st.delta_nb_bits in
406 let state_idx = (init_value lsr nb_bits_out) + st.delta_find_state in
407 { value = ctable.state_table.(state_idx); ctable }
408
409(** Encode a single symbol - matches C's FSE_encodeSymbol exactly.
410 Outputs bits representing state transition and updates state. *)
411let[@inline] encode_symbol (stream : Bit_writer.Backward.t) cstate symbol =
412 let st = cstate.ctable.symbol_tt.(symbol) in
413 let nb_bits_out = (cstate.value + st.delta_nb_bits) lsr 16 in
414 Bit_writer.Backward.write_bits stream cstate.value nb_bits_out;
415 let state_idx = (cstate.value lsr nb_bits_out) + st.delta_find_state in
416 cstate.value <- cstate.ctable.state_table.(state_idx)
417
418(** Flush compression state - matches C's FSE_flushCState.
419 Outputs final state value to allow decoder to initialize. *)
420let[@inline] flush_cstate (stream : Bit_writer.Backward.t) cstate =
421 Bit_writer.Backward.write_bits stream cstate.value cstate.ctable.accuracy_log
422
423(** Write FSE header (normalized counts) *)
424let write_header (stream : Bit_writer.Forward.t) norm_counts accuracy_log =
425 Bit_writer.Forward.write_bits stream (accuracy_log - 5) 4;
426
427 let table_size = 1 lsl accuracy_log in
428 let num_symbols = Array.length norm_counts in
429 let remaining = ref table_size in
430 let symbol = ref 0 in
431
432 while !remaining > 0 && !symbol < num_symbols do
433 let count = norm_counts.(!symbol) in
434 let value = count + 1 in
435
436 let bits_needed = highest_set_bit (!remaining + 1) + 1 in
437 let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in
438
439 if value < threshold then
440 Bit_writer.Forward.write_bits stream value (bits_needed - 1)
441 else
442 Bit_writer.Forward.write_bits stream (value + threshold) bits_needed;
443
444 remaining := !remaining - abs count;
445 incr symbol;
446
447 if count = 0 then begin
448 let rec count_zeroes acc =
449 if !symbol < num_symbols && norm_counts.(!symbol) = 0 then begin
450 incr symbol;
451 count_zeroes (acc + 1)
452 end else acc
453 in
454 let zeroes = count_zeroes 0 in
455 let rec write_repeats n =
456 if n >= 3 then begin
457 Bit_writer.Forward.write_bits stream 3 2;
458 write_repeats (n - 3)
459 end else
460 Bit_writer.Forward.write_bits stream n 2
461 in
462 write_repeats zeroes
463 end
464 done
465
466(** Build encoding table from predefined distribution *)
467let build_predefined_ctable distribution accuracy_log =
468 build_ctable distribution accuracy_log