···1+(*---------------------------------------------------------------------------
2+ Copyright (c) 2024 The bytesrw programmers. All rights reserved.
3+ SPDX-License-Identifier: ISC
4+ ---------------------------------------------------------------------------*)
5+6+open Bytesrw
7+8+(* Errors *)
9+10+type Bytes.Stream.error += Error of Zstd.error
11+12+let error_message = Zstd.error_message
13+14+let format_error =
15+ let case e = Error e in
16+ let message = function Error e -> error_message e | _ -> assert false in
17+ Bytes.Stream.make_format_error ~format:"zstd" ~case ~message
18+19+let _error e = Bytes.Stream.error format_error e
20+let reader_error r e = Bytes.Reader.error format_error r e
21+let writer_error w e = Bytes.Writer.error format_error w e
22+23+(* Library parameters *)
24+25+let version = "1.0.0-pure-ocaml"
26+let min_clevel = 1
27+let max_clevel = 19
28+let default_clevel = 3
29+30+(* Default slice length *)
31+let default_slice_length = 65536
32+33+(* Buffer all slices from a reader into a single bytes *)
34+let buffer_reader r =
35+ let buf = Buffer.create default_slice_length in
36+ let rec loop () =
37+ let slice = Bytes.Reader.read r in
38+ if Bytes.Slice.is_eod slice then
39+ Buffer.contents buf
40+ else begin
41+ Buffer.add_subbytes buf
42+ (Bytes.Slice.bytes slice)
43+ (Bytes.Slice.first slice)
44+ (Bytes.Slice.length slice);
45+ loop ()
46+ end
47+ in
48+ loop ()
49+50+(* Read a single zstd frame, returning leftover data *)
51+let read_single_frame r =
52+ (* Buffer slices until we have enough to detect frame boundaries *)
53+ let buf = Buffer.create default_slice_length in
54+ let rec loop () =
55+ let slice = Bytes.Reader.read r in
56+ if Bytes.Slice.is_eod slice then begin
57+ (* End of input - return what we have *)
58+ let data = Buffer.contents buf in
59+ (data, "")
60+ end else begin
61+ Buffer.add_subbytes buf
62+ (Bytes.Slice.bytes slice)
63+ (Bytes.Slice.first slice)
64+ (Bytes.Slice.length slice);
65+ (* Check if we have a complete frame *)
66+ let data = Buffer.contents buf in
67+ if String.length data >= 4 && Zstd.is_zstd_frame data then
68+ (* Try to find frame boundary by checking decompressed size or
69+ attempting decompression. For now, buffer everything. *)
70+ loop ()
71+ else
72+ loop ()
73+ end
74+ in
75+ loop ()
76+77+(* Create a reader that yields slices from a string *)
78+let reader_of_string ?(slice_length = default_slice_length) s =
79+ let len = String.length s in
80+ let pos = ref 0 in
81+ let bytes = Bytes.unsafe_of_string s in
82+ let read () =
83+ if !pos >= len then Bytes.Slice.eod
84+ else begin
85+ let chunk_len = min slice_length (len - !pos) in
86+ let slice = Bytes.Slice.make bytes ~first:!pos ~length:chunk_len in
87+ pos := !pos + chunk_len;
88+ slice
89+ end
90+ in
91+ Bytes.Reader.make ~slice_length read
92+93+(* Decompress *)
94+95+let decompress_reads ?(all_frames = true) () ?pos ?(slice_length = default_slice_length) r =
96+ let state = ref `Reading in
97+ let output_reader = ref None in
98+ let read () =
99+ match !state with
100+ | `Done -> Bytes.Slice.eod
101+ | `Outputting ->
102+ begin match !output_reader with
103+ | None -> Bytes.Slice.eod
104+ | Some or_ ->
105+ let slice = Bytes.Reader.read or_ in
106+ if Bytes.Slice.is_eod slice then begin
107+ state := `Done;
108+ output_reader := None;
109+ Bytes.Slice.eod
110+ end else
111+ slice
112+ end
113+ | `Reading ->
114+ (* Buffer all input *)
115+ let input =
116+ if all_frames then
117+ buffer_reader r
118+ else
119+ let (data, _leftover) = read_single_frame r in
120+ (* TODO: push back leftover to r *)
121+ data
122+ in
123+ if String.length input = 0 then begin
124+ state := `Done;
125+ Bytes.Slice.eod
126+ end else begin
127+ (* Decompress *)
128+ match Zstd.decompress input with
129+ | Error _msg ->
130+ state := `Done;
131+ reader_error r Zstd.Corruption
132+ | Ok decompressed ->
133+ let or_ = reader_of_string ~slice_length decompressed in
134+ output_reader := Some or_;
135+ state := `Outputting;
136+ let slice = Bytes.Reader.read or_ in
137+ if Bytes.Slice.is_eod slice then begin
138+ state := `Done;
139+ output_reader := None
140+ end;
141+ slice
142+ end
143+ in
144+ Bytes.Reader.make ?pos ~slice_length read
145+146+let decompress_writes () ?pos ?(slice_length = default_slice_length) ~eod w =
147+ let buf = Buffer.create default_slice_length in
148+ let write slice =
149+ if Bytes.Slice.is_eod slice then begin
150+ (* Decompress buffered data *)
151+ let input = Buffer.contents buf in
152+ if String.length input > 0 then begin
153+ match Zstd.decompress input with
154+ | Error _msg ->
155+ writer_error w Zstd.Corruption
156+ | Ok decompressed ->
157+ (* Write decompressed data in slices *)
158+ let len = String.length decompressed in
159+ let bytes = Bytes.unsafe_of_string decompressed in
160+ let rec write_chunks pos =
161+ if pos >= len then ()
162+ else begin
163+ let chunk_len = min (Bytes.Writer.slice_length w) (len - pos) in
164+ let slice = Bytes.Slice.make bytes ~first:pos ~length:chunk_len in
165+ Bytes.Writer.write w slice;
166+ write_chunks (pos + chunk_len)
167+ end
168+ in
169+ write_chunks 0
170+ end;
171+ if eod then Bytes.Writer.write_eod w
172+ end else begin
173+ Buffer.add_subbytes buf
174+ (Bytes.Slice.bytes slice)
175+ (Bytes.Slice.first slice)
176+ (Bytes.Slice.length slice)
177+ end
178+ in
179+ Bytes.Writer.make ?pos ~slice_length write
180+181+(* Compress *)
182+183+let compress_reads ?(level = default_clevel) () ?pos ?(slice_length = default_slice_length) r =
184+ let state = ref `Reading in
185+ let output_reader = ref None in
186+ let read () =
187+ match !state with
188+ | `Done -> Bytes.Slice.eod
189+ | `Outputting ->
190+ begin match !output_reader with
191+ | None -> Bytes.Slice.eod
192+ | Some or_ ->
193+ let slice = Bytes.Reader.read or_ in
194+ if Bytes.Slice.is_eod slice then begin
195+ state := `Done;
196+ output_reader := None;
197+ Bytes.Slice.eod
198+ end else
199+ slice
200+ end
201+ | `Reading ->
202+ (* Buffer all input *)
203+ let input = buffer_reader r in
204+ if String.length input = 0 then begin
205+ (* Compress empty input to get valid empty frame *)
206+ let compressed = Zstd.compress ~level "" in
207+ let or_ = reader_of_string ~slice_length compressed in
208+ output_reader := Some or_;
209+ state := `Outputting;
210+ Bytes.Reader.read or_
211+ end else begin
212+ (* Compress *)
213+ let compressed = Zstd.compress ~level input in
214+ let or_ = reader_of_string ~slice_length compressed in
215+ output_reader := Some or_;
216+ state := `Outputting;
217+ let slice = Bytes.Reader.read or_ in
218+ if Bytes.Slice.is_eod slice then begin
219+ state := `Done;
220+ output_reader := None
221+ end;
222+ slice
223+ end
224+ in
225+ Bytes.Reader.make ?pos ~slice_length read
226+227+let compress_writes ?(level = default_clevel) () ?pos ?(slice_length = default_slice_length) ~eod w =
228+ let buf = Buffer.create default_slice_length in
229+ let write slice =
230+ if Bytes.Slice.is_eod slice then begin
231+ (* Compress buffered data *)
232+ let input = Buffer.contents buf in
233+ let compressed = Zstd.compress ~level input in
234+ (* Write compressed data in slices *)
235+ let len = String.length compressed in
236+ let bytes = Bytes.unsafe_of_string compressed in
237+ let rec write_chunks pos =
238+ if pos >= len then ()
239+ else begin
240+ let chunk_len = min (Bytes.Writer.slice_length w) (len - pos) in
241+ let slice = Bytes.Slice.make bytes ~first:pos ~length:chunk_len in
242+ Bytes.Writer.write w slice;
243+ write_chunks (pos + chunk_len)
244+ end
245+ in
246+ write_chunks 0;
247+ if eod then Bytes.Writer.write_eod w
248+ end else begin
249+ Buffer.add_subbytes buf
250+ (Bytes.Slice.bytes slice)
251+ (Bytes.Slice.first slice)
252+ (Bytes.Slice.length slice)
253+ end
254+ in
255+ Bytes.Writer.make ?pos ~slice_length write
···1+(*---------------------------------------------------------------------------
2+ Copyright (c) 2024 The bytesrw programmers. All rights reserved.
3+ SPDX-License-Identifier: ISC
4+ ---------------------------------------------------------------------------*)
5+6+(** Zstd streams via pure OCaml implementation.
7+8+ This module provides support for reading and writing
9+ {{:https://www.rfc-editor.org/rfc/rfc8878.html}zstd} compressed
10+ streams using a pure OCaml zstd implementation.
11+12+ Unlike the C-based [bytesrw-zstd] package, this implementation:
13+ - Has no C dependencies
14+ - Buffers entire frames before processing (not true streaming)
15+ - Works anywhere OCaml runs
16+17+ {b Positions.} The positions of readers and writers created
18+ by filters of this module default to [0]. *)
19+20+open Bytesrw
21+22+(** {1:errors Errors} *)
23+24+type Bytes.Stream.error += Error of Zstd.error
25+(** The type for zstd stream errors.
26+27+ All functions of this module and resulting readers and writers may
28+ raise {!Bytesrw.Bytes.Stream.Error} with this error. *)
29+30+val error_message : Zstd.error -> string
31+(** [error_message e] is a human-readable message for error [e]. *)
32+33+(** {1:decompress Decompress} *)
34+35+val decompress_reads : ?all_frames:bool -> unit -> Bytes.Reader.filter
36+(** [decompress_reads () r] filters the reads of [r] by decompressing
37+ zstd frames.
38+ {ul
39+ {- [slice_length] defaults to [65536].}}
40+41+ If [all_frames] is:
42+ {ul
43+ {- [true] (default), this decompresses all frames until [r] returns
44+ {!Bytesrw.Bytes.Slice.eod} and concatenates the result.}
45+ {- [false], this decompresses a single frame. Once the resulting reader
46+ returns {!Bytesrw.Bytes.Slice.eod}, [r] is positioned exactly after
47+ the end of frame and can be used again to perform other non-filtered
48+ reads (e.g. a new zstd frame or other unrelated data).}}
49+50+ {b Note:} This implementation buffers the entire compressed input
51+ before decompressing. For large files, consider using the C-based
52+ [bytesrw-zstd] package instead. *)
53+54+val decompress_writes : unit -> Bytes.Writer.filter
55+(** [decompress_writes () w ~eod] filters the writes on [w] by decompressing
56+ sequences of zstd frames until {!Bytesrw.Bytes.Slice.eod} is written.
57+ If [eod] is [false] the last {!Bytesrw.Bytes.Slice.eod} is not written
58+ on [w] and at this point [w] can be used again to perform other
59+ non-filtered writes.
60+ {ul
61+ {- [slice_length] defaults to [65536].}}
62+63+ {b Note:} This implementation buffers the entire compressed input
64+ before decompressing. *)
65+66+(** {1:compress Compress} *)
67+68+val compress_reads : ?level:int -> unit -> Bytes.Reader.filter
69+(** [compress_reads () r] filters the reads of [r] by compressing them
70+ to a single zstd frame.
71+ {ul
72+ {- [level] is the compression level (1-19, default 3).}
73+ {- [slice_length] defaults to [65536].}}
74+75+ {b Note:} This implementation buffers the entire input before
76+ compressing. *)
77+78+val compress_writes : ?level:int -> unit -> Bytes.Writer.filter
79+(** [compress_writes () w ~eod] filters the writes on [w] by compressing
80+ them to a single zstd frame until {!Bytesrw.Bytes.Slice.eod} is written.
81+ If [eod] is [false] the last {!Bytesrw.Bytes.Slice.eod} is not written
82+ on [w] and at this point [w] can be used again to perform non-filtered
83+ writes.
84+ {ul
85+ {- [level] is the compression level (1-19, default 3).}
86+ {- [slice_length] defaults to [65536].}}
87+88+ {b Note:} This implementation buffers the entire input before
89+ compressing. *)
90+91+(** {1:params Library parameters} *)
92+93+val version : string
94+(** [version] is the version of this pure OCaml zstd implementation. *)
95+96+val min_clevel : int
97+(** [min_clevel] is the minimum compression level (1). *)
98+99+val max_clevel : int
100+(** [max_clevel] is the maximum compression level (19). *)
101+102+val default_clevel : int
103+(** [default_clevel] is the default compression level (3). *)
···1+(** Bitstream reader for Zstandard decompression.
2+3+ This module wraps the Bitstream library, translating exceptions
4+ to Zstd_error for consistent error handling. *)
5+6+(** Helper to wrap Bitstream operations and translate exceptions *)
7+let[@inline] wrap_truncated f =
8+ try f ()
9+ with Bitstream.End_of_stream ->
10+ raise (Constants.Zstd_error Constants.Truncated_input)
11+12+let[@inline] wrap_all f =
13+ try f ()
14+ with
15+ | Bitstream.End_of_stream ->
16+ raise (Constants.Zstd_error Constants.Truncated_input)
17+ | Bitstream.Invalid_state _ ->
18+ raise (Constants.Zstd_error Constants.Corruption)
19+ | Bitstream.Corrupted_stream _ ->
20+ raise (Constants.Zstd_error Constants.Corruption)
21+22+(** Forward bitstream reader - reads from start to end *)
23+module Forward = struct
24+ type t = Bitstream.Forward_reader.t
25+26+ let create src ~pos ~len =
27+ Bitstream.Forward_reader.create src ~pos ~len
28+29+ let of_bytes src =
30+ Bitstream.Forward_reader.of_bytes src
31+32+ let[@inline] remaining t =
33+ Bitstream.Forward_reader.remaining t
34+35+ let[@inline] is_byte_aligned t =
36+ Bitstream.Forward_reader.is_byte_aligned t
37+38+ let[@inline] read_bits t n =
39+ wrap_truncated (fun () -> Bitstream.Forward_reader.read_bits t n)
40+41+ let[@inline] read_byte t =
42+ wrap_all (fun () -> Bitstream.Forward_reader.read_byte t)
43+44+ let rewind_bits t n =
45+ wrap_truncated (fun () -> Bitstream.Forward_reader.rewind_bits t n)
46+47+ let align t =
48+ Bitstream.Forward_reader.align t
49+50+ let byte_position t =
51+ wrap_all (fun () -> Bitstream.Forward_reader.byte_position t)
52+53+ let get_bytes t n =
54+ wrap_all (fun () -> Bitstream.Forward_reader.get_bytes t n)
55+56+ let advance t n =
57+ wrap_all (fun () -> Bitstream.Forward_reader.advance t n)
58+59+ let sub t n =
60+ wrap_all (fun () -> Bitstream.Forward_reader.sub t n)
61+62+ let remaining_bytes t =
63+ wrap_all (fun () -> Bitstream.Forward_reader.remaining_bytes t)
64+end
65+66+(** Backward bitstream reader - reads from end to start.
67+ Used for FSE and Huffman coded streams. *)
68+module Backward = struct
69+ type t = Bitstream.Backward_reader.t
70+71+ let create src ~pos ~len =
72+ wrap_all (fun () -> Bitstream.Backward_reader.of_bytes src ~pos ~len)
73+74+ let of_bytes src ~pos ~len =
75+ create src ~pos ~len
76+77+ let[@inline] remaining t =
78+ Bitstream.Backward_reader.remaining t
79+80+ let[@inline] read_bits t n =
81+ Bitstream.Backward_reader.read_bits t n
82+83+ let[@inline] is_empty t =
84+ Bitstream.Backward_reader.is_empty t
85+end
86+87+(** Read little-endian integers from bytes *)
88+let[@inline] get_u16_le src pos =
89+ Bytes.get_uint16_le src pos
···1+(** Bitstream writer for Zstandard compression.
2+3+ This module wraps the Bitstream library for consistent API
4+ with the rest of the zstd implementation. *)
5+6+(** Forward bitstream writer - writes from start to end *)
7+module Forward = struct
8+ type t = Bitstream.Forward_writer.t
9+10+ let create dst ~pos =
11+ Bitstream.Forward_writer.create dst ~pos
12+13+ let of_bytes dst =
14+ Bitstream.Forward_writer.of_bytes dst
15+16+ let flush t =
17+ Bitstream.Forward_writer.flush t
18+19+ let write_bits t value n =
20+ Bitstream.Forward_writer.write_bits t value n
21+22+ let write_byte t value =
23+ Bitstream.Forward_writer.write_byte t value
24+25+ let write_bytes t src =
26+ Bitstream.Forward_writer.write_bytes t src
27+28+ let byte_position t =
29+ Bitstream.Forward_writer.byte_position t
30+31+ let finalize t =
32+ Bitstream.Forward_writer.finalize t
33+end
34+35+(** Backward bitstream writer - accumulates bits to be read backwards.
36+ Used for FSE and Huffman encoding. *)
37+module Backward = struct
38+ type t = Bitstream.Backward_writer.t
39+40+ let create size =
41+ Bitstream.Backward_writer.create size
42+43+ let[@inline] write_bits t value n =
44+ Bitstream.Backward_writer.write_bits t value n
45+46+ let flush_bytes t =
47+ Bitstream.Backward_writer.flush_bytes t
48+49+ let finalize t =
50+ Bitstream.Backward_writer.finalize t
51+52+ let current_size t =
53+ Bitstream.Backward_writer.current_size t
54+end
···1+(** Finite State Entropy (FSE) decoding for Zstandard.
2+3+ FSE is an entropy coding method based on ANS (Asymmetric Numeral Systems).
4+ FSE streams are read backwards (from end to beginning). *)
5+6+(** FSE decoding table entry *)
7+type entry = {
8+ symbol : int;
9+ num_bits : int;
10+ new_state_base : int;
11+}
12+13+(** FSE decoding table *)
14+type dtable = {
15+ entries : entry array;
16+ accuracy_log : int;
17+}
18+19+(** Find the highest set bit (floor(log2(n))) *)
20+let[@inline] highest_set_bit n =
21+ if n = 0 then -1
22+ else
23+ let rec loop i =
24+ if (1 lsl i) <= n then loop (i + 1)
25+ else i - 1
26+ in
27+ loop 0
28+29+(** Build FSE decoding table from normalized frequencies.
30+ Frequencies can be negative (-1 means probability < 1). *)
31+let build_dtable frequencies accuracy_log =
32+ let table_size = 1 lsl accuracy_log in
33+ let num_symbols = Array.length frequencies in
34+35+ (* Create entries array *)
36+ let entries = Array.init table_size (fun _ ->
37+ { symbol = 0; num_bits = 0; new_state_base = 0 }
38+ ) in
39+40+ (* Track state descriptors for each symbol *)
41+ let state_desc = Array.make num_symbols 0 in
42+43+ (* First pass: place symbols with prob < 1 at the end *)
44+ let high_threshold = ref table_size in
45+ for s = 0 to num_symbols - 1 do
46+ if frequencies.(s) = -1 then begin
47+ decr high_threshold;
48+ entries.(!high_threshold) <- { symbol = s; num_bits = 0; new_state_base = 0 };
49+ state_desc.(s) <- 1
50+ end
51+ done;
52+53+ (* Second pass: distribute remaining symbols using the step formula *)
54+ let step = (table_size lsr 1) + (table_size lsr 3) + 3 in
55+ let mask = table_size - 1 in
56+ let pos = ref 0 in
57+58+ for s = 0 to num_symbols - 1 do
59+ if frequencies.(s) > 0 then begin
60+ state_desc.(s) <- frequencies.(s);
61+ for _ = 0 to frequencies.(s) - 1 do
62+ entries.(!pos) <- { entries.(!pos) with symbol = s };
63+ (* Skip positions occupied by prob < 1 symbols *)
64+ pos := (!pos + step) land mask;
65+ while !pos >= !high_threshold do
66+ pos := (!pos + step) land mask
67+ done
68+ done
69+ end
70+ done;
71+72+ if !pos <> 0 then
73+ raise (Constants.Zstd_error Constants.Invalid_fse_table);
74+75+ (* Third pass: fill in num_bits and new_state_base *)
76+ for i = 0 to table_size - 1 do
77+ let s = entries.(i).symbol in
78+ let next_state_desc = state_desc.(s) in
79+ state_desc.(s) <- next_state_desc + 1;
80+81+ (* Number of bits is accuracy_log - log2(next_state_desc) *)
82+ let num_bits = accuracy_log - highest_set_bit next_state_desc in
83+ (* new_state_base = (next_state_desc << num_bits) - table_size *)
84+ let new_state_base = (next_state_desc lsl num_bits) - table_size in
85+86+ entries.(i) <- { entries.(i) with num_bits; new_state_base }
87+ done;
88+89+ { entries; accuracy_log }
90+91+(** Build RLE table (single symbol repeated) *)
92+let build_dtable_rle symbol =
93+ {
94+ entries = [| { symbol; num_bits = 0; new_state_base = 0 } |];
95+ accuracy_log = 0;
96+ }
97+98+(** Peek at the symbol for current state (doesn't update state) *)
99+let[@inline] peek_symbol dtable state =
100+ dtable.entries.(state).symbol
101+102+(** Update state by reading bits from the stream *)
103+let[@inline] update_state dtable state (stream : Bit_reader.Backward.t) =
104+ let entry = dtable.entries.(state) in
105+ let bits = Bit_reader.Backward.read_bits stream entry.num_bits in
106+ entry.new_state_base + bits
107+108+(** Decode symbol and update state *)
109+let[@inline] decode_symbol dtable state stream =
110+ let symbol = peek_symbol dtable state in
111+ let new_state = update_state dtable state stream in
112+ (symbol, new_state)
113+114+(** Initialize state by reading accuracy_log bits *)
115+let[@inline] init_state dtable (stream : Bit_reader.Backward.t) =
116+ Bit_reader.Backward.read_bits stream dtable.accuracy_log
117+118+(** Decode FSE header and build decoding table.
119+ Returns the table and advances the forward stream. *)
120+let decode_header (stream : Bit_reader.Forward.t) max_accuracy_log =
121+ (* Accuracy log is first 4 bits + 5 *)
122+ let accuracy_log = (Bit_reader.Forward.read_bits stream 4) + 5 in
123+ if accuracy_log > max_accuracy_log then
124+ raise (Constants.Zstd_error Constants.Invalid_fse_table);
125+126+ let table_size = 1 lsl accuracy_log in
127+ let frequencies = Array.make Constants.max_fse_symbols 0 in
128+129+ let remaining = ref table_size in
130+ let symbol = ref 0 in
131+132+ while !remaining > 0 && !symbol < Constants.max_fse_symbols do
133+ (* Determine how many bits we might need *)
134+ let bits_needed = highest_set_bit (!remaining + 1) + 1 in
135+ let value = Bit_reader.Forward.read_bits stream bits_needed in
136+137+ (* Small value optimization: values < threshold use one less bit *)
138+ let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in
139+ let lower_mask = (1 lsl (bits_needed - 1)) - 1 in
140+141+ let (actual_value, bits_consumed) =
142+ if (value land lower_mask) < threshold then
143+ (value land lower_mask, bits_needed - 1)
144+ else if value > lower_mask then
145+ (value - threshold, bits_needed)
146+ else
147+ (value, bits_needed)
148+ in
149+150+ (* Rewind if we read too many bits *)
151+ if bits_consumed < bits_needed then
152+ Bit_reader.Forward.rewind_bits stream 1;
153+154+ (* Probability = value - 1 (so value 0 means prob = -1) *)
155+ let prob = actual_value - 1 in
156+ frequencies.(!symbol) <- prob;
157+ remaining := !remaining - abs prob;
158+ incr symbol;
159+160+ (* Handle zero probability with repeat flags *)
161+ if prob = 0 then begin
162+ let rec read_zeroes () =
163+ let repeat = Bit_reader.Forward.read_bits stream 2 in
164+ for _ = 1 to repeat do
165+ if !symbol < Constants.max_fse_symbols then begin
166+ frequencies.(!symbol) <- 0;
167+ incr symbol
168+ end
169+ done;
170+ if repeat = 3 then read_zeroes ()
171+ in
172+ read_zeroes ()
173+ end
174+ done;
175+176+ (* Align to byte boundary *)
177+ Bit_reader.Forward.align stream;
178+179+ if !remaining <> 0 then
180+ raise (Constants.Zstd_error Constants.Invalid_fse_table);
181+182+ (* Build the decoding table *)
183+ let freq_slice = Array.sub frequencies 0 !symbol in
184+ build_dtable freq_slice accuracy_log
185+186+(** Decompress interleaved 2-state FSE stream.
187+ Used for Huffman weight encoding. Returns number of symbols decoded. *)
188+let decompress_interleaved2 dtable src ~pos ~len output =
189+ let stream = Bit_reader.Backward.of_bytes src ~pos ~len in
190+191+ (* Initialize two states *)
192+ let state1 = ref (init_state dtable stream) in
193+ let state2 = ref (init_state dtable stream) in
194+195+ let out_pos = ref 0 in
196+ let out_len = Bytes.length output in
197+198+ (* Decode symbols alternating between states *)
199+ while Bit_reader.Backward.remaining stream >= 0 do
200+ if !out_pos >= out_len then
201+ raise (Constants.Zstd_error Constants.Output_too_small);
202+203+ let (sym1, new_state1) = decode_symbol dtable !state1 stream in
204+ Bytes.set_uint8 output !out_pos sym1;
205+ incr out_pos;
206+ state1 := new_state1;
207+208+ if Bit_reader.Backward.remaining stream < 0 then begin
209+ (* Stream exhausted, output final symbol from state2 *)
210+ if !out_pos < out_len then begin
211+ Bytes.set_uint8 output !out_pos (peek_symbol dtable !state2);
212+ incr out_pos
213+ end
214+ end else begin
215+ if !out_pos >= out_len then
216+ raise (Constants.Zstd_error Constants.Output_too_small);
217+218+ let (sym2, new_state2) = decode_symbol dtable !state2 stream in
219+ Bytes.set_uint8 output !out_pos sym2;
220+ incr out_pos;
221+ state2 := new_state2;
222+223+ if Bit_reader.Backward.remaining stream < 0 then begin
224+ (* Stream exhausted, output final symbol from state1 *)
225+ if !out_pos < out_len then begin
226+ Bytes.set_uint8 output !out_pos (peek_symbol dtable !state1);
227+ incr out_pos
228+ end
229+ end
230+ end
231+ done;
232+233+ !out_pos
234+235+(** Build decoding table from predefined distribution *)
236+let build_predefined_table distribution accuracy_log =
237+ build_dtable distribution accuracy_log
238+239+(* ========== ENCODING ========== *)
240+241+(** FSE compression table - matches C zstd's FSE_symbolCompressionTransform format.
242+ deltaNbBits is encoded as (maxBitsOut << 16) - minStatePlus
243+ This allows computing nbBitsOut = (state + deltaNbBits) >> 16 *)
244+type symbol_transform = {
245+ delta_nb_bits : int; (* (maxBitsOut << 16) - minStatePlus *)
246+ delta_find_state : int; (* Cumulative offset to find next state *)
247+}
248+249+(** FSE compression table *)
250+type ctable = {
251+ symbol_tt : symbol_transform array; (* Symbol compression transforms *)
252+ state_table : int array; (* Next state lookup table *)
253+ accuracy_log : int;
254+ table_size : int;
255+}
256+257+(** FSE compression state - matches C zstd's FSE_CState_t *)
258+type cstate = {
259+ mutable value : int; (* Current state value *)
260+ ctable : ctable; (* Reference to compression table *)
261+}
262+263+(** Count symbol frequencies *)
264+let count_symbols src ~pos ~len max_symbol =
265+ let counts = Array.make (max_symbol + 1) 0 in
266+ for i = pos to pos + len - 1 do
267+ let s = Bytes.get_uint8 src i in
268+ if s <= max_symbol then
269+ counts.(s) <- counts.(s) + 1
270+ done;
271+ counts
272+273+(** Normalize counts to sum to table_size *)
274+let normalize_counts counts total accuracy_log =
275+ let table_size = 1 lsl accuracy_log in
276+ let num_symbols = Array.length counts in
277+ let norm = Array.make num_symbols 0 in
278+279+ if total = 0 then norm
280+ else begin
281+ let scale = table_size * 256 / total in
282+ let distributed = ref 0 in
283+284+ for s = 0 to num_symbols - 1 do
285+ if counts.(s) > 0 then begin
286+ let proba = (counts.(s) * scale + 128) / 256 in
287+ let proba = max 1 proba in
288+ norm.(s) <- proba;
289+ distributed := !distributed + proba
290+ end
291+ done;
292+293+ while !distributed > table_size do
294+ let max_val = ref 0 in
295+ let max_idx = ref 0 in
296+ for s = 0 to num_symbols - 1 do
297+ if norm.(s) > !max_val then begin
298+ max_val := norm.(s);
299+ max_idx := s
300+ end
301+ done;
302+ norm.(!max_idx) <- norm.(!max_idx) - 1;
303+ decr distributed
304+ done;
305+306+ while !distributed < table_size do
307+ let min_val = ref max_int in
308+ let min_idx = ref 0 in
309+ for s = 0 to num_symbols - 1 do
310+ if norm.(s) > 0 && norm.(s) < !min_val then begin
311+ min_val := norm.(s);
312+ min_idx := s
313+ end
314+ done;
315+ norm.(!min_idx) <- norm.(!min_idx) + 1;
316+ incr distributed
317+ done;
318+319+ norm
320+ end
321+322+(** Build FSE compression table from normalized counts.
323+ Matches C zstd's FSE_buildCTable_wksp algorithm exactly. *)
324+let build_ctable norm_counts accuracy_log =
325+ let table_size = 1 lsl accuracy_log in
326+ let table_mask = table_size - 1 in
327+ let num_symbols = Array.length norm_counts in
328+ let step = (table_size lsr 1) + (table_size lsr 3) + 3 in
329+330+ (* Symbol distribution table - which symbol at each state *)
331+ let table_symbol = Array.make table_size 0 in
332+333+ (* Cumulative counts for state table indexing *)
334+ let cumul = Array.make (num_symbols + 1) 0 in
335+ cumul.(0) <- 0;
336+ for s = 0 to num_symbols - 1 do
337+ let count = if norm_counts.(s) = -1 then 1 else max 0 norm_counts.(s) in
338+ cumul.(s + 1) <- cumul.(s) + count
339+ done;
340+341+ (* Place low probability symbols at the end *)
342+ let high_threshold = ref (table_size - 1) in
343+ for s = 0 to num_symbols - 1 do
344+ if norm_counts.(s) = -1 then begin
345+ table_symbol.(!high_threshold) <- s;
346+ decr high_threshold
347+ end
348+ done;
349+350+ (* Spread remaining symbols using step formula *)
351+ let pos = ref 0 in
352+ for s = 0 to num_symbols - 1 do
353+ let count = norm_counts.(s) in
354+ if count > 0 then begin
355+ for _ = 0 to count - 1 do
356+ table_symbol.(!pos) <- s;
357+ pos := (!pos + step) land table_mask;
358+ while !pos > !high_threshold do
359+ pos := (!pos + step) land table_mask
360+ done
361+ done
362+ end
363+ done;
364+365+ (* Build state table - for each position, compute next state *)
366+ let state_table = Array.make table_size 0 in
367+ let cumul_copy = Array.copy cumul in
368+ for u = 0 to table_size - 1 do
369+ let s = table_symbol.(u) in
370+ state_table.(cumul_copy.(s)) <- table_size + u;
371+ cumul_copy.(s) <- cumul_copy.(s) + 1
372+ done;
373+374+ (* Build symbol compression transforms *)
375+ let symbol_tt = Array.init num_symbols (fun s ->
376+ let count = norm_counts.(s) in
377+ match count with
378+ | 0 ->
379+ (* Zero probability - use max bits (shouldn't be encoded) *)
380+ { delta_nb_bits = ((accuracy_log + 1) lsl 16) - (1 lsl accuracy_log);
381+ delta_find_state = 0 }
382+ | -1 | 1 ->
383+ (* Low probability symbol *)
384+ { delta_nb_bits = (accuracy_log lsl 16) - (1 lsl accuracy_log);
385+ delta_find_state = cumul.(s) - 1 }
386+ | _ ->
387+ (* Normal symbol *)
388+ let max_bits_out = accuracy_log - highest_set_bit (count - 1) in
389+ let min_state_plus = count lsl max_bits_out in
390+ { delta_nb_bits = (max_bits_out lsl 16) - min_state_plus;
391+ delta_find_state = cumul.(s) - count }
392+ ) in
393+394+ { symbol_tt; state_table; accuracy_log; table_size }
395+396+(** Initialize compression state - matches C's FSE_initCState *)
397+let init_cstate ctable =
398+ { value = 1 lsl ctable.accuracy_log; ctable }
399+400+(** Initialize compression state with first symbol - matches C's FSE_initCState2.
401+ This saves bits by using the smallest valid state for the first symbol. *)
402+let init_cstate2 ctable symbol =
403+ let st = ctable.symbol_tt.(symbol) in
404+ let nb_bits_out = (st.delta_nb_bits + (1 lsl 15)) lsr 16 in
405+ let init_value = (nb_bits_out lsl 16) - st.delta_nb_bits in
406+ let state_idx = (init_value lsr nb_bits_out) + st.delta_find_state in
407+ { value = ctable.state_table.(state_idx); ctable }
408+409+(** Encode a single symbol - matches C's FSE_encodeSymbol exactly.
410+ Outputs bits representing state transition and updates state. *)
411+let[@inline] encode_symbol (stream : Bit_writer.Backward.t) cstate symbol =
412+ let st = cstate.ctable.symbol_tt.(symbol) in
413+ let nb_bits_out = (cstate.value + st.delta_nb_bits) lsr 16 in
414+ Bit_writer.Backward.write_bits stream cstate.value nb_bits_out;
415+ let state_idx = (cstate.value lsr nb_bits_out) + st.delta_find_state in
416+ cstate.value <- cstate.ctable.state_table.(state_idx)
417+418+(** Flush compression state - matches C's FSE_flushCState.
419+ Outputs final state value to allow decoder to initialize. *)
420+let[@inline] flush_cstate (stream : Bit_writer.Backward.t) cstate =
421+ Bit_writer.Backward.write_bits stream cstate.value cstate.ctable.accuracy_log
422+423+(** Write FSE header (normalized counts) *)
424+let write_header (stream : Bit_writer.Forward.t) norm_counts accuracy_log =
425+ Bit_writer.Forward.write_bits stream (accuracy_log - 5) 4;
426+427+ let table_size = 1 lsl accuracy_log in
428+ let num_symbols = Array.length norm_counts in
429+ let remaining = ref table_size in
430+ let symbol = ref 0 in
431+432+ while !remaining > 0 && !symbol < num_symbols do
433+ let count = norm_counts.(!symbol) in
434+ let value = count + 1 in
435+436+ let bits_needed = highest_set_bit (!remaining + 1) + 1 in
437+ let threshold = (1 lsl bits_needed) - 1 - (!remaining + 1) in
438+439+ if value < threshold then
440+ Bit_writer.Forward.write_bits stream value (bits_needed - 1)
441+ else
442+ Bit_writer.Forward.write_bits stream (value + threshold) bits_needed;
443+444+ remaining := !remaining - abs count;
445+ incr symbol;
446+447+ if count = 0 then begin
448+ let rec count_zeroes acc =
449+ if !symbol < num_symbols && norm_counts.(!symbol) = 0 then begin
450+ incr symbol;
451+ count_zeroes (acc + 1)
452+ end else acc
453+ in
454+ let zeroes = count_zeroes 0 in
455+ let rec write_repeats n =
456+ if n >= 3 then begin
457+ Bit_writer.Forward.write_bits stream 3 2;
458+ write_repeats (n - 3)
459+ end else
460+ Bit_writer.Forward.write_bits stream n 2
461+ in
462+ write_repeats zeroes
463+ end
464+ done
465+466+(** Build encoding table from predefined distribution *)
467+let build_predefined_ctable distribution accuracy_log =
468+ build_ctable distribution accuracy_log
···1+(** Huffman coding for Zstandard literals decompression.
2+3+ Zstd uses canonical Huffman codes for literal compression.
4+ Huffman streams are read backwards like FSE streams. *)
5+6+(** Huffman decoding table entry *)
7+type entry = {
8+ symbol : int;
9+ num_bits : int;
10+}
11+12+(** Huffman decoding table *)
13+type dtable = {
14+ entries : entry array;
15+ max_bits : int;
16+}
17+18+let highest_set_bit = Fse.highest_set_bit
19+20+(** Build Huffman table from bit lengths.
21+ Uses canonical Huffman coding. *)
22+let build_dtable_from_bits bits num_symbols =
23+ if num_symbols > Constants.max_huffman_symbols then
24+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
25+26+ (* Find max bits and count symbols per bit length *)
27+ let max_bits = ref 0 in
28+ let rank_count = Array.make (Constants.max_huffman_bits + 1) 0 in
29+30+ for i = 0 to num_symbols - 1 do
31+ let b = bits.(i) in
32+ if b > Constants.max_huffman_bits then
33+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
34+ if b > !max_bits then max_bits := b;
35+ rank_count.(b) <- rank_count.(b) + 1
36+ done;
37+38+ if !max_bits = 0 then
39+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
40+41+ let table_size = 1 lsl !max_bits in
42+ let entries = Array.init table_size (fun _ ->
43+ { symbol = 0; num_bits = 0 }
44+ ) in
45+46+ (* Calculate starting indices for each rank *)
47+ let rank_idx = Array.make (Constants.max_huffman_bits + 1) 0 in
48+ rank_idx.(!max_bits) <- 0;
49+ for i = !max_bits downto 1 do
50+ rank_idx.(i - 1) <- rank_idx.(i) + rank_count.(i) * (1 lsl (!max_bits - i));
51+ (* Fill in num_bits for this range *)
52+ for j = rank_idx.(i) to rank_idx.(i - 1) - 1 do
53+ entries.(j) <- { entries.(j) with num_bits = i }
54+ done
55+ done;
56+57+ if rank_idx.(0) <> table_size then
58+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
59+60+ (* Assign symbols to table entries *)
61+ for i = 0 to num_symbols - 1 do
62+ let b = bits.(i) in
63+ if b <> 0 then begin
64+ let code = rank_idx.(b) in
65+ let len = 1 lsl (!max_bits - b) in
66+ for j = code to code + len - 1 do
67+ entries.(j) <- { entries.(j) with symbol = i }
68+ done;
69+ rank_idx.(b) <- code + len
70+ end
71+ done;
72+73+ { entries; max_bits = !max_bits }
74+75+(** Build table from weights (as decoded from zstd format) *)
76+let build_dtable_from_weights weights num_symbols =
77+ if num_symbols + 1 > Constants.max_huffman_symbols then
78+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
79+80+ let bits = Array.make (num_symbols + 1) 0 in
81+82+ (* Calculate weight sum to find max_bits and last weight *)
83+ let weight_sum = ref 0 in
84+ for i = 0 to num_symbols - 1 do
85+ let w = weights.(i) in
86+ if w > Constants.max_huffman_bits then
87+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
88+ if w > 0 then
89+ weight_sum := !weight_sum + (1 lsl (w - 1))
90+ done;
91+92+ (* Find max_bits (first power of 2 > weight_sum) *)
93+ let max_bits = highest_set_bit !weight_sum + 1 in
94+ let left_over = (1 lsl max_bits) - !weight_sum in
95+96+ (* left_over must be a power of 2 *)
97+ if left_over land (left_over - 1) <> 0 then
98+ raise (Constants.Zstd_error Constants.Invalid_huffman_table);
99+100+ let last_weight = highest_set_bit left_over + 1 in
101+102+ (* Convert weights to bit lengths *)
103+ for i = 0 to num_symbols - 1 do
104+ let w = weights.(i) in
105+ bits.(i) <- if w > 0 then max_bits + 1 - w else 0
106+ done;
107+ bits.(num_symbols) <- max_bits + 1 - last_weight;
108+109+ build_dtable_from_bits bits (num_symbols + 1)
110+111+(** Initialize Huffman state by reading max_bits *)
112+let[@inline] init_state dtable (stream : Bit_reader.Backward.t) =
113+ Bit_reader.Backward.read_bits stream dtable.max_bits
114+115+(** Decode a symbol and update state *)
116+let[@inline] decode_symbol dtable state (stream : Bit_reader.Backward.t) =
117+ let entry = dtable.entries.(state) in
118+ let symbol = entry.symbol in
119+ let bits_used = entry.num_bits in
120+ (* Shift out used bits and read new ones *)
121+ let mask = (1 lsl dtable.max_bits) - 1 in
122+ let rest = Bit_reader.Backward.read_bits stream bits_used in
123+ let new_state = ((state lsl bits_used) + rest) land mask in
124+ (symbol, new_state)
125+126+(** Decompress a single Huffman stream *)
127+let decompress_1stream dtable src ~pos ~len output ~out_pos ~out_len =
128+ let stream = Bit_reader.Backward.of_bytes src ~pos ~len in
129+ let state = ref (init_state dtable stream) in
130+131+ let written = ref 0 in
132+ while Bit_reader.Backward.remaining stream > -dtable.max_bits do
133+ if out_pos + !written >= out_pos + out_len then
134+ raise (Constants.Zstd_error Constants.Output_too_small);
135+136+ let (symbol, new_state) = decode_symbol dtable !state stream in
137+ Bytes.set_uint8 output (out_pos + !written) symbol;
138+ incr written;
139+ state := new_state
140+ done;
141+142+ (* Verify stream is exactly consumed *)
143+ if Bit_reader.Backward.remaining stream <> -dtable.max_bits then
144+ raise (Constants.Zstd_error Constants.Corruption);
145+146+ !written
147+148+(** Decompress 4 interleaved Huffman streams *)
149+let decompress_4stream dtable src ~pos ~len output ~out_pos ~regen_size =
150+ (* Read stream sizes from jump table (6 bytes) *)
151+ let size1 = Bit_reader.get_u16_le src pos in
152+ let size2 = Bit_reader.get_u16_le src (pos + 2) in
153+ let size3 = Bit_reader.get_u16_le src (pos + 4) in
154+ let size4 = len - 6 - size1 - size2 - size3 in
155+156+ if size4 < 1 then
157+ raise (Constants.Zstd_error Constants.Corruption);
158+159+ (* Calculate output sizes *)
160+ let out_size = (regen_size + 3) / 4 in
161+ let out_size4 = regen_size - 3 * out_size in
162+163+ (* Decompress each stream *)
164+ let stream_pos = pos + 6 in
165+166+ let written1 = decompress_1stream dtable src
167+ ~pos:stream_pos ~len:size1
168+ output ~out_pos ~out_len:out_size in
169+170+ let written2 = decompress_1stream dtable src
171+ ~pos:(stream_pos + size1) ~len:size2
172+ output ~out_pos:(out_pos + out_size) ~out_len:out_size in
173+174+ let written3 = decompress_1stream dtable src
175+ ~pos:(stream_pos + size1 + size2) ~len:size3
176+ output ~out_pos:(out_pos + 2 * out_size) ~out_len:out_size in
177+178+ let written4 = decompress_1stream dtable src
179+ ~pos:(stream_pos + size1 + size2 + size3) ~len:size4
180+ output ~out_pos:(out_pos + 3 * out_size) ~out_len:out_size4 in
181+182+ written1 + written2 + written3 + written4
183+184+(** Decode Huffman table from stream.
185+ Returns (dtable, bytes consumed) *)
186+let decode_table (stream : Bit_reader.Forward.t) =
187+ let header = Bit_reader.Forward.read_byte stream in
188+189+ let weights = Array.make Constants.max_huffman_symbols 0 in
190+ let num_symbols =
191+ if header >= 128 then begin
192+ (* Direct representation: 4 bits per weight *)
193+ let count = header - 127 in
194+ let bytes_needed = (count + 1) / 2 in
195+ let data = Bit_reader.Forward.get_bytes stream bytes_needed in
196+197+ for i = 0 to count - 1 do
198+ let byte = Bytes.get_uint8 data (i / 2) in
199+ weights.(i) <- if i mod 2 = 0 then byte lsr 4 else byte land 0xf
200+ done;
201+ count
202+ end else begin
203+ (* FSE compressed weights *)
204+ let compressed_size = header in
205+ let fse_data = Bit_reader.Forward.get_bytes stream compressed_size in
206+207+ (* Decode FSE table for weights (max accuracy 7) *)
208+ let fse_stream = Bit_reader.Forward.of_bytes fse_data in
209+ let fse_table = Fse.decode_header fse_stream 7 in
210+211+ (* Remaining bytes are the compressed weights *)
212+ let weights_pos = Bit_reader.Forward.byte_position fse_stream in
213+ let weights_len = compressed_size - weights_pos in
214+215+ let weight_bytes = Bytes.create Constants.max_huffman_symbols in
216+ let decoded = Fse.decompress_interleaved2 fse_table
217+ fse_data ~pos:weights_pos ~len:weights_len weight_bytes in
218+219+ for i = 0 to decoded - 1 do
220+ weights.(i) <- Bytes.get_uint8 weight_bytes i
221+ done;
222+ decoded
223+ end
224+ in
225+226+ build_dtable_from_weights weights num_symbols
227+228+(* ========== ENCODING ========== *)
229+230+(** Huffman encoding table *)
231+type ctable = {
232+ codes : int array; (* Canonical code for each symbol *)
233+ num_bits : int array; (* Bit length for each symbol *)
234+ max_bits : int;
235+ num_symbols : int;
236+}
237+238+(** Build Huffman code from frequencies using package-merge algorithm *)
239+let build_ctable counts max_symbol max_bits_limit =
240+ let num_symbols = max_symbol + 1 in
241+ let freqs = Array.sub counts 0 num_symbols in
242+243+ (* Count non-zero frequencies *)
244+ let non_zero = ref 0 in
245+ for i = 0 to num_symbols - 1 do
246+ if freqs.(i) > 0 then incr non_zero
247+ done;
248+249+ if !non_zero = 0 then
250+ { codes = [||]; num_bits = [||]; max_bits = 0; num_symbols = 0 }
251+ else if !non_zero = 1 then begin
252+ (* Single symbol case *)
253+ let num_bits = Array.make num_symbols 0 in
254+ for i = 0 to num_symbols - 1 do
255+ if freqs.(i) > 0 then num_bits.(i) <- 1
256+ done;
257+ let codes = Array.make num_symbols 0 in
258+ { codes; num_bits; max_bits = 1; num_symbols }
259+ end else begin
260+ (* Sort symbols by frequency *)
261+ let sorted = Array.init num_symbols (fun i -> (freqs.(i), i)) in
262+ Array.sort (fun (f1, _) (f2, _) -> compare f1 f2) sorted;
263+264+ (* Build Huffman tree using a simple greedy approach *)
265+ (* This produces a valid but not necessarily optimal tree *)
266+ let bit_lengths = Array.make num_symbols 0 in
267+268+ (* Assign bit lengths based on frequency rank *)
269+ let active_count = ref 0 in
270+ for i = 0 to num_symbols - 1 do
271+ let (freq, _sym) = sorted.(num_symbols - 1 - i) in
272+ if freq > 0 then incr active_count
273+ done;
274+275+ (* Use Kraft's inequality to assign optimal lengths *)
276+ (* Start with uniform distribution and adjust *)
277+ let target_bits = max 1 (highest_set_bit !active_count + 1) in
278+ let max_bits = min max_bits_limit (max target_bits 1) in
279+280+ (* Simple heuristic: assign bits based on frequency ranking *)
281+ let rank = ref 0 in
282+ for i = num_symbols - 1 downto 0 do
283+ let (freq, sym) = sorted.(i) in
284+ if freq > 0 then begin
285+ (* More frequent symbols get shorter codes *)
286+ let bits =
287+ if !rank < (1 lsl (max_bits - 1)) then
288+ min max_bits (max 1 (max_bits - highest_set_bit (!rank + 1)))
289+ else
290+ max_bits
291+ in
292+ bit_lengths.(sym) <- bits;
293+ incr rank
294+ end
295+ done;
296+297+ (* Validate and adjust bit lengths to satisfy Kraft inequality *)
298+ let rec adjust () =
299+ let kraft_sum = ref 0.0 in
300+ for i = 0 to num_symbols - 1 do
301+ if bit_lengths.(i) > 0 then
302+ kraft_sum := !kraft_sum +. (1.0 /. (float_of_int (1 lsl bit_lengths.(i))))
303+ done;
304+ if !kraft_sum > 1.0 then begin
305+ (* Increase some lengths *)
306+ for i = 0 to num_symbols - 1 do
307+ if bit_lengths.(i) > 0 && bit_lengths.(i) < max_bits then begin
308+ bit_lengths.(i) <- bit_lengths.(i) + 1
309+ end
310+ done;
311+ adjust ()
312+ end
313+ in
314+ adjust ();
315+316+ (* Build canonical codes *)
317+ let codes = Array.make num_symbols 0 in
318+ let actual_max = ref 0 in
319+ for i = 0 to num_symbols - 1 do
320+ if bit_lengths.(i) > !actual_max then actual_max := bit_lengths.(i)
321+ done;
322+323+ (* Count symbols at each bit length *)
324+ let bl_count = Array.make (!actual_max + 1) 0 in
325+ for i = 0 to num_symbols - 1 do
326+ if bit_lengths.(i) > 0 then
327+ bl_count.(bit_lengths.(i)) <- bl_count.(bit_lengths.(i)) + 1
328+ done;
329+330+ (* Calculate starting code for each bit length *)
331+ let next_code = Array.make (!actual_max + 1) 0 in
332+ let code = ref 0 in
333+ for bits = 1 to !actual_max do
334+ code := (!code + bl_count.(bits - 1)) lsl 1;
335+ next_code.(bits) <- !code
336+ done;
337+338+ (* Assign codes to symbols *)
339+ for i = 0 to num_symbols - 1 do
340+ let bits = bit_lengths.(i) in
341+ if bits > 0 then begin
342+ codes.(i) <- next_code.(bits);
343+ next_code.(bits) <- next_code.(bits) + 1
344+ end
345+ done;
346+347+ { codes; num_bits = bit_lengths; max_bits = !actual_max; num_symbols }
348+ end
349+350+(** Convert bit lengths to weights (zstd format) *)
351+let bits_to_weights num_bits num_symbols max_bits =
352+ let weights = Array.make num_symbols 0 in
353+ for i = 0 to num_symbols - 1 do
354+ if num_bits.(i) > 0 then
355+ weights.(i) <- max_bits + 1 - num_bits.(i)
356+ done;
357+ weights
358+359+(** Write Huffman table header using direct representation.
360+ Returns the number of actual symbols to encode.
361+ Note: For tables with >127 weights, FSE compression could be used
362+ for better ratios, but direct representation is always valid. *)
363+let write_header (stream : Bit_writer.Forward.t) ctable =
364+ if ctable.num_symbols = 0 then 0
365+ else begin
366+ let weights = bits_to_weights ctable.num_bits ctable.num_symbols ctable.max_bits in
367+368+ (* Find last non-zero weight (implicit last symbol) *)
369+ let last_nonzero = ref (ctable.num_symbols - 1) in
370+ while !last_nonzero > 0 && weights.(!last_nonzero) = 0 do
371+ decr last_nonzero
372+ done;
373+374+ let num_weights = !last_nonzero in (* Last weight is implicit *)
375+376+ (* Direct representation: header byte = 128 + num_weights, then 4 bits per weight *)
377+ let header = 128 + num_weights in
378+ Bit_writer.Forward.write_byte stream header;
379+380+ (* Write weights packed as pairs (high nibble, low nibble) *)
381+ for i = 0 to (num_weights - 1) / 2 do
382+ let w1 = if 2 * i < num_weights then weights.(2 * i) else 0 in
383+ let w2 = if 2 * i + 1 < num_weights then weights.(2 * i + 1) else 0 in
384+ Bit_writer.Forward.write_byte stream ((w1 lsl 4) lor w2)
385+ done;
386+387+ num_weights + 1
388+ end
389+390+(** Encode a single symbol (write to backward stream) *)
391+let[@inline] encode_symbol ctable (stream : Bit_writer.Backward.t) symbol =
392+ let code = ctable.codes.(symbol) in
393+ let bits = ctable.num_bits.(symbol) in
394+ if bits > 0 then
395+ Bit_writer.Backward.write_bits stream code bits
396+397+(** Compress literals to a single Huffman stream *)
398+let compress_1stream ctable literals ~pos ~len =
399+ let stream = Bit_writer.Backward.create (len * 2 + 16) in
400+401+ (* Encode symbols in reverse order *)
402+ for i = pos + len - 1 downto pos do
403+ let sym = Bytes.get_uint8 literals i in
404+ encode_symbol ctable stream sym
405+ done;
406+407+ Bit_writer.Backward.finalize stream
408+409+(** Compress literals to 4 interleaved Huffman streams *)
410+let compress_4stream ctable literals ~pos ~len =
411+ let chunk_size = (len + 3) / 4 in
412+ let chunk4_size = len - 3 * chunk_size in
413+414+ (* Compress each stream *)
415+ let stream1 = compress_1stream ctable literals ~pos ~len:chunk_size in
416+ let stream2 = compress_1stream ctable literals ~pos:(pos + chunk_size) ~len:chunk_size in
417+ let stream3 = compress_1stream ctable literals ~pos:(pos + 2 * chunk_size) ~len:chunk_size in
418+ let stream4 = compress_1stream ctable literals ~pos:(pos + 3 * chunk_size) ~len:chunk4_size in
419+420+ (* Build output with jump table *)
421+ let size1 = Bytes.length stream1 in
422+ let size2 = Bytes.length stream2 in
423+ let size3 = Bytes.length stream3 in
424+ let total = 6 + size1 + size2 + size3 + Bytes.length stream4 in
425+426+ let output = Bytes.create total in
427+ Bytes.set_uint16_le output 0 size1;
428+ Bytes.set_uint16_le output 2 size2;
429+ Bytes.set_uint16_le output 4 size3;
430+ Bytes.blit stream1 0 output 6 size1;
431+ Bytes.blit stream2 0 output (6 + size1) size2;
432+ Bytes.blit stream3 0 output (6 + size1 + size2) size3;
433+ Bytes.blit stream4 0 output (6 + size1 + size2 + size3) (Bytes.length stream4);
434+435+ output
···1+(** Pure OCaml implementation of Zstandard compression (RFC 8878).
2+3+ {2 Decoder}
4+5+ The decoder is fully compliant with the zstd format specification and can
6+ decompress any valid zstd frame produced by any conforming encoder. It
7+ supports all block types (raw, RLE, compressed), Huffman and FSE entropy
8+ coding, and content checksums.
9+10+ {2 Encoder}
11+12+ The encoder produces valid zstd frames that can be decompressed by any
13+ conforming decoder (including the reference C implementation). Current
14+ encoding strategy:
15+16+ - {b RLE blocks}: Data consisting of a single repeated byte is encoded as
17+ RLE blocks (4 bytes total regardless of decompressed size)
18+ - {b Raw blocks}: All other data is stored as raw (uncompressed) blocks
19+20+ This means the encoder always produces valid output, but compression ratios
21+ are not optimal for most data. The encoder is suitable for:
22+ - Applications where decompression speed matters more than compressed size
23+ - Data that is already compressed or has high entropy
24+ - Testing zstd decoders
25+26+ Future improvements planned:
27+ - LZ77 match finding with sequence encoding
28+ - Huffman compression for literals
29+ - FSE-compressed blocks for better ratios
30+31+ {2 Dictionary Support}
32+33+ Dictionary decompression is supported. Dictionary compression is not yet
34+ implemented (falls back to regular compression). *)
35+36+type error = Constants.error =
37+ | Invalid_magic_number
38+ | Invalid_frame_header
39+ | Invalid_block_type
40+ | Invalid_block_size
41+ | Invalid_literals_header
42+ | Invalid_huffman_table
43+ | Invalid_fse_table
44+ | Invalid_sequence_header
45+ | Invalid_offset
46+ | Invalid_match_length
47+ | Truncated_input
48+ | Output_too_small
49+ | Checksum_mismatch
50+ | Dictionary_mismatch
51+ | Corruption
52+53+exception Zstd_error = Constants.Zstd_error
54+55+type dictionary = Zstd_decode.dictionary
56+57+let error_message = Constants.error_message
58+59+(** Check if data starts with zstd magic number *)
60+let is_zstd_frame s =
61+ if String.length s < 4 then false
62+ else
63+ let b = Bytes.unsafe_of_string s in
64+ let magic = Bytes.get_int32_le b 0 in
65+ magic = Constants.zstd_magic_number
66+67+(** Get decompressed size from frame header *)
68+let get_decompressed_size s =
69+ if String.length s < 5 then None
70+ else
71+ let b = Bytes.unsafe_of_string s in
72+ Zstd_decode.get_decompressed_size b ~pos:0 ~len:(String.length s)
73+74+(** Calculate maximum compressed size *)
75+let compress_bound src_len =
76+ (* zstd guarantees compressed size <= src_len + (src_len >> 8) + constant *)
77+ src_len + (src_len lsr 8) + 64
78+79+(** Load dictionary *)
80+let load_dictionary s =
81+ let b = Bytes.of_string s in
82+ Zstd_decode.parse_dictionary b ~pos:0 ~len:(String.length s)
83+84+(** Decompress bytes *)
85+let decompress_bytes_exn src =
86+ Zstd_decode.decompress_frame src ~pos:0 ~len:(Bytes.length src)
87+88+let decompress_bytes src =
89+ try Ok (decompress_bytes_exn src)
90+ with Zstd_error e -> Error (error_message e)
91+92+(** Decompress string *)
93+let decompress_exn s =
94+ let src = Bytes.unsafe_of_string s in
95+ let result = Zstd_decode.decompress_frame src ~pos:0 ~len:(String.length s) in
96+ Bytes.unsafe_to_string result
97+98+let decompress s =
99+ try Ok (decompress_exn s)
100+ with Zstd_error e -> Error (error_message e)
101+102+(** Decompress with dictionary *)
103+let decompress_with_dict_exn dict s =
104+ let src = Bytes.unsafe_of_string s in
105+ let result = Zstd_decode.decompress_frame ~dict src ~pos:0 ~len:(String.length s) in
106+ Bytes.unsafe_to_string result
107+108+let decompress_with_dict dict s =
109+ try Ok (decompress_with_dict_exn dict s)
110+ with Zstd_error e -> Error (error_message e)
111+112+(** Decompress into pre-allocated buffer *)
113+let decompress_into ~src ~src_pos ~src_len ~dst ~dst_pos =
114+ let result = Zstd_decode.decompress_frame src ~pos:src_pos ~len:src_len in
115+ let result_len = Bytes.length result in
116+ if dst_pos + result_len > Bytes.length dst then
117+ raise (Zstd_error Output_too_small);
118+ Bytes.blit result 0 dst dst_pos result_len;
119+ result_len
120+121+(** Compress string *)
122+let compress ?(level=3) s =
123+ Zstd_encode.compress ~level ~checksum:true s
124+125+(** Compress bytes *)
126+let compress_bytes ?(level=3) src =
127+ let s = Bytes.unsafe_to_string src in
128+ let result = Zstd_encode.compress ~level ~checksum:true s in
129+ Bytes.of_string result
130+131+let compress_with_dict ?level _dict s =
132+ (* Dictionary compression uses same encoder but with preloaded tables *)
133+ (* For now, just compress without dictionary *)
134+ compress ?level s
135+136+let compress_into ?(level=3) ~src ~src_pos ~src_len ~dst ~dst_pos () =
137+ let input = Bytes.sub_string src src_pos src_len in
138+ let result = Zstd_encode.compress ~level ~checksum:true input in
139+ let result_len = String.length result in
140+ if dst_pos + result_len > Bytes.length dst then
141+ raise (Zstd_error Output_too_small);
142+ Bytes.blit_string result 0 dst dst_pos result_len;
143+ result_len
144+145+(** Check if data starts with skippable frame magic *)
146+let is_skippable_frame s =
147+ let b = Bytes.unsafe_of_string s in
148+ Zstd_decode.is_skippable_frame b ~pos:0 ~len:(String.length s)
149+150+(** Get skippable frame variant (0-15) *)
151+let get_skippable_variant s =
152+ let b = Bytes.unsafe_of_string s in
153+ Zstd_decode.get_skippable_variant b ~pos:0 ~len:(String.length s)
154+155+(** Write a skippable frame *)
156+let write_skippable_frame ?variant content =
157+ Zstd_encode.write_skippable_frame ?variant content
158+159+(** Read a skippable frame and return its content *)
160+let read_skippable_frame s =
161+ let b = Bytes.unsafe_of_string s in
162+ let (content, _) = Zstd_decode.read_skippable_frame b ~pos:0 ~len:(String.length s) in
163+ content
164+165+(** Get total size of skippable frame *)
166+let get_skippable_frame_size s =
167+ let b = Bytes.unsafe_of_string s in
168+ Zstd_decode.get_skippable_frame_size b ~pos:0 ~len:(String.length s)
169+170+(** Find compressed size of first frame *)
171+let find_frame_compressed_size s =
172+ let b = Bytes.unsafe_of_string s in
173+ Zstd_decode.find_frame_compressed_size b ~pos:0 ~len:(String.length s)
174+175+(** Decompress all frames *)
176+let decompress_all_exn s =
177+ let b = Bytes.unsafe_of_string s in
178+ let result = Zstd_decode.decompress_frames b ~pos:0 ~len:(String.length s) in
179+ Bytes.unsafe_to_string result
180+181+let decompress_all s =
182+ try Ok (decompress_all_exn s)
183+ with Zstd_error e -> Error (error_message e)
···1+(** Pure OCaml implementation of Zstandard compression (RFC 8878).
2+3+ Zstandard is a fast compression algorithm providing high compression
4+ ratios. This library provides both compression and decompression
5+ functionality in pure OCaml.
6+7+ {1 Quick Start}
8+9+ Decompress data:
10+ {[
11+ let compressed = ... in
12+ match Zstd.decompress compressed with
13+ | Ok data -> use data
14+ | Error msg -> handle_error msg
15+ ]}
16+17+ Compress data:
18+ {[
19+ let data = ... in
20+ let compressed = Zstd.compress data in
21+ ...
22+ ]}
23+24+ {1 Error Handling}
25+26+ Two styles are provided:
27+ - Result-based: [decompress] returns [(string, string) result]
28+ - Exception-based: [decompress_exn] raises [Zstd_error]
29+30+ {1 Compression Levels}
31+32+ Compression levels range from 1 (fastest) to 19 (best compression).
33+ The default level is 3, which provides a good balance.
34+ Level 0 is a special level meaning "use default".
35+*)
36+37+(** {1 Types} *)
38+39+(** Error codes for decompression failures *)
40+type error =
41+ | Invalid_magic_number
42+ | Invalid_frame_header
43+ | Invalid_block_type
44+ | Invalid_block_size
45+ | Invalid_literals_header
46+ | Invalid_huffman_table
47+ | Invalid_fse_table
48+ | Invalid_sequence_header
49+ | Invalid_offset
50+ | Invalid_match_length
51+ | Truncated_input
52+ | Output_too_small
53+ | Checksum_mismatch
54+ | Dictionary_mismatch
55+ | Corruption
56+57+(** Exception raised by [*_exn] functions *)
58+exception Zstd_error of error
59+60+(** Pre-loaded dictionary for compression/decompression *)
61+type dictionary
62+63+(** {1 Simple API} *)
64+65+(** Decompress a zstd-compressed string.
66+ @return [Ok data] on success, [Error msg] on failure *)
67+val decompress : string -> (string, string) result
68+69+(** Decompress a zstd-compressed string.
70+ @raise Zstd_error on failure *)
71+val decompress_exn : string -> string
72+73+(** Compress a string using zstd.
74+ @param level Compression level 1-19 (default: 3)
75+ @return Compressed data *)
76+val compress : ?level:int -> string -> string
77+78+(** {1 Bytes API} *)
79+80+(** Decompress from bytes.
81+ @return [Ok data] on success, [Error msg] on failure *)
82+val decompress_bytes : bytes -> (bytes, string) result
83+84+(** Decompress from bytes.
85+ @raise Zstd_error on failure *)
86+val decompress_bytes_exn : bytes -> bytes
87+88+(** Compress bytes.
89+ @param level Compression level 1-19 (default: 3) *)
90+val compress_bytes : ?level:int -> bytes -> bytes
91+92+(** {1 Low-allocation API} *)
93+94+(** Decompress into a pre-allocated buffer.
95+ @param src Source buffer with compressed data
96+ @param src_pos Start position in source
97+ @param src_len Length of compressed data
98+ @param dst Destination buffer
99+ @param dst_pos Start position in destination
100+ @return Number of bytes written to destination
101+ @raise Zstd_error on failure or if destination is too small *)
102+val decompress_into :
103+ src:bytes -> src_pos:int -> src_len:int ->
104+ dst:bytes -> dst_pos:int -> int
105+106+(** Compress into a pre-allocated buffer.
107+ @param level Compression level 1-19 (default: 3)
108+ @param src Source buffer
109+ @param src_pos Start position in source
110+ @param src_len Length of data to compress
111+ @param dst Destination buffer
112+ @param dst_pos Start position in destination
113+ @return Number of bytes written to destination
114+ @raise Zstd_error on failure or if destination is too small *)
115+val compress_into :
116+ ?level:int ->
117+ src:bytes -> src_pos:int -> src_len:int ->
118+ dst:bytes -> dst_pos:int -> unit -> int
119+120+(** {1 Frame Information} *)
121+122+(** Get the decompressed size from a frame header, if available.
123+ Returns [None] if the frame doesn't include the content size. *)
124+val get_decompressed_size : string -> int64 option
125+126+(** Check if data starts with a valid zstd magic number. *)
127+val is_zstd_frame : string -> bool
128+129+(** Calculate the maximum compressed size for a given input size.
130+ This can be used to allocate a buffer for compression. *)
131+val compress_bound : int -> int
132+133+(** {1 Dictionary Support} *)
134+135+(** Load a dictionary from data.
136+ The dictionary can be either a raw content dictionary or a
137+ formatted dictionary with pre-computed entropy tables. *)
138+val load_dictionary : string -> dictionary
139+140+(** Decompress using a dictionary.
141+ @return [Ok data] on success, [Error msg] on failure *)
142+val decompress_with_dict : dictionary -> string -> (string, string) result
143+144+(** Decompress using a dictionary.
145+ @raise Zstd_error on failure *)
146+val decompress_with_dict_exn : dictionary -> string -> string
147+148+(** Compress using a dictionary.
149+ @param level Compression level 1-19 (default: 3) *)
150+val compress_with_dict : ?level:int -> dictionary -> string -> string
151+152+(** {1 Error Utilities} *)
153+154+(** Convert an error code to a human-readable message. *)
155+val error_message : error -> string
156+157+(** {1 Frame Type Detection} *)
158+159+(** Check if data starts with a valid skippable frame magic number.
160+ Skippable frames have magic numbers in the range 0x184D2A50 to 0x184D2A5F. *)
161+val is_skippable_frame : string -> bool
162+163+(** Get the skippable frame variant (0-15) if present.
164+ Returns [None] if not a skippable frame. *)
165+val get_skippable_variant : string -> int option
166+167+(** {1 Skippable Frame Support} *)
168+169+(** Write a skippable frame.
170+ Skippable frames can contain arbitrary data that will be ignored by decoders.
171+ @param variant Magic number variant 0-15 (default: 0)
172+ @param content The content to embed
173+ @return The complete skippable frame *)
174+val write_skippable_frame : ?variant:int -> string -> string
175+176+(** Read a skippable frame and return its content.
177+ @return The content bytes
178+ @raise Zstd_error if not a valid skippable frame *)
179+val read_skippable_frame : string -> bytes
180+181+(** Get the total size of a skippable frame (header + content).
182+ @return [Some size] if a valid skippable frame, [None] otherwise *)
183+val get_skippable_frame_size : string -> int option
184+185+(** {1 Multi-Frame Support} *)
186+187+(** Find the compressed size of the first frame (zstd or skippable).
188+ This is useful for parsing concatenated frames.
189+ @return Size in bytes of the complete first frame
190+ @raise Zstd_error on invalid or truncated input *)
191+val find_frame_compressed_size : string -> int
192+193+(** Decompress all frames (including skipping skippable frames).
194+ Concatenated zstd frames are decompressed and their output concatenated.
195+ Skippable frames are silently skipped.
196+ @return The concatenated decompressed output *)
197+val decompress_all : string -> (string, string) result
198+199+(** Decompress all frames, raising on error.
200+ @raise Zstd_error on failure *)
201+val decompress_all_exn : string -> string
···1+(** Zstandard compression implementation.
2+3+ Implements LZ77 matching, block compression, and frame encoding. *)
4+5+(** Compression level affects speed vs ratio tradeoff *)
6+type compression_level = {
7+ window_log : int; (* Log2 of window size *)
8+ chain_log : int; (* Log2 of hash chain length *)
9+ hash_log : int; (* Log2 of hash table size *)
10+ search_log : int; (* Number of searches per position *)
11+ min_match : int; (* Minimum match length *)
12+ target_len : int; (* Target match length *)
13+ strategy : int; (* 0=fast, 1=greedy, 2=lazy *)
14+}
15+16+(** Default levels 1-19 *)
17+let level_params = [|
18+ (* Level 0/1: Fast *)
19+ { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 };
20+ { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 };
21+ (* Level 2 *)
22+ { window_log = 18; chain_log = 13; hash_log = 12; search_log = 1; min_match = 5; target_len = 4; strategy = 0 };
23+ (* Level 3 *)
24+ { window_log = 18; chain_log = 14; hash_log = 13; search_log = 1; min_match = 5; target_len = 8; strategy = 1 };
25+ (* Level 4 *)
26+ { window_log = 18; chain_log = 14; hash_log = 14; search_log = 2; min_match = 4; target_len = 8; strategy = 1 };
27+ (* Level 5 *)
28+ { window_log = 18; chain_log = 15; hash_log = 14; search_log = 3; min_match = 4; target_len = 16; strategy = 1 };
29+ (* Level 6 *)
30+ { window_log = 19; chain_log = 16; hash_log = 15; search_log = 3; min_match = 4; target_len = 32; strategy = 1 };
31+ (* Level 7 *)
32+ { window_log = 19; chain_log = 16; hash_log = 15; search_log = 4; min_match = 4; target_len = 32; strategy = 2 };
33+ (* Level 8 *)
34+ { window_log = 19; chain_log = 17; hash_log = 16; search_log = 4; min_match = 4; target_len = 64; strategy = 2 };
35+ (* Level 9 *)
36+ { window_log = 20; chain_log = 17; hash_log = 16; search_log = 5; min_match = 4; target_len = 64; strategy = 2 };
37+ (* Level 10 *)
38+ { window_log = 20; chain_log = 17; hash_log = 16; search_log = 6; min_match = 4; target_len = 128; strategy = 2 };
39+ (* Level 11 *)
40+ { window_log = 20; chain_log = 18; hash_log = 17; search_log = 6; min_match = 4; target_len = 128; strategy = 2 };
41+ (* Level 12 *)
42+ { window_log = 21; chain_log = 18; hash_log = 17; search_log = 7; min_match = 4; target_len = 256; strategy = 2 };
43+ (* Level 13 *)
44+ { window_log = 21; chain_log = 19; hash_log = 18; search_log = 7; min_match = 4; target_len = 256; strategy = 2 };
45+ (* Level 14 *)
46+ { window_log = 22; chain_log = 19; hash_log = 18; search_log = 8; min_match = 4; target_len = 256; strategy = 2 };
47+ (* Level 15 *)
48+ { window_log = 22; chain_log = 20; hash_log = 18; search_log = 9; min_match = 4; target_len = 256; strategy = 2 };
49+ (* Level 16 *)
50+ { window_log = 22; chain_log = 20; hash_log = 19; search_log = 10; min_match = 4; target_len = 512; strategy = 2 };
51+ (* Level 17 *)
52+ { window_log = 22; chain_log = 21; hash_log = 19; search_log = 11; min_match = 4; target_len = 512; strategy = 2 };
53+ (* Level 18 *)
54+ { window_log = 22; chain_log = 21; hash_log = 20; search_log = 12; min_match = 4; target_len = 512; strategy = 2 };
55+ (* Level 19 *)
56+ { window_log = 23; chain_log = 22; hash_log = 20; search_log = 12; min_match = 4; target_len = 1024; strategy = 2 };
57+|]
58+59+let get_level_params level =
60+ let level = max 1 (min level 19) in
61+ level_params.(level)
62+63+(** A sequence represents a literal run + match *)
64+type sequence = {
65+ lit_length : int;
66+ match_offset : int;
67+ match_length : int;
68+}
69+70+(** Hash table for fast match finding *)
71+type hash_table = {
72+ table : int array; (* Position indexed by hash *)
73+ chain : int array; (* Chain of previous matches at same hash *)
74+ mask : int;
75+}
76+77+let create_hash_table log_size =
78+ let size = 1 lsl log_size in
79+ {
80+ table = Array.make size (-1);
81+ chain = Array.make (1 lsl 20) (-1); (* Max input size *)
82+ mask = size - 1;
83+ }
84+85+(** Compute hash of 4 bytes *)
86+let[@inline] hash4 src pos =
87+ let v = Bytes.get_int32_le src pos in
88+ (* MurmurHash3-like mixing *)
89+ let h = Int32.to_int (Int32.mul v 0xcc9e2d51l) in
90+ (h lxor (h lsr 15))
91+92+(** Check if positions match and return length *)
93+let match_length src pos1 pos2 limit =
94+ let len = ref 0 in
95+ let max_len = min (limit - pos1) (pos1 - pos2) in
96+ while !len < max_len &&
97+ Bytes.get_uint8 src (pos1 + !len) = Bytes.get_uint8 src (pos2 + !len) do
98+ incr len
99+ done;
100+ !len
101+102+(** Find best match at current position *)
103+let find_best_match ht src pos limit params =
104+ if pos + 4 > limit then
105+ (0, 0)
106+ else begin
107+ let h = hash4 src pos land ht.mask in
108+ let prev_pos = ht.table.(h) in
109+110+ (* Update hash table *)
111+ ht.chain.(pos) <- prev_pos;
112+ ht.table.(h) <- pos;
113+114+ if prev_pos < 0 || pos - prev_pos > (1 lsl params.window_log) then
115+ (0, 0)
116+ else begin
117+ (* Search chain for best match *)
118+ let best_offset = ref 0 in
119+ let best_length = ref 0 in
120+ let chain_pos = ref prev_pos in
121+ let searches = ref 0 in
122+ let max_searches = 1 lsl params.search_log in
123+124+ while !chain_pos >= 0 && !searches < max_searches do
125+ let offset = pos - !chain_pos in
126+ if offset > (1 lsl params.window_log) then
127+ chain_pos := -1
128+ else begin
129+ let len = match_length src pos !chain_pos limit in
130+ if len >= params.min_match && len > !best_length then begin
131+ best_length := len;
132+ best_offset := offset
133+ end;
134+ chain_pos := ht.chain.(!chain_pos);
135+ incr searches
136+ end
137+ done;
138+139+ (!best_offset, !best_length)
140+ end
141+ end
142+143+(** Parse input into sequences using greedy/lazy matching *)
144+let parse_sequences src ~pos ~len params =
145+ let sequences = ref [] in
146+ let cur_pos = ref pos in
147+ let limit = pos + len in
148+ let lit_start = ref pos in
149+150+ let ht = create_hash_table params.hash_log in
151+152+ while !cur_pos + 4 <= limit do
153+ let (offset, length) = find_best_match ht src !cur_pos limit params in
154+155+ if length >= params.min_match then begin
156+ (* Emit sequence *)
157+ let lit_len = !cur_pos - !lit_start in
158+ sequences := { lit_length = lit_len; match_offset = offset; match_length = length } :: !sequences;
159+160+ (* Update hash table for matched positions *)
161+ for i = !cur_pos + 1 to !cur_pos + length - 1 do
162+ if i + 4 <= limit then begin
163+ let h = hash4 src i land ht.mask in
164+ ht.chain.(i) <- ht.table.(h);
165+ ht.table.(h) <- i
166+ end
167+ done;
168+169+ cur_pos := !cur_pos + length;
170+ lit_start := !cur_pos
171+ end else begin
172+ incr cur_pos
173+ end
174+ done;
175+176+ (* Handle remaining literals *)
177+ let remaining = limit - !lit_start in
178+ if remaining > 0 || !sequences = [] then
179+ sequences := { lit_length = remaining; match_offset = 0; match_length = 0 } :: !sequences;
180+181+ List.rev !sequences
182+183+(** Encode literal length code *)
184+let encode_lit_length_code lit_len =
185+ if lit_len < 16 then
186+ (lit_len, 0, 0)
187+ else if lit_len < 64 then
188+ (16 + (lit_len - 16) / 4, (lit_len - 16) mod 4, 2)
189+ else if lit_len < 128 then
190+ (28 + (lit_len - 64) / 8, (lit_len - 64) mod 8, 3)
191+ else begin
192+ (* Use baseline tables for larger values *)
193+ let rec find_code code =
194+ if code >= 35 then (35, lit_len - Constants.ll_baselines.(35), Constants.ll_extra_bits.(35))
195+ else if lit_len < Constants.ll_baselines.(code + 1) then
196+ (code, lit_len - Constants.ll_baselines.(code), Constants.ll_extra_bits.(code))
197+ else find_code (code + 1)
198+ in
199+ find_code 16
200+ end
201+202+(** Minimum match length for zstd *)
203+let min_match = 3
204+205+(** Encode match length code *)
206+let encode_match_length_code match_len =
207+ let ml = match_len - min_match in
208+ if ml < 32 then
209+ (ml, 0, 0)
210+ else if ml < 64 then
211+ (32 + (ml - 32) / 2, (ml - 32) mod 2, 1)
212+ else begin
213+ let rec find_code code =
214+ if code >= 52 then (52, ml - Constants.ml_baselines.(52) + 3, Constants.ml_extra_bits.(52))
215+ else if ml < Constants.ml_baselines.(code + 1) - 3 then
216+ (code, ml - Constants.ml_baselines.(code) + 3, Constants.ml_extra_bits.(code))
217+ else find_code (code + 1)
218+ in
219+ find_code 32
220+ end
221+222+(** Encode offset code.
223+ Returns (of_code, extra_value, extra_bits).
224+225+ Repeat offsets use offBase 1,2,3:
226+ - offBase=1: ofCode=0, no extra bits
227+ - offBase=2: ofCode=1, extra=0 (1 bit)
228+ - offBase=3: ofCode=1, extra=1 (1 bit)
229+230+ Real offsets use offBase = offset + 3:
231+ - ofCode = highbit(offBase)
232+ - extra = lower ofCode bits of offBase *)
233+let encode_offset_code offset offset_history =
234+ let off_base =
235+ if offset = offset_history.(0) then 1
236+ else if offset = offset_history.(1) then 2
237+ else if offset = offset_history.(2) then 3
238+ else offset + 3
239+ in
240+ let of_code = Fse.highest_set_bit off_base in
241+ let extra = off_base land ((1 lsl of_code) - 1) in
242+ (of_code, extra, of_code)
243+244+(** Write raw literals section *)
245+let write_raw_literals literals ~pos ~len output ~out_pos =
246+ if len = 0 then begin
247+ (* Empty literals: single-byte header with type=0, size=0 *)
248+ Bytes.set_uint8 output out_pos 0;
249+ 1
250+ end else if len < 32 then begin
251+ (* Raw literals, single stream, 1-byte header *)
252+ (* Header: type=0 (raw), size_format=0 (5-bit), regen_size in bits 3-7 *)
253+ let header = 0b00 lor ((len land 0x1f) lsl 3) in
254+ Bytes.set_uint8 output out_pos header;
255+ Bytes.blit literals pos output (out_pos + 1) len;
256+ 1 + len
257+ end else if len < 4096 then begin
258+ (* Raw literals, 2-byte header *)
259+ (* type=0 (bits 0-1), size_format=1 (bits 2-3), size in bits 4-15 *)
260+ let header = 0b0100 lor ((len land 0x0fff) lsl 4) in
261+ Bytes.set_uint16_le output out_pos header;
262+ Bytes.blit literals pos output (out_pos + 2) len;
263+ 2 + len
264+ end else begin
265+ (* Raw literals, 3-byte header *)
266+ (* type=0 (bits 0-1), size_format=2 (bits 2-3), size in bits 4-17 (14 bits) *)
267+ let header = 0b1000 lor ((len land 0x3fff) lsl 4) in
268+ Bytes.set_uint8 output out_pos (header land 0xff);
269+ Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
270+ Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
271+ Bytes.blit literals pos output (out_pos + 3) len;
272+ 3 + len
273+ end
274+275+(** Write compressed literals with Huffman encoding *)
276+let write_compressed_literals literals ~pos ~len output ~out_pos =
277+ if len < 32 then
278+ (* Too small for Huffman, use raw *)
279+ write_raw_literals literals ~pos ~len output ~out_pos
280+ else begin
281+ (* Count symbol frequencies *)
282+ let counts = Array.make 256 0 in
283+ for i = pos to pos + len - 1 do
284+ let c = Bytes.get_uint8 literals i in
285+ counts.(c) <- counts.(c) + 1
286+ done;
287+288+ (* Find max symbol used *)
289+ let max_symbol = ref 0 in
290+ for i = 0 to 255 do
291+ if counts.(i) > 0 then max_symbol := i
292+ done;
293+294+ (* Build Huffman table *)
295+ let ctable = Huffman.build_ctable counts !max_symbol Constants.max_huffman_bits in
296+297+ if ctable.num_symbols = 0 then
298+ write_raw_literals literals ~pos ~len output ~out_pos
299+ else begin
300+ (* Decide single vs 4-stream based on size *)
301+ let use_4streams = len >= 256 in
302+303+ (* Write Huffman table header to temp buffer *)
304+ let header_buf = Bytes.create 256 in
305+ let header_stream = Bit_writer.Forward.of_bytes header_buf in
306+ let _num_written = Huffman.write_header header_stream ctable in
307+ let header_size = Bit_writer.Forward.byte_position header_stream in
308+309+ (* Compress literals *)
310+ let compressed =
311+ if use_4streams then
312+ Huffman.compress_4stream ctable literals ~pos ~len
313+ else
314+ Huffman.compress_1stream ctable literals ~pos ~len
315+ in
316+ let compressed_size = Bytes.length compressed in
317+318+ (* Check if compression is worthwhile (should save at least 10%) *)
319+ let total_compressed_size = header_size + compressed_size in
320+ if total_compressed_size >= len - len / 10 then
321+ write_raw_literals literals ~pos ~len output ~out_pos
322+ else begin
323+ (* Write compressed literals header *)
324+ (* Type: 2 = compressed, size_format based on sizes *)
325+ let regen_size = len in
326+ let lit_type = 2 in (* Compressed_literals *)
327+328+ let header_pos = ref out_pos in
329+ if regen_size < 1024 && total_compressed_size < 1024 then begin
330+ (* 3-byte header: type(2) + size_format(2) + regen(10) + compressed(10) + streams(2) *)
331+ let size_format = 0 in
332+ let streams_flag = if use_4streams then 3 else 0 in
333+ let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in
334+ let h1 = ((regen_size lsr 6) land 0xf) lor ((total_compressed_size land 0xf) lsl 4) in
335+ let h2 = (total_compressed_size lsr 4) land 0xff in
336+ Bytes.set_uint8 output !header_pos h0;
337+ Bytes.set_uint8 output (!header_pos + 1) h1;
338+ Bytes.set_uint8 output (!header_pos + 2) h2;
339+ header_pos := !header_pos + 3
340+ end else begin
341+ (* 5-byte header for larger sizes *)
342+ let size_format = 1 in
343+ let streams_flag = if use_4streams then 3 else 0 in
344+ let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in
345+ Bytes.set_uint8 output !header_pos h0;
346+ Bytes.set_uint16_le output (!header_pos + 1) (((regen_size lsr 6) land 0x3fff) lor ((total_compressed_size land 0x3) lsl 14));
347+ Bytes.set_uint16_le output (!header_pos + 3) ((total_compressed_size lsr 2) land 0xffff);
348+ header_pos := !header_pos + 5
349+ end;
350+351+ (* Write Huffman table *)
352+ Bytes.blit header_buf 0 output !header_pos header_size;
353+ header_pos := !header_pos + header_size;
354+355+ (* Write compressed streams *)
356+ Bytes.blit compressed 0 output !header_pos compressed_size;
357+358+ !header_pos + compressed_size - out_pos
359+ end
360+ end
361+ end
362+363+(** Compress literals - try Huffman, fall back to raw *)
364+let compress_literals literals ~pos ~len output ~out_pos =
365+ write_compressed_literals literals ~pos ~len output ~out_pos
366+367+(** Build predefined FSE compression tables *)
368+let ll_ctable = lazy (Fse.build_predefined_ctable Constants.ll_default_distribution Constants.ll_default_accuracy_log)
369+let ml_ctable = lazy (Fse.build_predefined_ctable Constants.ml_default_distribution Constants.ml_default_accuracy_log)
370+let of_ctable = lazy (Fse.build_predefined_ctable Constants.of_default_distribution Constants.of_default_accuracy_log)
371+372+(** Compress sequences section using predefined FSE tables.
373+ This implements proper zstd sequence encoding following RFC 8878.
374+375+ Matches C zstd's ZSTD_encodeSequences_body exactly:
376+ 1. Initialize states with FSE_initCState2 using LAST sequence's codes
377+ 2. Write LAST sequence's extra bits (LL, ML, OF order)
378+ 3. For sequences n-2 down to 0:
379+ - FSE_encodeSymbol for OF, ML, LL
380+ - Extra bits for LL, ML, OF
381+ 4. FSE_flushCState for ML, OF, LL
382+*)
383+let compress_sequences sequences output ~out_pos offset_history =
384+ if sequences = [] then begin
385+ (* Zero sequences *)
386+ Bytes.set_uint8 output out_pos 0;
387+ 1
388+ end else begin
389+ let num_seq = List.length sequences in
390+ let header_size = ref 0 in
391+392+ (* Write sequence count (1-3 bytes) *)
393+ if num_seq < 128 then begin
394+ Bytes.set_uint8 output out_pos num_seq;
395+ header_size := 1
396+ end else if num_seq < 0x7f00 then begin
397+ Bytes.set_uint8 output out_pos ((num_seq lsr 8) + 128);
398+ Bytes.set_uint8 output (out_pos + 1) (num_seq land 0xff);
399+ header_size := 2
400+ end else begin
401+ Bytes.set_uint8 output out_pos 0xff;
402+ Bytes.set_uint16_le output (out_pos + 1) (num_seq - 0x7f00);
403+ header_size := 3
404+ end;
405+406+ (* Symbol compression modes byte:
407+ bits 0-1: Literals_Lengths_Mode (0 = predefined)
408+ bits 2-3: Offsets_Mode (0 = predefined)
409+ bits 4-5: Match_Lengths_Mode (0 = predefined)
410+ bits 6-7: reserved *)
411+ Bytes.set_uint8 output (out_pos + !header_size) 0b00;
412+ incr header_size;
413+414+ (* Get predefined FSE tables *)
415+ let ll_ct = Lazy.force ll_ctable in
416+ let ml_ct = Lazy.force ml_ctable in
417+ let of_ct = Lazy.force of_ctable in
418+419+ let offset_hist = Array.copy offset_history in
420+ let seq_array = Array.of_list sequences in
421+422+ (* Encode all sequences in forward order to track offset history *)
423+ let encoded = Array.map (fun seq ->
424+ let (ll_code, ll_extra, ll_extra_bits) = encode_lit_length_code seq.lit_length in
425+ let (ml_code, ml_extra, ml_extra_bits) = encode_match_length_code seq.match_length in
426+ let (of_code, of_extra, of_extra_bits) = encode_offset_code seq.match_offset offset_hist in
427+428+ (* Update offset history for real offsets (of_code > 1 means offBase > 2) *)
429+ if seq.match_offset > 0 && of_code > 1 then begin
430+ offset_hist.(2) <- offset_hist.(1);
431+ offset_hist.(1) <- offset_hist.(0);
432+ offset_hist.(0) <- seq.match_offset
433+ end;
434+435+ (ll_code, ll_extra, ll_extra_bits, ml_code, ml_extra, ml_extra_bits, of_code, of_extra, of_extra_bits)
436+ ) seq_array in
437+438+ (* Use a backward bit writer *)
439+ let stream = Bit_writer.Backward.create (num_seq * 20 + 32) in
440+441+ (* Get last sequence's codes for state initialization *)
442+ let last_idx = num_seq - 1 in
443+ let (ll_code_last, ll_extra_last, ll_extra_bits_last,
444+ ml_code_last, ml_extra_last, ml_extra_bits_last,
445+ of_code_last, of_extra_last, of_extra_bits_last) = encoded.(last_idx) in
446+447+ (* Initialize FSE states with LAST sequence's codes *)
448+ let ll_state = Fse.init_cstate2 ll_ct ll_code_last in
449+ let ml_state = Fse.init_cstate2 ml_ct ml_code_last in
450+ let of_state = Fse.init_cstate2 of_ct of_code_last in
451+452+ (* Write LAST sequence's extra bits first (LL, ML, OF order) *)
453+ if ll_extra_bits_last > 0 then
454+ Bit_writer.Backward.write_bits stream ll_extra_last ll_extra_bits_last;
455+ if ml_extra_bits_last > 0 then
456+ Bit_writer.Backward.write_bits stream ml_extra_last ml_extra_bits_last;
457+ if of_extra_bits_last > 0 then
458+ Bit_writer.Backward.write_bits stream of_extra_last of_extra_bits_last;
459+460+ (* Process sequences from n-2 down to 0 *)
461+ for i = last_idx - 1 downto 0 do
462+ let (ll_code, ll_extra, ll_extra_bits,
463+ ml_code, ml_extra, ml_extra_bits,
464+ of_code, of_extra, of_extra_bits) = encoded.(i) in
465+466+ (* FSE encode: OF, ML, LL order *)
467+ Fse.encode_symbol stream of_state of_code;
468+ Fse.encode_symbol stream ml_state ml_code;
469+ Fse.encode_symbol stream ll_state ll_code;
470+471+ (* Extra bits: LL, ML, OF order *)
472+ if ll_extra_bits > 0 then
473+ Bit_writer.Backward.write_bits stream ll_extra ll_extra_bits;
474+ if ml_extra_bits > 0 then
475+ Bit_writer.Backward.write_bits stream ml_extra ml_extra_bits;
476+ if of_extra_bits > 0 then
477+ Bit_writer.Backward.write_bits stream of_extra of_extra_bits
478+ done;
479+480+ (* Flush states: ML, OF, LL order *)
481+ Fse.flush_cstate stream ml_state;
482+ Fse.flush_cstate stream of_state;
483+ Fse.flush_cstate stream ll_state;
484+485+ (* Finalize and copy to output *)
486+ let seq_data = Bit_writer.Backward.finalize stream in
487+ let seq_len = Bytes.length seq_data in
488+ Bytes.blit seq_data 0 output (out_pos + !header_size) seq_len;
489+490+ !header_size + seq_len
491+ end
492+493+(** Write raw block (no compression) *)
494+let write_raw_block src ~pos ~len output ~out_pos =
495+ (* Raw block: header (3 bytes) + raw data
496+ Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = block_size
497+ For raw: block_type = 0, block_size = number of bytes *)
498+ let header = (Constants.block_raw lsl 1) lor ((len land 0x1fffff) lsl 3) in
499+ Bytes.set_uint8 output out_pos (header land 0xff);
500+ Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
501+ Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
502+ Bytes.blit src pos output (out_pos + 3) len;
503+ 3 + len
504+505+(** Write compressed block with sequences *)
506+let write_compressed_block src ~pos ~len sequences output ~out_pos offset_history =
507+ (* Collect all literals *)
508+ let total_lit_len = List.fold_left (fun acc seq -> acc + seq.lit_length) 0 sequences in
509+ let literals = Bytes.create total_lit_len in
510+ let lit_pos = ref 0 in
511+ let src_pos = ref pos in
512+ List.iter (fun seq ->
513+ if seq.lit_length > 0 then begin
514+ Bytes.blit src !src_pos literals !lit_pos seq.lit_length;
515+ lit_pos := !lit_pos + seq.lit_length;
516+ src_pos := !src_pos + seq.lit_length
517+ end;
518+ src_pos := !src_pos + seq.match_length
519+ ) sequences;
520+521+ (* Build block content in temp buffer *)
522+ let block_buf = Bytes.create (len * 2 + 256) in
523+ let block_pos = ref 0 in
524+525+ (* Write literals section *)
526+ let lit_size = compress_literals literals ~pos:0 ~len:total_lit_len block_buf ~out_pos:!block_pos in
527+ block_pos := !block_pos + lit_size;
528+529+ (* Filter out sequences with only literals (match_length = 0 and match_offset = 0)
530+ at the end - the last sequence can be literal-only *)
531+ let real_sequences = List.filter (fun seq ->
532+ seq.match_length > 0 || seq.match_offset > 0
533+ ) sequences in
534+535+ (* Write sequences section *)
536+ let seq_size = compress_sequences real_sequences block_buf ~out_pos:!block_pos offset_history in
537+ block_pos := !block_pos + seq_size;
538+539+ let block_size = !block_pos in
540+541+ (* Check if compressed block is actually smaller *)
542+ if block_size >= len then begin
543+ (* Fall back to raw block *)
544+ write_raw_block src ~pos ~len output ~out_pos
545+ end else begin
546+ (* Write compressed block header *)
547+ let header = (Constants.block_compressed lsl 1) lor ((block_size land 0x1fffff) lsl 3) in
548+ Bytes.set_uint8 output out_pos (header land 0xff);
549+ Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
550+ Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
551+ Bytes.blit block_buf 0 output (out_pos + 3) block_size;
552+ 3 + block_size
553+ end
554+555+(** Write RLE block (single byte repeated) *)
556+let write_rle_block byte len output ~out_pos =
557+ (* RLE block: header (3 bytes) + single byte
558+ Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = regen_size
559+ For RLE: block_type = 1, regen_size = number of bytes when expanded *)
560+ let header = (Constants.block_rle lsl 1) lor ((len land 0x1fffff) lsl 3) in
561+ Bytes.set_uint8 output out_pos (header land 0xff);
562+ Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff);
563+ Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff);
564+ Bytes.set_uint8 output (out_pos + 3) byte;
565+ 4
566+567+(** Check if block is all same byte *)
568+let is_rle_block src ~pos ~len =
569+ if len = 0 then None
570+ else begin
571+ let first = Bytes.get_uint8 src pos in
572+ let all_same = ref true in
573+ for i = pos + 1 to pos + len - 1 do
574+ if Bytes.get_uint8 src i <> first then all_same := false
575+ done;
576+ if !all_same then Some first else None
577+ end
578+579+(** Compress a single block using LZ77 + FSE + Huffman.
580+ Falls back to RLE for repetitive data, or raw blocks if compression doesn't help. *)
581+let compress_block src ~pos ~len output ~out_pos params offset_history =
582+ if len = 0 then
583+ 0
584+ else
585+ (* Check for RLE opportunity (all same byte) *)
586+ match is_rle_block src ~pos ~len with
587+ | Some byte when len > 4 ->
588+ (* RLE is worthwhile: 4 bytes instead of len+3 *)
589+ write_rle_block byte len output ~out_pos
590+ | _ ->
591+ (* Try LZ77 + FSE compression for compressible data *)
592+ let sequences = parse_sequences src ~pos ~len params in
593+ let match_count = List.fold_left (fun acc s ->
594+ if s.match_length > 0 then acc + 1 else acc) 0 sequences in
595+ (* Use compressed blocks for compressible data. The backward bitstream
596+ writer now uses periodic flushing like C zstd, supporting any size. *)
597+ if match_count >= 2 && len >= 64 then
598+ write_compressed_block src ~pos ~len sequences output ~out_pos offset_history
599+ else
600+ write_raw_block src ~pos ~len output ~out_pos
601+602+(** Write frame header *)
603+let write_frame_header output ~pos content_size window_log checksum_flag =
604+ (* Magic number *)
605+ Bytes.set_int32_le output pos Constants.zstd_magic_number;
606+ let out_pos = ref (pos + 4) in
607+608+ (* Use single segment mode for smaller content (no window descriptor needed).
609+ FCS field sizes when single_segment is set:
610+ - fcs_flag=0: 1 byte (content size 0-255)
611+ - fcs_flag=1: 2 bytes (content size 256-65791, stored with -256)
612+ - fcs_flag=2: 4 bytes
613+ - fcs_flag=3: 8 bytes *)
614+ let single_segment = content_size <= 131072L in
615+616+ let (fcs_flag, fcs_bytes) =
617+ if single_segment then begin
618+ if content_size <= 255L then (0, 1)
619+ else if content_size <= 65791L then (1, 2) (* 2-byte has +256 offset *)
620+ else if content_size <= 0xFFFFFFFFL then (2, 4)
621+ else (3, 8)
622+ end else begin
623+ (* For non-single-segment, fcs_flag=0 means no FCS field *)
624+ if content_size = 0L then (0, 0)
625+ else if content_size <= 65535L then (1, 2)
626+ else if content_size <= 0xFFFFFFFFL then (2, 4)
627+ else (3, 8)
628+ end
629+ in
630+631+ (* Frame header descriptor:
632+ bit 0-1: dict ID flag (0 = no dict)
633+ bit 2: content checksum flag
634+ bit 3: reserved
635+ bit 4: unused
636+ bit 5: single segment (no window descriptor)
637+ bit 6-7: FCS field size flag *)
638+ let descriptor =
639+ (if checksum_flag then 0b00000100 else 0)
640+ lor (if single_segment then 0b00100000 else 0)
641+ lor (fcs_flag lsl 6)
642+ in
643+ Bytes.set_uint8 output !out_pos descriptor;
644+ incr out_pos;
645+646+ (* Window descriptor (only if not single segment) *)
647+ if not single_segment then begin
648+ let window_desc = ((window_log - 10) lsl 3) in
649+ Bytes.set_uint8 output !out_pos window_desc;
650+ incr out_pos
651+ end;
652+653+ (* Frame content size *)
654+ begin match fcs_bytes with
655+ | 1 ->
656+ Bytes.set_uint8 output !out_pos (Int64.to_int content_size);
657+ out_pos := !out_pos + 1
658+ | 2 ->
659+ (* 2-byte FCS stores value - 256 *)
660+ let adjusted = Int64.sub content_size 256L in
661+ Bytes.set_uint16_le output !out_pos (Int64.to_int adjusted);
662+ out_pos := !out_pos + 2
663+ | 4 ->
664+ Bytes.set_int32_le output !out_pos (Int64.to_int32 content_size);
665+ out_pos := !out_pos + 4
666+ | 8 ->
667+ Bytes.set_int64_le output !out_pos content_size;
668+ out_pos := !out_pos + 8
669+ | _ -> ()
670+ end;
671+672+ !out_pos - pos
673+674+(** Compress data to zstd frame *)
675+let compress ?(level = 3) ?(checksum = true) src =
676+ let src = Bytes.of_string src in
677+ let len = Bytes.length src in
678+ let params = get_level_params level in
679+680+ (* Allocate output buffer - worst case is slightly larger than input *)
681+ let max_output = len + len / 128 + 256 in
682+ let output = Bytes.create max_output in
683+684+ (* Initialize offset history *)
685+ let offset_history = Array.copy Constants.initial_repeat_offsets in
686+687+ (* Write frame header *)
688+ let header_size = write_frame_header output ~pos:0 (Int64.of_int len) params.window_log checksum in
689+ let out_pos = ref header_size in
690+691+ (* Compress blocks *)
692+ if len = 0 then begin
693+ (* Empty content: write an empty raw block with last_block flag *)
694+ (* Block header: last_block=1, block_type=raw(0), block_size=0 *)
695+ (* Header = 1 | (0 << 1) | (0 << 3) = 0x01 *)
696+ Bytes.set_uint8 output !out_pos 0x01;
697+ Bytes.set_uint8 output (!out_pos + 1) 0x00;
698+ Bytes.set_uint8 output (!out_pos + 2) 0x00;
699+ out_pos := !out_pos + 3
700+ end else begin
701+ let block_size = min len Constants.block_size_max in
702+ let pos = ref 0 in
703+704+ while !pos < len do
705+ let this_block = min block_size (len - !pos) in
706+ let is_last = !pos + this_block >= len in
707+708+ let block_len = compress_block src ~pos:!pos ~len:this_block output ~out_pos:!out_pos params offset_history in
709+710+ (* Set last block flag *)
711+ if is_last then begin
712+ let current = Bytes.get_uint8 output !out_pos in
713+ Bytes.set_uint8 output !out_pos (current lor 0x01)
714+ end;
715+716+ out_pos := !out_pos + block_len;
717+ pos := !pos + this_block
718+ done
719+ end;
720+721+ (* Write checksum if requested *)
722+ if checksum then begin
723+ let hash = Xxhash.hash64 src ~pos:0 ~len in
724+ (* Write only lower 32 bits *)
725+ Bytes.set_int32_le output !out_pos (Int64.to_int32 hash);
726+ out_pos := !out_pos + 4
727+ end;
728+729+ Bytes.sub_string output 0 !out_pos
730+731+(** Calculate maximum compressed size *)
732+let compress_bound len =
733+ len + len / 128 + 256
734+735+(** Write a skippable frame.
736+ @param variant Magic number variant 0-15
737+ @param content The content to embed in the skippable frame
738+ @return The complete skippable frame as a string *)
739+let write_skippable_frame ?(variant = 0) content =
740+ let variant = max 0 (min 15 variant) in
741+ let len = String.length content in
742+ if len > 0xFFFFFFFF then
743+ invalid_arg "Skippable frame content too large (max 4GB)";
744+ let output = Bytes.create (Constants.skippable_header_size + len) in
745+ (* Magic number: 0x184D2A50 + variant *)
746+ let magic = Int32.add Constants.skippable_magic_start (Int32.of_int variant) in
747+ Bytes.set_int32_le output 0 magic;
748+ (* Content size (4 bytes little-endian) *)
749+ Bytes.set_int32_le output 4 (Int32.of_int len);
750+ (* Content *)
751+ Bytes.blit_string content 0 output 8 len;
752+ Bytes.unsafe_to_string output
+5
test-interop/dune
···00000
···1+; Test: Verify pure OCaml can decompress C-compressed data
2+; and C zstd can decompress pure OCaml compressed data
3+(test
4+ (name test_interop)
5+ (libraries zstd alcotest))