Zstd compression in pure OCaml
1(** Zstandard compression implementation.
2
3 Implements LZ77 matching, block compression, and frame encoding. *)
4
5(** Compression level affects speed vs ratio tradeoff *)
6type compression_level = {
7 window_log : int; (* Log2 of window size *)
8 chain_log : int; (* Log2 of hash chain length *)
9 hash_log : int; (* Log2 of hash table size *)
10 search_log : int; (* Number of searches per position *)
11 min_match : int; (* Minimum match length *)
12 target_len : int; (* Target match length *)
13 strategy : int; (* 0=fast, 1=greedy, 2=lazy *)
14}
15
16(** Default levels 1-19 *)
17let level_params = [|
18 (* Level 0/1: Fast *)
19 { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 };
20 { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 };
21 (* Level 2 *)
22 { window_log = 18; chain_log = 13; hash_log = 12; search_log = 1; min_match = 5; target_len = 4; strategy = 0 };
23 (* Level 3 *)
24 { window_log = 18; chain_log = 14; hash_log = 13; search_log = 1; min_match = 5; target_len = 8; strategy = 1 };
25 (* Level 4 *)
26 { window_log = 18; chain_log = 14; hash_log = 14; search_log = 2; min_match = 4; target_len = 8; strategy = 1 };
27 (* Level 5 *)
28 { window_log = 18; chain_log = 15; hash_log = 14; search_log = 3; min_match = 4; target_len = 16; strategy = 1 };
29 (* Level 6 *)
30 { window_log = 19; chain_log = 16; hash_log = 15; search_log = 3; min_match = 4; target_len = 32; strategy = 1 };
31 (* Level 7 *)
32 { window_log = 19; chain_log = 16; hash_log = 15; search_log = 4; min_match = 4; target_len = 32; strategy = 2 };
33 (* Level 8 *)
34 { window_log = 19; chain_log = 17; hash_log = 16; search_log = 4; min_match = 4; target_len = 64; strategy = 2 };
35 (* Level 9 *)
36 { window_log = 20; chain_log = 17; hash_log = 16; search_log = 5; min_match = 4; target_len = 64; strategy = 2 };
37 (* Level 10 *)
38 { window_log = 20; chain_log = 17; hash_log = 16; search_log = 6; min_match = 4; target_len = 128; strategy = 2 };
39 (* Level 11 *)
40 { window_log = 20; chain_log = 18; hash_log = 17; search_log = 6; min_match = 4; target_len = 128; strategy = 2 };
41 (* Level 12 *)
42 { window_log = 21; chain_log = 18; hash_log = 17; search_log = 7; min_match = 4; target_len = 256; strategy = 2 };
43 (* Level 13 *)
44 { window_log = 21; chain_log = 19; hash_log = 18; search_log = 7; min_match = 4; target_len = 256; strategy = 2 };
45 (* Level 14 *)
46 { window_log = 22; chain_log = 19; hash_log = 18; search_log = 8; min_match = 4; target_len = 256; strategy = 2 };
47 (* Level 15 *)
48 { window_log = 22; chain_log = 20; hash_log = 18; search_log = 9; min_match = 4; target_len = 256; strategy = 2 };
49 (* Level 16 *)
50 { window_log = 22; chain_log = 20; hash_log = 19; search_log = 10; min_match = 4; target_len = 512; strategy = 2 };
51 (* Level 17 *)
52 { window_log = 22; chain_log = 21; hash_log = 19; search_log = 11; min_match = 4; target_len = 512; strategy = 2 };
53 (* Level 18 *)
54 { window_log = 22; chain_log = 21; hash_log = 20; search_log = 12; min_match = 4; target_len = 512; strategy = 2 };
55 (* Level 19 *)
56 { window_log = 23; chain_log = 22; hash_log = 20; search_log = 12; min_match = 4; target_len = 1024; strategy = 2 };
57|]
58
59let get_level_params level =
60 let level = max 1 (min level 19) in
61 level_params.(level)
62
63(** A sequence represents a literal run + match *)
64type sequence = {
65 lit_length : int;
66 match_offset : int;
67 match_length : int;
68}
69
70(** Hash table for fast match finding *)
71type hash_table = {
72 table : int array; (* Position indexed by hash *)
73 chain : int array; (* Chain of previous matches at same hash *)
74 mask : int;
75}
76
77let create_hash_table log_size =
78 let size = 1 lsl log_size in
79 {
80 table = Array.make size (-1);
81 chain = Array.make (1 lsl 20) (-1); (* Max input size *)
82 mask = size - 1;
83 }
84
85(** Compute hash of 4 bytes *)
86let[@inline] hash4 src pos =
87 let v = Bytes.get_int32_le src pos in
88 (* MurmurHash3-like mixing *)
89 let h = Int32.to_int (Int32.mul v 0xcc9e2d51l) in
90 (h lxor (h lsr 15))
91
92(** Check if positions match and return length *)
93let match_length src pos1 pos2 limit =
94 let len = ref 0 in
95 let max_len = min (limit - pos1) (pos1 - pos2) in
96 while !len < max_len &&
97 Bytes.get_uint8 src (pos1 + !len) = Bytes.get_uint8 src (pos2 + !len) do
98 incr len
99 done;
100 !len
101
102(** Find best match at current position *)
103let find_best_match ht src pos limit params =
104 if pos + 4 > limit then
105 (0, 0)
106 else begin
107 let h = hash4 src pos land ht.mask in
108 let prev_pos = ht.table.(h) in
109
110 (* Update hash table *)
111 ht.chain.(pos) <- prev_pos;
112 ht.table.(h) <- pos;
113
114 if prev_pos < 0 || pos - prev_pos > (1 lsl params.window_log) then
115 (0, 0)
116 else begin
117 (* Search chain for best match *)
118 let best_offset = ref 0 in
119 let best_length = ref 0 in
120 let chain_pos = ref prev_pos in
121 let searches = ref 0 in
122 let max_searches = 1 lsl params.search_log in
123
124 while !chain_pos >= 0 && !searches < max_searches do
125 let offset = pos - !chain_pos in
126 if offset > (1 lsl params.window_log) then
127 chain_pos := -1
128 else begin
129 let len = match_length src pos !chain_pos limit in
130 if len >= params.min_match && len > !best_length then begin
131 best_length := len;
132 best_offset := offset
133 end;
134 chain_pos := ht.chain.(!chain_pos);
135 incr searches
136 end
137 done;
138
139 (!best_offset, !best_length)
140 end
141 end
142
143(** Parse input into sequences using greedy/lazy matching *)
144let parse_sequences src ~pos ~len params =
145 let sequences = ref [] in
146 let cur_pos = ref pos in
147 let limit = pos + len in
148 let lit_start = ref pos in
149
150 let ht = create_hash_table params.hash_log in
151
152 while !cur_pos + 4 <= limit do
153 let (offset, length) = find_best_match ht src !cur_pos limit params in
154
155 if length >= params.min_match then begin
156 (* Emit sequence *)
157 let lit_len = !cur_pos - !lit_start in
158 sequences := { lit_length = lit_len; match_offset = offset; match_length = length } :: !sequences;
159
160 (* Update hash table for matched positions *)
161 for i = !cur_pos + 1 to !cur_pos + length - 1 do
162 if i + 4 <= limit then begin
163 let h = hash4 src i land ht.mask in
164 ht.chain.(i) <- ht.table.(h);
165 ht.table.(h) <- i
166 end
167 done;
168
169 cur_pos := !cur_pos + length;
170 lit_start := !cur_pos
171 end else begin
172 incr cur_pos
173 end
174 done;
175
176 (* Handle remaining literals *)
177 let remaining = limit - !lit_start in
178 if remaining > 0 || !sequences = [] then
179 sequences := { lit_length = remaining; match_offset = 0; match_length = 0 } :: !sequences;
180
181 List.rev !sequences
182
183(** Encode literal length code *)
184let encode_lit_length_code lit_len =
185 if lit_len < 16 then
186 (lit_len, 0, 0)
187 else if lit_len < 64 then
188 (16 + (lit_len - 16) / 4, (lit_len - 16) mod 4, 2)
189 else if lit_len < 128 then
190 (28 + (lit_len - 64) / 8, (lit_len - 64) mod 8, 3)
191 else begin
192 (* Use baseline tables for larger values *)
193 let rec find_code code =
194 if code >= 35 then (35, lit_len - Constants.ll_baselines.(35), Constants.ll_extra_bits.(35))
195 else if lit_len < Constants.ll_baselines.(code + 1) then
196 (code, lit_len - Constants.ll_baselines.(code), Constants.ll_extra_bits.(code))
197 else find_code (code + 1)
198 in
199 find_code 16
200 end
201
202(** Minimum match length for zstd *)
203let min_match = 3
204
205(** Encode match length code *)
206let encode_match_length_code match_len =
207 let ml = match_len - min_match in
208 if ml < 32 then
209 (ml, 0, 0)
210 else if ml < 64 then
211 (32 + (ml - 32) / 2, (ml - 32) mod 2, 1)
212 else begin
213 let rec find_code code =
214 if code >= 52 then (52, ml - Constants.ml_baselines.(52) + 3, Constants.ml_extra_bits.(52))
215 else if ml < Constants.ml_baselines.(code + 1) - 3 then
216 (code, ml - Constants.ml_baselines.(code) + 3, Constants.ml_extra_bits.(code))
217 else find_code (code + 1)
218 in
219 find_code 32
220 end
221
222(** Encode offset code.
223 Returns (of_code, extra_value, extra_bits).
224
225 Repeat offsets use offBase 1,2,3:
226 - offBase=1: ofCode=0, no extra bits
227 - offBase=2: ofCode=1, extra=0 (1 bit)
228 - offBase=3: ofCode=1, extra=1 (1 bit)
229
230 Real offsets use offBase = offset + 3:
231 - ofCode = highbit(offBase)
232 - extra = lower ofCode bits of offBase *)
233let encode_offset_code offset offset_history =
234 let off_base =
235 if offset = offset_history.(0) then 1
236 else if offset = offset_history.(1) then 2
237 else if offset = offset_history.(2) then 3
238 else offset + 3
239 in
240 let of_code = Fse.highest_set_bit off_base in
241 let extra = off_base land ((1 lsl of_code) - 1) in
242 (of_code, extra, of_code)
243
244(** Write raw literals section *)
245let write_raw_literals literals ~pos ~len output ~out_pos =
246 if len = 0 then begin
247 (* Empty literals: single-byte header with type=0, size=0 *)
248 Bytes.set_uint8 output out_pos 0;
249 1
250 end else if len < 32 then begin
251 (* Raw literals, single stream, 1-byte header *)
252 (* Header: type=0 (raw), size_format=0 (5-bit), regen_size in bits 3-7 *)
253 let header = 0b00 lor ((len land 0x1f) lsl 3) in
254 Bytes.set_uint8 output out_pos header;
255 Bytes.blit literals pos output (out_pos + 1) len;
256 1 + len
257 end else if len < 4096 then begin
258 (* Raw literals, 2-byte header *)
259 (* type=0 (bits 0-1), size_format=1 (bits 2-3), size in bits 4-15 *)
260 let header = 0b0100 lor ((len land 0x0fff) lsl 4) in
261 Bytes.set_uint16_le output out_pos header;
262 Bytes.blit literals pos output (out_pos + 2) len;
263 2 + len
264 end else begin
265 (* Raw literals, 3-byte header *)
266 (* type=0 (bits 0-1), size_format=2 (bits 2-3), size in bits 4-17 (14 bits) *)
267 let header = 0b1000 lor ((len land 0x3fff) lsl 4) in
268 Bytes.set_uint8 output out_pos (header land 0xff);
269 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
270 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
271 Bytes.blit literals pos output (out_pos + 3) len;
272 3 + len
273 end
274
275(** Write compressed literals with Huffman encoding *)
276let write_compressed_literals literals ~pos ~len output ~out_pos =
277 if len < 32 then
278 (* Too small for Huffman, use raw *)
279 write_raw_literals literals ~pos ~len output ~out_pos
280 else begin
281 (* Count symbol frequencies *)
282 let counts = Array.make 256 0 in
283 for i = pos to pos + len - 1 do
284 let c = Bytes.get_uint8 literals i in
285 counts.(c) <- counts.(c) + 1
286 done;
287
288 (* Find max symbol used *)
289 let max_symbol = ref 0 in
290 for i = 0 to 255 do
291 if counts.(i) > 0 then max_symbol := i
292 done;
293
294 (* Build Huffman table *)
295 let ctable = Huffman.build_ctable counts !max_symbol Constants.max_huffman_bits in
296
297 if ctable.num_symbols = 0 then
298 write_raw_literals literals ~pos ~len output ~out_pos
299 else begin
300 (* Decide single vs 4-stream based on size *)
301 let use_4streams = len >= 256 in
302
303 (* Write Huffman table header to temp buffer *)
304 let header_buf = Bytes.create 256 in
305 let header_stream = Bit_writer.Forward.of_bytes header_buf in
306 let _num_written = Huffman.write_header header_stream ctable in
307 let header_size = Bit_writer.Forward.byte_position header_stream in
308
309 (* Compress literals *)
310 let compressed =
311 if use_4streams then
312 Huffman.compress_4stream ctable literals ~pos ~len
313 else
314 Huffman.compress_1stream ctable literals ~pos ~len
315 in
316 let compressed_size = Bytes.length compressed in
317
318 (* Check if compression is worthwhile (should save at least 10%) *)
319 let total_compressed_size = header_size + compressed_size in
320 if total_compressed_size >= len - len / 10 then
321 write_raw_literals literals ~pos ~len output ~out_pos
322 else begin
323 (* Write compressed literals header *)
324 (* Type: 2 = compressed, size_format based on sizes *)
325 let regen_size = len in
326 let lit_type = 2 in (* Compressed_literals *)
327
328 let header_pos = ref out_pos in
329 if regen_size < 1024 && total_compressed_size < 1024 then begin
330 (* 3-byte header: type(2) + size_format(2) + regen(10) + compressed(10) + streams(2) *)
331 let size_format = 0 in
332 let streams_flag = if use_4streams then 3 else 0 in
333 let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in
334 let h1 = ((regen_size lsr 6) land 0xf) lor ((total_compressed_size land 0xf) lsl 4) in
335 let h2 = (total_compressed_size lsr 4) land 0xff in
336 Bytes.set_uint8 output !header_pos h0;
337 Bytes.set_uint8 output (!header_pos + 1) h1;
338 Bytes.set_uint8 output (!header_pos + 2) h2;
339 header_pos := !header_pos + 3
340 end else begin
341 (* 5-byte header for larger sizes *)
342 let size_format = 1 in
343 let streams_flag = if use_4streams then 3 else 0 in
344 let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in
345 Bytes.set_uint8 output !header_pos h0;
346 Bytes.set_uint16_le output (!header_pos + 1) (((regen_size lsr 6) land 0x3fff) lor ((total_compressed_size land 0x3) lsl 14));
347 Bytes.set_uint16_le output (!header_pos + 3) ((total_compressed_size lsr 2) land 0xffff);
348 header_pos := !header_pos + 5
349 end;
350
351 (* Write Huffman table *)
352 Bytes.blit header_buf 0 output !header_pos header_size;
353 header_pos := !header_pos + header_size;
354
355 (* Write compressed streams *)
356 Bytes.blit compressed 0 output !header_pos compressed_size;
357
358 !header_pos + compressed_size - out_pos
359 end
360 end
361 end
362
363(** Compress literals - try Huffman, fall back to raw *)
364let compress_literals literals ~pos ~len output ~out_pos =
365 write_compressed_literals literals ~pos ~len output ~out_pos
366
367(** Build predefined FSE compression tables *)
368let ll_ctable = lazy (Fse.build_predefined_ctable Constants.ll_default_distribution Constants.ll_default_accuracy_log)
369let ml_ctable = lazy (Fse.build_predefined_ctable Constants.ml_default_distribution Constants.ml_default_accuracy_log)
370let of_ctable = lazy (Fse.build_predefined_ctable Constants.of_default_distribution Constants.of_default_accuracy_log)
371
372(** Compress sequences section using predefined FSE tables.
373 This implements proper zstd sequence encoding following RFC 8878.
374
375 Matches C zstd's ZSTD_encodeSequences_body exactly:
376 1. Initialize states with FSE_initCState2 using LAST sequence's codes
377 2. Write LAST sequence's extra bits (LL, ML, OF order)
378 3. For sequences n-2 down to 0:
379 - FSE_encodeSymbol for OF, ML, LL
380 - Extra bits for LL, ML, OF
381 4. FSE_flushCState for ML, OF, LL
382*)
383let compress_sequences sequences output ~out_pos offset_history =
384 if sequences = [] then begin
385 (* Zero sequences *)
386 Bytes.set_uint8 output out_pos 0;
387 1
388 end else begin
389 let num_seq = List.length sequences in
390 let header_size = ref 0 in
391
392 (* Write sequence count (1-3 bytes) *)
393 if num_seq < 128 then begin
394 Bytes.set_uint8 output out_pos num_seq;
395 header_size := 1
396 end else if num_seq < 0x7f00 then begin
397 Bytes.set_uint8 output out_pos ((num_seq lsr 8) + 128);
398 Bytes.set_uint8 output (out_pos + 1) (num_seq land 0xff);
399 header_size := 2
400 end else begin
401 Bytes.set_uint8 output out_pos 0xff;
402 Bytes.set_uint16_le output (out_pos + 1) (num_seq - 0x7f00);
403 header_size := 3
404 end;
405
406 (* Symbol compression modes byte:
407 bits 0-1: Literals_Lengths_Mode (0 = predefined)
408 bits 2-3: Offsets_Mode (0 = predefined)
409 bits 4-5: Match_Lengths_Mode (0 = predefined)
410 bits 6-7: reserved *)
411 Bytes.set_uint8 output (out_pos + !header_size) 0b00;
412 incr header_size;
413
414 (* Get predefined FSE tables *)
415 let ll_ct = Lazy.force ll_ctable in
416 let ml_ct = Lazy.force ml_ctable in
417 let of_ct = Lazy.force of_ctable in
418
419 let offset_hist = Array.copy offset_history in
420 let seq_array = Array.of_list sequences in
421
422 (* Encode all sequences in forward order to track offset history *)
423 let encoded = Array.map (fun seq ->
424 let (ll_code, ll_extra, ll_extra_bits) = encode_lit_length_code seq.lit_length in
425 let (ml_code, ml_extra, ml_extra_bits) = encode_match_length_code seq.match_length in
426 let (of_code, of_extra, of_extra_bits) = encode_offset_code seq.match_offset offset_hist in
427
428 (* Update offset history for real offsets (of_code > 1 means offBase > 2) *)
429 if seq.match_offset > 0 && of_code > 1 then begin
430 offset_hist.(2) <- offset_hist.(1);
431 offset_hist.(1) <- offset_hist.(0);
432 offset_hist.(0) <- seq.match_offset
433 end;
434
435 (ll_code, ll_extra, ll_extra_bits, ml_code, ml_extra, ml_extra_bits, of_code, of_extra, of_extra_bits)
436 ) seq_array in
437
438 (* Use a backward bit writer *)
439 let stream = Bit_writer.Backward.create (num_seq * 20 + 32) in
440
441 (* Get last sequence's codes for state initialization *)
442 let last_idx = num_seq - 1 in
443 let (ll_code_last, ll_extra_last, ll_extra_bits_last,
444 ml_code_last, ml_extra_last, ml_extra_bits_last,
445 of_code_last, of_extra_last, of_extra_bits_last) = encoded.(last_idx) in
446
447 (* Initialize FSE states with LAST sequence's codes *)
448 let ll_state = Fse.init_cstate2 ll_ct ll_code_last in
449 let ml_state = Fse.init_cstate2 ml_ct ml_code_last in
450 let of_state = Fse.init_cstate2 of_ct of_code_last in
451
452 (* Write LAST sequence's extra bits first (LL, ML, OF order) *)
453 if ll_extra_bits_last > 0 then
454 Bit_writer.Backward.write_bits stream ll_extra_last ll_extra_bits_last;
455 if ml_extra_bits_last > 0 then
456 Bit_writer.Backward.write_bits stream ml_extra_last ml_extra_bits_last;
457 if of_extra_bits_last > 0 then
458 Bit_writer.Backward.write_bits stream of_extra_last of_extra_bits_last;
459
460 (* Process sequences from n-2 down to 0 *)
461 for i = last_idx - 1 downto 0 do
462 let (ll_code, ll_extra, ll_extra_bits,
463 ml_code, ml_extra, ml_extra_bits,
464 of_code, of_extra, of_extra_bits) = encoded.(i) in
465
466 (* FSE encode: OF, ML, LL order *)
467 Fse.encode_symbol stream of_state of_code;
468 Fse.encode_symbol stream ml_state ml_code;
469 Fse.encode_symbol stream ll_state ll_code;
470
471 (* Extra bits: LL, ML, OF order *)
472 if ll_extra_bits > 0 then
473 Bit_writer.Backward.write_bits stream ll_extra ll_extra_bits;
474 if ml_extra_bits > 0 then
475 Bit_writer.Backward.write_bits stream ml_extra ml_extra_bits;
476 if of_extra_bits > 0 then
477 Bit_writer.Backward.write_bits stream of_extra of_extra_bits
478 done;
479
480 (* Flush states: ML, OF, LL order *)
481 Fse.flush_cstate stream ml_state;
482 Fse.flush_cstate stream of_state;
483 Fse.flush_cstate stream ll_state;
484
485 (* Finalize and copy to output *)
486 let seq_data = Bit_writer.Backward.finalize stream in
487 let seq_len = Bytes.length seq_data in
488 Bytes.blit seq_data 0 output (out_pos + !header_size) seq_len;
489
490 !header_size + seq_len
491 end
492
493(** Write raw block (no compression) *)
494let write_raw_block src ~pos ~len output ~out_pos =
495 (* Raw block: header (3 bytes) + raw data
496 Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = block_size
497 For raw: block_type = 0, block_size = number of bytes *)
498 let header = (Constants.block_raw lsl 1) lor ((len land 0x1fffff) lsl 3) in
499 Bytes.set_uint8 output out_pos (header land 0xff);
500 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
501 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
502 Bytes.blit src pos output (out_pos + 3) len;
503 3 + len
504
505(** Write compressed block with sequences *)
506let write_compressed_block src ~pos ~len sequences output ~out_pos offset_history =
507 (* Collect all literals *)
508 let total_lit_len = List.fold_left (fun acc seq -> acc + seq.lit_length) 0 sequences in
509 let literals = Bytes.create total_lit_len in
510 let lit_pos = ref 0 in
511 let src_pos = ref pos in
512 List.iter (fun seq ->
513 if seq.lit_length > 0 then begin
514 Bytes.blit src !src_pos literals !lit_pos seq.lit_length;
515 lit_pos := !lit_pos + seq.lit_length;
516 src_pos := !src_pos + seq.lit_length
517 end;
518 src_pos := !src_pos + seq.match_length
519 ) sequences;
520
521 (* Build block content in temp buffer *)
522 let block_buf = Bytes.create (len * 2 + 256) in
523 let block_pos = ref 0 in
524
525 (* Write literals section *)
526 let lit_size = compress_literals literals ~pos:0 ~len:total_lit_len block_buf ~out_pos:!block_pos in
527 block_pos := !block_pos + lit_size;
528
529 (* Filter out sequences with only literals (match_length = 0 and match_offset = 0)
530 at the end - the last sequence can be literal-only *)
531 let real_sequences = List.filter (fun seq ->
532 seq.match_length > 0 || seq.match_offset > 0
533 ) sequences in
534
535 (* Write sequences section *)
536 let seq_size = compress_sequences real_sequences block_buf ~out_pos:!block_pos offset_history in
537 block_pos := !block_pos + seq_size;
538
539 let block_size = !block_pos in
540
541 (* Check if compressed block is actually smaller *)
542 if block_size >= len then begin
543 (* Fall back to raw block *)
544 write_raw_block src ~pos ~len output ~out_pos
545 end else begin
546 (* Write compressed block header *)
547 let header = (Constants.block_compressed lsl 1) lor ((block_size land 0x1fffff) lsl 3) in
548 Bytes.set_uint8 output out_pos (header land 0xff);
549 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
550 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
551 Bytes.blit block_buf 0 output (out_pos + 3) block_size;
552 3 + block_size
553 end
554
555(** Write RLE block (single byte repeated) *)
556let write_rle_block byte len output ~out_pos =
557 (* RLE block: header (3 bytes) + single byte
558 Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = regen_size
559 For RLE: block_type = 1, regen_size = number of bytes when expanded *)
560 let header = (Constants.block_rle lsl 1) lor ((len land 0x1fffff) lsl 3) in
561 Bytes.set_uint8 output out_pos (header land 0xff);
562 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
563 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
564 Bytes.set_uint8 output (out_pos + 3) byte;
565 4
566
567(** Check if block is all same byte *)
568let is_rle_block src ~pos ~len =
569 if len = 0 then None
570 else begin
571 let first = Bytes.get_uint8 src pos in
572 let all_same = ref true in
573 for i = pos + 1 to pos + len - 1 do
574 if Bytes.get_uint8 src i <> first then all_same := false
575 done;
576 if !all_same then Some first else None
577 end
578
579(** Compress a single block using LZ77 + FSE + Huffman.
580 Falls back to RLE for repetitive data, or raw blocks if compression doesn't help. *)
581let compress_block src ~pos ~len output ~out_pos params offset_history =
582 if len = 0 then
583 0
584 else
585 (* Check for RLE opportunity (all same byte) *)
586 match is_rle_block src ~pos ~len with
587 | Some byte when len > 4 ->
588 (* RLE is worthwhile: 4 bytes instead of len+3 *)
589 write_rle_block byte len output ~out_pos
590 | _ ->
591 (* Try LZ77 + FSE compression for compressible data *)
592 let sequences = parse_sequences src ~pos ~len params in
593 let match_count = List.fold_left (fun acc s ->
594 if s.match_length > 0 then acc + 1 else acc) 0 sequences in
595 (* Use compressed blocks for compressible data. The backward bitstream
596 writer now uses periodic flushing like C zstd, supporting any size. *)
597 if match_count >= 2 && len >= 64 then
598 write_compressed_block src ~pos ~len sequences output ~out_pos offset_history
599 else
600 write_raw_block src ~pos ~len output ~out_pos
601
602(** Write frame header *)
603let write_frame_header output ~pos content_size window_log checksum_flag =
604 (* Magic number *)
605 Bytes.set_int32_le output pos Constants.zstd_magic_number;
606 let out_pos = ref (pos + 4) in
607
608 (* Use single segment mode for smaller content (no window descriptor needed).
609 FCS field sizes when single_segment is set:
610 - fcs_flag=0: 1 byte (content size 0-255)
611 - fcs_flag=1: 2 bytes (content size 256-65791, stored with -256)
612 - fcs_flag=2: 4 bytes
613 - fcs_flag=3: 8 bytes *)
614 let single_segment = content_size <= 131072L in
615
616 let (fcs_flag, fcs_bytes) =
617 if single_segment then begin
618 if content_size <= 255L then (0, 1)
619 else if content_size <= 65791L then (1, 2) (* 2-byte has +256 offset *)
620 else if content_size <= 0xFFFFFFFFL then (2, 4)
621 else (3, 8)
622 end else begin
623 (* For non-single-segment, fcs_flag=0 means no FCS field *)
624 if content_size = 0L then (0, 0)
625 else if content_size <= 65535L then (1, 2)
626 else if content_size <= 0xFFFFFFFFL then (2, 4)
627 else (3, 8)
628 end
629 in
630
631 (* Frame header descriptor:
632 bit 0-1: dict ID flag (0 = no dict)
633 bit 2: content checksum flag
634 bit 3: reserved
635 bit 4: unused
636 bit 5: single segment (no window descriptor)
637 bit 6-7: FCS field size flag *)
638 let descriptor =
639 (if checksum_flag then 0b00000100 else 0)
640 lor (if single_segment then 0b00100000 else 0)
641 lor (fcs_flag lsl 6)
642 in
643 Bytes.set_uint8 output !out_pos descriptor;
644 incr out_pos;
645
646 (* Window descriptor (only if not single segment) *)
647 if not single_segment then begin
648 let window_desc = ((window_log - 10) lsl 3) in
649 Bytes.set_uint8 output !out_pos window_desc;
650 incr out_pos
651 end;
652
653 (* Frame content size *)
654 begin match fcs_bytes with
655 | 1 ->
656 Bytes.set_uint8 output !out_pos (Int64.to_int content_size);
657 out_pos := !out_pos + 1
658 | 2 ->
659 (* 2-byte FCS stores value - 256 *)
660 let adjusted = Int64.sub content_size 256L in
661 Bytes.set_uint16_le output !out_pos (Int64.to_int adjusted);
662 out_pos := !out_pos + 2
663 | 4 ->
664 Bytes.set_int32_le output !out_pos (Int64.to_int32 content_size);
665 out_pos := !out_pos + 4
666 | 8 ->
667 Bytes.set_int64_le output !out_pos content_size;
668 out_pos := !out_pos + 8
669 | _ -> ()
670 end;
671
672 !out_pos - pos
673
674(** Compress data to zstd frame *)
675let compress ?(level = 3) ?(checksum = true) src =
676 let src = Bytes.of_string src in
677 let len = Bytes.length src in
678 let params = get_level_params level in
679
680 (* Allocate output buffer - worst case is slightly larger than input *)
681 let max_output = len + len / 128 + 256 in
682 let output = Bytes.create max_output in
683
684 (* Initialize offset history *)
685 let offset_history = Array.copy Constants.initial_repeat_offsets in
686
687 (* Write frame header *)
688 let header_size = write_frame_header output ~pos:0 (Int64.of_int len) params.window_log checksum in
689 let out_pos = ref header_size in
690
691 (* Compress blocks *)
692 if len = 0 then begin
693 (* Empty content: write an empty raw block with last_block flag *)
694 (* Block header: last_block=1, block_type=raw(0), block_size=0 *)
695 (* Header = 1 | (0 << 1) | (0 << 3) = 0x01 *)
696 Bytes.set_uint8 output !out_pos 0x01;
697 Bytes.set_uint8 output (!out_pos + 1) 0x00;
698 Bytes.set_uint8 output (!out_pos + 2) 0x00;
699 out_pos := !out_pos + 3
700 end else begin
701 let block_size = min len Constants.block_size_max in
702 let pos = ref 0 in
703
704 while !pos < len do
705 let this_block = min block_size (len - !pos) in
706 let is_last = !pos + this_block >= len in
707
708 let block_len = compress_block src ~pos:!pos ~len:this_block output ~out_pos:!out_pos params offset_history in
709
710 (* Set last block flag *)
711 if is_last then begin
712 let current = Bytes.get_uint8 output !out_pos in
713 Bytes.set_uint8 output !out_pos (current lor 0x01)
714 end;
715
716 out_pos := !out_pos + block_len;
717 pos := !pos + this_block
718 done
719 end;
720
721 (* Write checksum if requested *)
722 if checksum then begin
723 let hash = Xxhash.hash64 src ~pos:0 ~len in
724 (* Write only lower 32 bits *)
725 Bytes.set_int32_le output !out_pos (Int64.to_int32 hash);
726 out_pos := !out_pos + 4
727 end;
728
729 Bytes.sub_string output 0 !out_pos
730
731(** Calculate maximum compressed size *)
732let compress_bound len =
733 len + len / 128 + 256
734
735(** Write a skippable frame.
736 @param variant Magic number variant 0-15
737 @param content The content to embed in the skippable frame
738 @return The complete skippable frame as a string *)
739let write_skippable_frame ?(variant = 0) content =
740 let variant = max 0 (min 15 variant) in
741 let len = String.length content in
742 if len > 0xFFFFFFFF then
743 invalid_arg "Skippable frame content too large (max 4GB)";
744 let output = Bytes.create (Constants.skippable_header_size + len) in
745 (* Magic number: 0x184D2A50 + variant *)
746 let magic = Int32.add Constants.skippable_magic_start (Int32.of_int variant) in
747 Bytes.set_int32_le output 0 magic;
748 (* Content size (4 bytes little-endian) *)
749 Bytes.set_int32_le output 4 (Int32.of_int len);
750 (* Content *)
751 Bytes.blit_string content 0 output 8 len;
752 Bytes.unsafe_to_string output