Zstd compression in pure OCaml
1(** Huffman coding for Zstandard literals decompression.
2
3 Zstd uses canonical Huffman codes for literal compression.
4 Huffman streams are read backwards like FSE streams. *)
5
6(** Huffman decoding table entry *)
7type entry = {
8 symbol : int;
9 num_bits : int;
10}
11
12(** Huffman decoding table *)
13type dtable = {
14 entries : entry array;
15 max_bits : int;
16}
17
18let highest_set_bit = Fse.highest_set_bit
19
20(** Build Huffman table from bit lengths.
21 Uses canonical Huffman coding. *)
22let build_dtable_from_bits bits num_symbols =
23 if num_symbols > Constants.max_huffman_symbols then
24 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
25
26 (* Find max bits and count symbols per bit length *)
27 let max_bits = ref 0 in
28 let rank_count = Array.make (Constants.max_huffman_bits + 1) 0 in
29
30 for i = 0 to num_symbols - 1 do
31 let b = bits.(i) in
32 if b > Constants.max_huffman_bits then
33 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
34 if b > !max_bits then max_bits := b;
35 rank_count.(b) <- rank_count.(b) + 1
36 done;
37
38 if !max_bits = 0 then
39 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
40
41 let table_size = 1 lsl !max_bits in
42 let entries = Array.init table_size (fun _ ->
43 { symbol = 0; num_bits = 0 }
44 ) in
45
46 (* Calculate starting indices for each rank *)
47 let rank_idx = Array.make (Constants.max_huffman_bits + 1) 0 in
48 rank_idx.(!max_bits) <- 0;
49 for i = !max_bits downto 1 do
50 rank_idx.(i - 1) <- rank_idx.(i) + rank_count.(i) * (1 lsl (!max_bits - i));
51 (* Fill in num_bits for this range *)
52 for j = rank_idx.(i) to rank_idx.(i - 1) - 1 do
53 entries.(j) <- { entries.(j) with num_bits = i }
54 done
55 done;
56
57 if rank_idx.(0) <> table_size then
58 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
59
60 (* Assign symbols to table entries *)
61 for i = 0 to num_symbols - 1 do
62 let b = bits.(i) in
63 if b <> 0 then begin
64 let code = rank_idx.(b) in
65 let len = 1 lsl (!max_bits - b) in
66 for j = code to code + len - 1 do
67 entries.(j) <- { entries.(j) with symbol = i }
68 done;
69 rank_idx.(b) <- code + len
70 end
71 done;
72
73 { entries; max_bits = !max_bits }
74
75(** Build table from weights (as decoded from zstd format) *)
76let build_dtable_from_weights weights num_symbols =
77 if num_symbols + 1 > Constants.max_huffman_symbols then
78 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
79
80 let bits = Array.make (num_symbols + 1) 0 in
81
82 (* Calculate weight sum to find max_bits and last weight *)
83 let weight_sum = ref 0 in
84 for i = 0 to num_symbols - 1 do
85 let w = weights.(i) in
86 if w > Constants.max_huffman_bits then
87 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
88 if w > 0 then
89 weight_sum := !weight_sum + (1 lsl (w - 1))
90 done;
91
92 (* Find max_bits (first power of 2 > weight_sum) *)
93 let max_bits = highest_set_bit !weight_sum + 1 in
94 let left_over = (1 lsl max_bits) - !weight_sum in
95
96 (* left_over must be a power of 2 *)
97 if left_over land (left_over - 1) <> 0 then
98 raise (Constants.Zstd_error Constants.Invalid_huffman_table);
99
100 let last_weight = highest_set_bit left_over + 1 in
101
102 (* Convert weights to bit lengths *)
103 for i = 0 to num_symbols - 1 do
104 let w = weights.(i) in
105 bits.(i) <- if w > 0 then max_bits + 1 - w else 0
106 done;
107 bits.(num_symbols) <- max_bits + 1 - last_weight;
108
109 build_dtable_from_bits bits (num_symbols + 1)
110
111(** Initialize Huffman state by reading max_bits *)
112let[@inline] init_state dtable (stream : Bit_reader.Backward.t) =
113 Bit_reader.Backward.read_bits stream dtable.max_bits
114
115(** Decode a symbol and update state *)
116let[@inline] decode_symbol dtable state (stream : Bit_reader.Backward.t) =
117 let entry = dtable.entries.(state) in
118 let symbol = entry.symbol in
119 let bits_used = entry.num_bits in
120 (* Shift out used bits and read new ones *)
121 let mask = (1 lsl dtable.max_bits) - 1 in
122 let rest = Bit_reader.Backward.read_bits stream bits_used in
123 let new_state = ((state lsl bits_used) + rest) land mask in
124 (symbol, new_state)
125
126(** Decompress a single Huffman stream *)
127let decompress_1stream dtable src ~pos ~len output ~out_pos ~out_len =
128 let stream = Bit_reader.Backward.of_bytes src ~pos ~len in
129 let state = ref (init_state dtable stream) in
130
131 let written = ref 0 in
132 while Bit_reader.Backward.remaining stream > -dtable.max_bits do
133 if out_pos + !written >= out_pos + out_len then
134 raise (Constants.Zstd_error Constants.Output_too_small);
135
136 let (symbol, new_state) = decode_symbol dtable !state stream in
137 Bytes.set_uint8 output (out_pos + !written) symbol;
138 incr written;
139 state := new_state
140 done;
141
142 (* Verify stream is exactly consumed *)
143 if Bit_reader.Backward.remaining stream <> -dtable.max_bits then
144 raise (Constants.Zstd_error Constants.Corruption);
145
146 !written
147
148(** Decompress 4 interleaved Huffman streams *)
149let decompress_4stream dtable src ~pos ~len output ~out_pos ~regen_size =
150 (* Read stream sizes from jump table (6 bytes) *)
151 let size1 = Bit_reader.get_u16_le src pos in
152 let size2 = Bit_reader.get_u16_le src (pos + 2) in
153 let size3 = Bit_reader.get_u16_le src (pos + 4) in
154 let size4 = len - 6 - size1 - size2 - size3 in
155
156 if size4 < 1 then
157 raise (Constants.Zstd_error Constants.Corruption);
158
159 (* Calculate output sizes *)
160 let out_size = (regen_size + 3) / 4 in
161 let out_size4 = regen_size - 3 * out_size in
162
163 (* Decompress each stream *)
164 let stream_pos = pos + 6 in
165
166 let written1 = decompress_1stream dtable src
167 ~pos:stream_pos ~len:size1
168 output ~out_pos ~out_len:out_size in
169
170 let written2 = decompress_1stream dtable src
171 ~pos:(stream_pos + size1) ~len:size2
172 output ~out_pos:(out_pos + out_size) ~out_len:out_size in
173
174 let written3 = decompress_1stream dtable src
175 ~pos:(stream_pos + size1 + size2) ~len:size3
176 output ~out_pos:(out_pos + 2 * out_size) ~out_len:out_size in
177
178 let written4 = decompress_1stream dtable src
179 ~pos:(stream_pos + size1 + size2 + size3) ~len:size4
180 output ~out_pos:(out_pos + 3 * out_size) ~out_len:out_size4 in
181
182 written1 + written2 + written3 + written4
183
184(** Decode Huffman table from stream.
185 Returns (dtable, bytes consumed) *)
186let decode_table (stream : Bit_reader.Forward.t) =
187 let header = Bit_reader.Forward.read_byte stream in
188
189 let weights = Array.make Constants.max_huffman_symbols 0 in
190 let num_symbols =
191 if header >= 128 then begin
192 (* Direct representation: 4 bits per weight *)
193 let count = header - 127 in
194 let bytes_needed = (count + 1) / 2 in
195 let data = Bit_reader.Forward.get_bytes stream bytes_needed in
196
197 for i = 0 to count - 1 do
198 let byte = Bytes.get_uint8 data (i / 2) in
199 weights.(i) <- if i mod 2 = 0 then byte lsr 4 else byte land 0xf
200 done;
201 count
202 end else begin
203 (* FSE compressed weights *)
204 let compressed_size = header in
205 let fse_data = Bit_reader.Forward.get_bytes stream compressed_size in
206
207 (* Decode FSE table for weights (max accuracy 7) *)
208 let fse_stream = Bit_reader.Forward.of_bytes fse_data in
209 let fse_table = Fse.decode_header fse_stream 7 in
210
211 (* Remaining bytes are the compressed weights *)
212 let weights_pos = Bit_reader.Forward.byte_position fse_stream in
213 let weights_len = compressed_size - weights_pos in
214
215 let weight_bytes = Bytes.create Constants.max_huffman_symbols in
216 let decoded = Fse.decompress_interleaved2 fse_table
217 fse_data ~pos:weights_pos ~len:weights_len weight_bytes in
218
219 for i = 0 to decoded - 1 do
220 weights.(i) <- Bytes.get_uint8 weight_bytes i
221 done;
222 decoded
223 end
224 in
225
226 build_dtable_from_weights weights num_symbols
227
228(* ========== ENCODING ========== *)
229
230(** Huffman encoding table *)
231type ctable = {
232 codes : int array; (* Canonical code for each symbol *)
233 num_bits : int array; (* Bit length for each symbol *)
234 max_bits : int;
235 num_symbols : int;
236}
237
238(** Build Huffman code from frequencies using package-merge algorithm *)
239let build_ctable counts max_symbol max_bits_limit =
240 let num_symbols = max_symbol + 1 in
241 let freqs = Array.sub counts 0 num_symbols in
242
243 (* Count non-zero frequencies *)
244 let non_zero = ref 0 in
245 for i = 0 to num_symbols - 1 do
246 if freqs.(i) > 0 then incr non_zero
247 done;
248
249 if !non_zero = 0 then
250 { codes = [||]; num_bits = [||]; max_bits = 0; num_symbols = 0 }
251 else if !non_zero = 1 then begin
252 (* Single symbol case *)
253 let num_bits = Array.make num_symbols 0 in
254 for i = 0 to num_symbols - 1 do
255 if freqs.(i) > 0 then num_bits.(i) <- 1
256 done;
257 let codes = Array.make num_symbols 0 in
258 { codes; num_bits; max_bits = 1; num_symbols }
259 end else begin
260 (* Sort symbols by frequency *)
261 let sorted = Array.init num_symbols (fun i -> (freqs.(i), i)) in
262 Array.sort (fun (f1, _) (f2, _) -> compare f1 f2) sorted;
263
264 (* Build Huffman tree using a simple greedy approach *)
265 (* This produces a valid but not necessarily optimal tree *)
266 let bit_lengths = Array.make num_symbols 0 in
267
268 (* Assign bit lengths based on frequency rank *)
269 let active_count = ref 0 in
270 for i = 0 to num_symbols - 1 do
271 let (freq, _sym) = sorted.(num_symbols - 1 - i) in
272 if freq > 0 then incr active_count
273 done;
274
275 (* Use Kraft's inequality to assign optimal lengths *)
276 (* Start with uniform distribution and adjust *)
277 let target_bits = max 1 (highest_set_bit !active_count + 1) in
278 let max_bits = min max_bits_limit (max target_bits 1) in
279
280 (* Simple heuristic: assign bits based on frequency ranking *)
281 let rank = ref 0 in
282 for i = num_symbols - 1 downto 0 do
283 let (freq, sym) = sorted.(i) in
284 if freq > 0 then begin
285 (* More frequent symbols get shorter codes *)
286 let bits =
287 if !rank < (1 lsl (max_bits - 1)) then
288 min max_bits (max 1 (max_bits - highest_set_bit (!rank + 1)))
289 else
290 max_bits
291 in
292 bit_lengths.(sym) <- bits;
293 incr rank
294 end
295 done;
296
297 (* Validate and adjust bit lengths to satisfy Kraft inequality *)
298 let rec adjust () =
299 let kraft_sum = ref 0.0 in
300 for i = 0 to num_symbols - 1 do
301 if bit_lengths.(i) > 0 then
302 kraft_sum := !kraft_sum +. (1.0 /. (float_of_int (1 lsl bit_lengths.(i))))
303 done;
304 if !kraft_sum > 1.0 then begin
305 (* Increase some lengths *)
306 for i = 0 to num_symbols - 1 do
307 if bit_lengths.(i) > 0 && bit_lengths.(i) < max_bits then begin
308 bit_lengths.(i) <- bit_lengths.(i) + 1
309 end
310 done;
311 adjust ()
312 end
313 in
314 adjust ();
315
316 (* Build canonical codes *)
317 let codes = Array.make num_symbols 0 in
318 let actual_max = ref 0 in
319 for i = 0 to num_symbols - 1 do
320 if bit_lengths.(i) > !actual_max then actual_max := bit_lengths.(i)
321 done;
322
323 (* Count symbols at each bit length *)
324 let bl_count = Array.make (!actual_max + 1) 0 in
325 for i = 0 to num_symbols - 1 do
326 if bit_lengths.(i) > 0 then
327 bl_count.(bit_lengths.(i)) <- bl_count.(bit_lengths.(i)) + 1
328 done;
329
330 (* Calculate starting code for each bit length *)
331 let next_code = Array.make (!actual_max + 1) 0 in
332 let code = ref 0 in
333 for bits = 1 to !actual_max do
334 code := (!code + bl_count.(bits - 1)) lsl 1;
335 next_code.(bits) <- !code
336 done;
337
338 (* Assign codes to symbols *)
339 for i = 0 to num_symbols - 1 do
340 let bits = bit_lengths.(i) in
341 if bits > 0 then begin
342 codes.(i) <- next_code.(bits);
343 next_code.(bits) <- next_code.(bits) + 1
344 end
345 done;
346
347 { codes; num_bits = bit_lengths; max_bits = !actual_max; num_symbols }
348 end
349
350(** Convert bit lengths to weights (zstd format) *)
351let bits_to_weights num_bits num_symbols max_bits =
352 let weights = Array.make num_symbols 0 in
353 for i = 0 to num_symbols - 1 do
354 if num_bits.(i) > 0 then
355 weights.(i) <- max_bits + 1 - num_bits.(i)
356 done;
357 weights
358
359(** Write Huffman table header using direct representation.
360 Returns the number of actual symbols to encode.
361 Note: For tables with >127 weights, FSE compression could be used
362 for better ratios, but direct representation is always valid. *)
363let write_header (stream : Bit_writer.Forward.t) ctable =
364 if ctable.num_symbols = 0 then 0
365 else begin
366 let weights = bits_to_weights ctable.num_bits ctable.num_symbols ctable.max_bits in
367
368 (* Find last non-zero weight (implicit last symbol) *)
369 let last_nonzero = ref (ctable.num_symbols - 1) in
370 while !last_nonzero > 0 && weights.(!last_nonzero) = 0 do
371 decr last_nonzero
372 done;
373
374 let num_weights = !last_nonzero in (* Last weight is implicit *)
375
376 (* Direct representation: header byte = 128 + num_weights, then 4 bits per weight *)
377 let header = 128 + num_weights in
378 Bit_writer.Forward.write_byte stream header;
379
380 (* Write weights packed as pairs (high nibble, low nibble) *)
381 for i = 0 to (num_weights - 1) / 2 do
382 let w1 = if 2 * i < num_weights then weights.(2 * i) else 0 in
383 let w2 = if 2 * i + 1 < num_weights then weights.(2 * i + 1) else 0 in
384 Bit_writer.Forward.write_byte stream ((w1 lsl 4) lor w2)
385 done;
386
387 num_weights + 1
388 end
389
390(** Encode a single symbol (write to backward stream) *)
391let[@inline] encode_symbol ctable (stream : Bit_writer.Backward.t) symbol =
392 let code = ctable.codes.(symbol) in
393 let bits = ctable.num_bits.(symbol) in
394 if bits > 0 then
395 Bit_writer.Backward.write_bits stream code bits
396
397(** Compress literals to a single Huffman stream *)
398let compress_1stream ctable literals ~pos ~len =
399 let stream = Bit_writer.Backward.create (len * 2 + 16) in
400
401 (* Encode symbols in reverse order *)
402 for i = pos + len - 1 downto pos do
403 let sym = Bytes.get_uint8 literals i in
404 encode_symbol ctable stream sym
405 done;
406
407 Bit_writer.Backward.finalize stream
408
409(** Compress literals to 4 interleaved Huffman streams *)
410let compress_4stream ctable literals ~pos ~len =
411 let chunk_size = (len + 3) / 4 in
412 let chunk4_size = len - 3 * chunk_size in
413
414 (* Compress each stream *)
415 let stream1 = compress_1stream ctable literals ~pos ~len:chunk_size in
416 let stream2 = compress_1stream ctable literals ~pos:(pos + chunk_size) ~len:chunk_size in
417 let stream3 = compress_1stream ctable literals ~pos:(pos + 2 * chunk_size) ~len:chunk_size in
418 let stream4 = compress_1stream ctable literals ~pos:(pos + 3 * chunk_size) ~len:chunk4_size in
419
420 (* Build output with jump table *)
421 let size1 = Bytes.length stream1 in
422 let size2 = Bytes.length stream2 in
423 let size3 = Bytes.length stream3 in
424 let total = 6 + size1 + size2 + size3 + Bytes.length stream4 in
425
426 let output = Bytes.create total in
427 Bytes.set_uint16_le output 0 size1;
428 Bytes.set_uint16_le output 2 size2;
429 Bytes.set_uint16_le output 4 size3;
430 Bytes.blit stream1 0 output 6 size1;
431 Bytes.blit stream2 0 output (6 + size1) size2;
432 Bytes.blit stream3 0 output (6 + size1 + size2) size3;
433 Bytes.blit stream4 0 output (6 + size1 + size2 + size3) (Bytes.length stream4);
434
435 output