Bitlevel streams for OCaml
1(** Bitstream - Bit-level I/O for binary formats.
2
3 Provides forward and backward bitstream reading and writing for parsing
4 and generating binary formats that operate at the bit level.
5
6 Forward streams read/write from the start of a buffer towards the end.
7 Backward streams read/write from the end of a buffer towards the start,
8 which is required by some compression algorithms (FSE, ANS). *)
9
10(** {1 Slice Type} *)
11
12module Slice = struct
13 type t = {
14 bytes : bytes;
15 first : int;
16 length : int;
17 }
18
19 let make bytes ~first ~length =
20 { bytes; first; length }
21
22 let of_bytes ?first ?length bytes =
23 let first = Option.value first ~default:0 in
24 let length = Option.value length ~default:(Bytes.length bytes - first) in
25 { bytes; first; length }
26
27 let to_bytes t =
28 Bytes.sub t.bytes t.first t.length
29
30 let is_empty t =
31 t.length = 0
32
33 let sub t ~first ~length =
34 { bytes = t.bytes; first = t.first + first; length }
35end
36
37(** {1 Exceptions} *)
38
39exception End_of_stream
40(** Raised when attempting to read past the end of the stream. *)
41
42exception Invalid_state of string
43(** Raised when an operation requires a specific state (e.g., byte alignment). *)
44
45exception Corrupted_stream of string
46(** Raised when stream data is malformed (e.g., invalid padding marker). *)
47
48(** {1 Forward Bitstream Reader} *)
49
50module Forward_reader = struct
51 type t = {
52 src : bytes;
53 start_pos : int;
54 limit : int;
55 mutable byte_pos : int;
56 mutable bit_pos : int; (* 0-7, bits consumed in current byte *)
57 }
58
59 let of_slice (slice : Slice.t) =
60 { src = slice.bytes;
61 start_pos = slice.first;
62 limit = slice.first + slice.length;
63 byte_pos = slice.first;
64 bit_pos = 0 }
65
66 let of_bytes src =
67 of_slice (Slice.of_bytes src)
68
69 let create src ~pos ~len =
70 of_slice (Slice.make src ~first:pos ~length:len)
71
72 let[@inline] remaining t =
73 (t.limit - t.byte_pos) * 8 - t.bit_pos
74
75 let[@inline] is_byte_aligned t =
76 t.bit_pos = 0
77
78 let[@inline] read_bits t n =
79 if n <= 0 then 0
80 else if n > 57 then invalid_arg "read_bits: n > 57"
81 else begin
82 let result = ref 0 in
83 let bits_read = ref 0 in
84 while !bits_read < n do
85 if t.byte_pos >= t.limit then
86 raise End_of_stream;
87 let byte = Bytes.get_uint8 t.src t.byte_pos in
88 let available = 8 - t.bit_pos in
89 let to_read = min available (n - !bits_read) in
90 let mask = (1 lsl to_read) - 1 in
91 let bits = (byte lsr t.bit_pos) land mask in
92 result := !result lor (bits lsl !bits_read);
93 bits_read := !bits_read + to_read;
94 t.bit_pos <- t.bit_pos + to_read;
95 if t.bit_pos >= 8 then begin
96 t.bit_pos <- 0;
97 t.byte_pos <- t.byte_pos + 1
98 end
99 done;
100 !result
101 end
102
103 let[@inline] read_byte t =
104 if t.bit_pos <> 0 then
105 raise (Invalid_state "read_byte: not byte aligned");
106 if t.byte_pos >= t.limit then
107 raise End_of_stream;
108 let b = Bytes.get_uint8 t.src t.byte_pos in
109 t.byte_pos <- t.byte_pos + 1;
110 b
111
112 let rewind_bits t n =
113 let total_bits = (t.byte_pos - t.start_pos) * 8 + t.bit_pos in
114 let new_total = total_bits - n in
115 if new_total < 0 then
116 raise End_of_stream;
117 t.byte_pos <- t.start_pos + new_total / 8;
118 t.bit_pos <- new_total mod 8
119
120 let align t =
121 if t.bit_pos <> 0 then begin
122 t.bit_pos <- 0;
123 t.byte_pos <- t.byte_pos + 1
124 end
125
126 let byte_position t =
127 if t.bit_pos <> 0 then
128 raise (Invalid_state "byte_position: not byte aligned");
129 t.byte_pos
130
131 let get_slice t n : Slice.t =
132 if t.bit_pos <> 0 then
133 raise (Invalid_state "get_slice: not byte aligned");
134 if t.byte_pos + n > t.limit then
135 raise End_of_stream;
136 let result = Slice.make t.src ~first:t.byte_pos ~length:n in
137 t.byte_pos <- t.byte_pos + n;
138 result
139
140 let get_bytes t n =
141 Slice.to_bytes (get_slice t n)
142
143 let to_slice t : Slice.t =
144 if t.bit_pos <> 0 then
145 raise (Invalid_state "to_slice: not byte aligned");
146 Slice.make t.src ~first:t.byte_pos ~length:(t.limit - t.byte_pos)
147
148 let advance t n =
149 if t.bit_pos <> 0 then
150 raise (Invalid_state "advance: not byte aligned");
151 if t.byte_pos + n > t.limit then
152 raise End_of_stream;
153 t.byte_pos <- t.byte_pos + n
154
155 let sub t n =
156 if t.bit_pos <> 0 then
157 raise (Invalid_state "sub: not byte aligned");
158 if t.byte_pos + n > t.limit then
159 raise End_of_stream;
160 let result = of_slice (Slice.make t.src ~first:t.byte_pos ~length:n) in
161 t.byte_pos <- t.byte_pos + n;
162 result
163
164 let remaining_bytes t =
165 if t.bit_pos <> 0 then
166 raise (Invalid_state "remaining_bytes: not byte aligned");
167 t.limit - t.byte_pos
168
169 let skip_bits t n =
170 ignore (read_bits t n)
171end
172
173(** {1 Backward Bitstream Reader}
174
175 Reads bits from the end of a buffer towards the start. The stream
176 starts with a padding marker (highest 1-bit indicates start of data). *)
177
178module Backward_reader = struct
179 type t = {
180 src : bytes;
181 start_pos : int;
182 mutable bit_offset : int; (* Bits remaining from end, decreasing *)
183 }
184
185 let of_slice (slice : Slice.t) =
186 if slice.length = 0 then
187 raise End_of_stream;
188 let last_byte_pos = slice.first + slice.length - 1 in
189 let last_byte = Bytes.get_uint8 slice.bytes last_byte_pos in
190 if last_byte = 0 then
191 raise (Corrupted_stream "invalid padding marker");
192 (* Find the highest set bit - this is the padding marker *)
193 let rec find_marker byte bit =
194 if bit < 0 then 0
195 else if (byte land (1 lsl bit)) <> 0 then bit
196 else find_marker byte (bit - 1)
197 in
198 let padding = 8 - find_marker last_byte 7 in
199 let bit_offset = slice.length * 8 - padding in
200 { src = slice.bytes; start_pos = slice.first; bit_offset }
201
202 let of_bytes src ~pos ~len =
203 of_slice (Slice.make src ~first:pos ~length:len)
204
205 let[@inline] remaining t = t.bit_offset
206
207 let[@inline] is_empty t = t.bit_offset <= 0
208
209 let[@inline] read_bits t n =
210 if n <= 0 then 0
211 else if n > 57 then invalid_arg "read_bits: n > 57"
212 else begin
213 t.bit_offset <- t.bit_offset - n;
214 let actual_offset = max 0 t.bit_offset in
215 let actual_bits = if t.bit_offset < 0 then n + t.bit_offset else n in
216 if actual_bits <= 0 then 0
217 else begin
218 let byte_offset = t.start_pos + (actual_offset / 8) in
219 let bit_offset = actual_offset mod 8 in
220 let result = ref 0 in
221 let bits_read = ref 0 in
222 let current_byte = ref byte_offset in
223 let current_bit = ref bit_offset in
224 while !bits_read < actual_bits do
225 let byte = Bytes.get_uint8 t.src !current_byte in
226 let available = 8 - !current_bit in
227 let to_read = min available (actual_bits - !bits_read) in
228 let mask = (1 lsl to_read) - 1 in
229 let bits = (byte lsr !current_bit) land mask in
230 result := !result lor (bits lsl !bits_read);
231 bits_read := !bits_read + to_read;
232 current_bit := !current_bit + to_read;
233 if !current_bit >= 8 then begin
234 current_bit := 0;
235 incr current_byte
236 end
237 done;
238 (* If we read past the beginning, shift the result *)
239 if t.bit_offset < 0 then
240 !result lsl (-t.bit_offset)
241 else
242 !result
243 end
244 end
245
246 let peek_bits t n =
247 let saved_offset = t.bit_offset in
248 let result = read_bits t n in
249 t.bit_offset <- saved_offset;
250 result
251end
252
253(** {1 Forward Bitstream Writer} *)
254
255module Forward_writer = struct
256 type t = {
257 dst : bytes;
258 start_pos : int;
259 mutable byte_pos : int;
260 mutable bit_pos : int; (* 0-7, bits written in current byte *)
261 mutable current_byte : int;
262 }
263
264 let of_slice (slice : Slice.t) =
265 { dst = slice.bytes;
266 start_pos = slice.first;
267 byte_pos = slice.first;
268 bit_pos = 0;
269 current_byte = 0 }
270
271 let of_bytes dst =
272 of_slice (Slice.of_bytes dst)
273
274 let create dst ~pos =
275 of_slice (Slice.make dst ~first:pos ~length:(Bytes.length dst - pos))
276
277 let flush t =
278 if t.bit_pos > 0 then begin
279 Bytes.set_uint8 t.dst t.byte_pos t.current_byte;
280 t.byte_pos <- t.byte_pos + 1;
281 t.bit_pos <- 0;
282 t.current_byte <- 0
283 end
284
285 let[@inline] write_bits t value n =
286 if n <= 0 then ()
287 else if n > 57 then invalid_arg "write_bits: n > 57"
288 else begin
289 let value = ref value in
290 let remaining = ref n in
291
292 while !remaining > 0 do
293 let available = 8 - t.bit_pos in
294 let to_write = min available !remaining in
295 let mask = (1 lsl to_write) - 1 in
296 t.current_byte <- t.current_byte lor ((!value land mask) lsl t.bit_pos);
297 value := !value lsr to_write;
298 remaining := !remaining - to_write;
299 t.bit_pos <- t.bit_pos + to_write;
300
301 if t.bit_pos = 8 then begin
302 Bytes.set_uint8 t.dst t.byte_pos t.current_byte;
303 t.byte_pos <- t.byte_pos + 1;
304 t.bit_pos <- 0;
305 t.current_byte <- 0
306 end
307 done
308 end
309
310 let write_byte t value =
311 if t.bit_pos <> 0 then flush t;
312 Bytes.set_uint8 t.dst t.byte_pos value;
313 t.byte_pos <- t.byte_pos + 1
314
315 let write_slice t (slice : Slice.t) =
316 if t.bit_pos <> 0 then flush t;
317 Bytes.blit slice.bytes slice.first t.dst t.byte_pos slice.length;
318 t.byte_pos <- t.byte_pos + slice.length
319
320 let write_bytes t src =
321 write_slice t (Slice.of_bytes src)
322
323 let byte_position t =
324 if t.bit_pos > 0 then t.byte_pos + 1 else t.byte_pos
325
326 let finalize t =
327 flush t;
328 t.byte_pos - t.start_pos
329
330 let to_slice t : Slice.t =
331 flush t;
332 Slice.make t.dst ~first:t.start_pos ~length:(t.byte_pos - t.start_pos)
333end
334
335(** {1 Backward Bitstream Writer}
336
337 Accumulates bits to be read backwards. Used for FSE and Huffman encoding. *)
338
339module Backward_writer = struct
340 type t = {
341 mutable bits : int64;
342 mutable num_bits : int;
343 buffer : bytes;
344 mutable buf_pos : int;
345 }
346
347 let create size =
348 { bits = 0L; num_bits = 0; buffer = Bytes.create size; buf_pos = size }
349
350 let[@inline] write_bits t value n =
351 if n > 0 then begin
352 t.bits <- Int64.logor t.bits (Int64.shift_left (Int64.of_int value) t.num_bits);
353 t.num_bits <- t.num_bits + n
354 end
355
356 let flush_bytes t =
357 while t.num_bits >= 8 do
358 t.buf_pos <- t.buf_pos - 1;
359 Bytes.set_uint8 t.buffer t.buf_pos (Int64.to_int (Int64.logand t.bits 0xFFL));
360 t.bits <- Int64.shift_right_logical t.bits 8;
361 t.num_bits <- t.num_bits - 8
362 done
363
364 let finalize_to_slice t : Slice.t =
365 write_bits t 1 1;
366 if t.num_bits mod 8 <> 0 then
367 t.num_bits <- ((t.num_bits + 7) / 8) * 8;
368 flush_bytes t;
369 let len = Bytes.length t.buffer - t.buf_pos in
370 (* Reverse bytes in place so marker ends up at the end *)
371 for i = 0 to len / 2 - 1 do
372 let j = t.buf_pos + i in
373 let k = t.buf_pos + len - 1 - i in
374 let tmp = Bytes.get t.buffer j in
375 Bytes.set t.buffer j (Bytes.get t.buffer k);
376 Bytes.set t.buffer k tmp
377 done;
378 Slice.make t.buffer ~first:t.buf_pos ~length:len
379
380 let finalize t =
381 Slice.to_bytes (finalize_to_slice t)
382
383 let current_size t =
384 Bytes.length t.buffer - t.buf_pos + (t.num_bits + 7) / 8
385end