Bitlevel streams for OCaml
at main 385 lines 11 kB view raw
1(** Bitstream - Bit-level I/O for binary formats. 2 3 Provides forward and backward bitstream reading and writing for parsing 4 and generating binary formats that operate at the bit level. 5 6 Forward streams read/write from the start of a buffer towards the end. 7 Backward streams read/write from the end of a buffer towards the start, 8 which is required by some compression algorithms (FSE, ANS). *) 9 10(** {1 Slice Type} *) 11 12module Slice = struct 13 type t = { 14 bytes : bytes; 15 first : int; 16 length : int; 17 } 18 19 let make bytes ~first ~length = 20 { bytes; first; length } 21 22 let of_bytes ?first ?length bytes = 23 let first = Option.value first ~default:0 in 24 let length = Option.value length ~default:(Bytes.length bytes - first) in 25 { bytes; first; length } 26 27 let to_bytes t = 28 Bytes.sub t.bytes t.first t.length 29 30 let is_empty t = 31 t.length = 0 32 33 let sub t ~first ~length = 34 { bytes = t.bytes; first = t.first + first; length } 35end 36 37(** {1 Exceptions} *) 38 39exception End_of_stream 40(** Raised when attempting to read past the end of the stream. *) 41 42exception Invalid_state of string 43(** Raised when an operation requires a specific state (e.g., byte alignment). *) 44 45exception Corrupted_stream of string 46(** Raised when stream data is malformed (e.g., invalid padding marker). *) 47 48(** {1 Forward Bitstream Reader} *) 49 50module Forward_reader = struct 51 type t = { 52 src : bytes; 53 start_pos : int; 54 limit : int; 55 mutable byte_pos : int; 56 mutable bit_pos : int; (* 0-7, bits consumed in current byte *) 57 } 58 59 let of_slice (slice : Slice.t) = 60 { src = slice.bytes; 61 start_pos = slice.first; 62 limit = slice.first + slice.length; 63 byte_pos = slice.first; 64 bit_pos = 0 } 65 66 let of_bytes src = 67 of_slice (Slice.of_bytes src) 68 69 let create src ~pos ~len = 70 of_slice (Slice.make src ~first:pos ~length:len) 71 72 let[@inline] remaining t = 73 (t.limit - t.byte_pos) * 8 - t.bit_pos 74 75 let[@inline] is_byte_aligned t = 76 t.bit_pos = 0 77 78 let[@inline] read_bits t n = 79 if n <= 0 then 0 80 else if n > 57 then invalid_arg "read_bits: n > 57" 81 else begin 82 let result = ref 0 in 83 let bits_read = ref 0 in 84 while !bits_read < n do 85 if t.byte_pos >= t.limit then 86 raise End_of_stream; 87 let byte = Bytes.get_uint8 t.src t.byte_pos in 88 let available = 8 - t.bit_pos in 89 let to_read = min available (n - !bits_read) in 90 let mask = (1 lsl to_read) - 1 in 91 let bits = (byte lsr t.bit_pos) land mask in 92 result := !result lor (bits lsl !bits_read); 93 bits_read := !bits_read + to_read; 94 t.bit_pos <- t.bit_pos + to_read; 95 if t.bit_pos >= 8 then begin 96 t.bit_pos <- 0; 97 t.byte_pos <- t.byte_pos + 1 98 end 99 done; 100 !result 101 end 102 103 let[@inline] read_byte t = 104 if t.bit_pos <> 0 then 105 raise (Invalid_state "read_byte: not byte aligned"); 106 if t.byte_pos >= t.limit then 107 raise End_of_stream; 108 let b = Bytes.get_uint8 t.src t.byte_pos in 109 t.byte_pos <- t.byte_pos + 1; 110 b 111 112 let rewind_bits t n = 113 let total_bits = (t.byte_pos - t.start_pos) * 8 + t.bit_pos in 114 let new_total = total_bits - n in 115 if new_total < 0 then 116 raise End_of_stream; 117 t.byte_pos <- t.start_pos + new_total / 8; 118 t.bit_pos <- new_total mod 8 119 120 let align t = 121 if t.bit_pos <> 0 then begin 122 t.bit_pos <- 0; 123 t.byte_pos <- t.byte_pos + 1 124 end 125 126 let byte_position t = 127 if t.bit_pos <> 0 then 128 raise (Invalid_state "byte_position: not byte aligned"); 129 t.byte_pos 130 131 let get_slice t n : Slice.t = 132 if t.bit_pos <> 0 then 133 raise (Invalid_state "get_slice: not byte aligned"); 134 if t.byte_pos + n > t.limit then 135 raise End_of_stream; 136 let result = Slice.make t.src ~first:t.byte_pos ~length:n in 137 t.byte_pos <- t.byte_pos + n; 138 result 139 140 let get_bytes t n = 141 Slice.to_bytes (get_slice t n) 142 143 let to_slice t : Slice.t = 144 if t.bit_pos <> 0 then 145 raise (Invalid_state "to_slice: not byte aligned"); 146 Slice.make t.src ~first:t.byte_pos ~length:(t.limit - t.byte_pos) 147 148 let advance t n = 149 if t.bit_pos <> 0 then 150 raise (Invalid_state "advance: not byte aligned"); 151 if t.byte_pos + n > t.limit then 152 raise End_of_stream; 153 t.byte_pos <- t.byte_pos + n 154 155 let sub t n = 156 if t.bit_pos <> 0 then 157 raise (Invalid_state "sub: not byte aligned"); 158 if t.byte_pos + n > t.limit then 159 raise End_of_stream; 160 let result = of_slice (Slice.make t.src ~first:t.byte_pos ~length:n) in 161 t.byte_pos <- t.byte_pos + n; 162 result 163 164 let remaining_bytes t = 165 if t.bit_pos <> 0 then 166 raise (Invalid_state "remaining_bytes: not byte aligned"); 167 t.limit - t.byte_pos 168 169 let skip_bits t n = 170 ignore (read_bits t n) 171end 172 173(** {1 Backward Bitstream Reader} 174 175 Reads bits from the end of a buffer towards the start. The stream 176 starts with a padding marker (highest 1-bit indicates start of data). *) 177 178module Backward_reader = struct 179 type t = { 180 src : bytes; 181 start_pos : int; 182 mutable bit_offset : int; (* Bits remaining from end, decreasing *) 183 } 184 185 let of_slice (slice : Slice.t) = 186 if slice.length = 0 then 187 raise End_of_stream; 188 let last_byte_pos = slice.first + slice.length - 1 in 189 let last_byte = Bytes.get_uint8 slice.bytes last_byte_pos in 190 if last_byte = 0 then 191 raise (Corrupted_stream "invalid padding marker"); 192 (* Find the highest set bit - this is the padding marker *) 193 let rec find_marker byte bit = 194 if bit < 0 then 0 195 else if (byte land (1 lsl bit)) <> 0 then bit 196 else find_marker byte (bit - 1) 197 in 198 let padding = 8 - find_marker last_byte 7 in 199 let bit_offset = slice.length * 8 - padding in 200 { src = slice.bytes; start_pos = slice.first; bit_offset } 201 202 let of_bytes src ~pos ~len = 203 of_slice (Slice.make src ~first:pos ~length:len) 204 205 let[@inline] remaining t = t.bit_offset 206 207 let[@inline] is_empty t = t.bit_offset <= 0 208 209 let[@inline] read_bits t n = 210 if n <= 0 then 0 211 else if n > 57 then invalid_arg "read_bits: n > 57" 212 else begin 213 t.bit_offset <- t.bit_offset - n; 214 let actual_offset = max 0 t.bit_offset in 215 let actual_bits = if t.bit_offset < 0 then n + t.bit_offset else n in 216 if actual_bits <= 0 then 0 217 else begin 218 let byte_offset = t.start_pos + (actual_offset / 8) in 219 let bit_offset = actual_offset mod 8 in 220 let result = ref 0 in 221 let bits_read = ref 0 in 222 let current_byte = ref byte_offset in 223 let current_bit = ref bit_offset in 224 while !bits_read < actual_bits do 225 let byte = Bytes.get_uint8 t.src !current_byte in 226 let available = 8 - !current_bit in 227 let to_read = min available (actual_bits - !bits_read) in 228 let mask = (1 lsl to_read) - 1 in 229 let bits = (byte lsr !current_bit) land mask in 230 result := !result lor (bits lsl !bits_read); 231 bits_read := !bits_read + to_read; 232 current_bit := !current_bit + to_read; 233 if !current_bit >= 8 then begin 234 current_bit := 0; 235 incr current_byte 236 end 237 done; 238 (* If we read past the beginning, shift the result *) 239 if t.bit_offset < 0 then 240 !result lsl (-t.bit_offset) 241 else 242 !result 243 end 244 end 245 246 let peek_bits t n = 247 let saved_offset = t.bit_offset in 248 let result = read_bits t n in 249 t.bit_offset <- saved_offset; 250 result 251end 252 253(** {1 Forward Bitstream Writer} *) 254 255module Forward_writer = struct 256 type t = { 257 dst : bytes; 258 start_pos : int; 259 mutable byte_pos : int; 260 mutable bit_pos : int; (* 0-7, bits written in current byte *) 261 mutable current_byte : int; 262 } 263 264 let of_slice (slice : Slice.t) = 265 { dst = slice.bytes; 266 start_pos = slice.first; 267 byte_pos = slice.first; 268 bit_pos = 0; 269 current_byte = 0 } 270 271 let of_bytes dst = 272 of_slice (Slice.of_bytes dst) 273 274 let create dst ~pos = 275 of_slice (Slice.make dst ~first:pos ~length:(Bytes.length dst - pos)) 276 277 let flush t = 278 if t.bit_pos > 0 then begin 279 Bytes.set_uint8 t.dst t.byte_pos t.current_byte; 280 t.byte_pos <- t.byte_pos + 1; 281 t.bit_pos <- 0; 282 t.current_byte <- 0 283 end 284 285 let[@inline] write_bits t value n = 286 if n <= 0 then () 287 else if n > 57 then invalid_arg "write_bits: n > 57" 288 else begin 289 let value = ref value in 290 let remaining = ref n in 291 292 while !remaining > 0 do 293 let available = 8 - t.bit_pos in 294 let to_write = min available !remaining in 295 let mask = (1 lsl to_write) - 1 in 296 t.current_byte <- t.current_byte lor ((!value land mask) lsl t.bit_pos); 297 value := !value lsr to_write; 298 remaining := !remaining - to_write; 299 t.bit_pos <- t.bit_pos + to_write; 300 301 if t.bit_pos = 8 then begin 302 Bytes.set_uint8 t.dst t.byte_pos t.current_byte; 303 t.byte_pos <- t.byte_pos + 1; 304 t.bit_pos <- 0; 305 t.current_byte <- 0 306 end 307 done 308 end 309 310 let write_byte t value = 311 if t.bit_pos <> 0 then flush t; 312 Bytes.set_uint8 t.dst t.byte_pos value; 313 t.byte_pos <- t.byte_pos + 1 314 315 let write_slice t (slice : Slice.t) = 316 if t.bit_pos <> 0 then flush t; 317 Bytes.blit slice.bytes slice.first t.dst t.byte_pos slice.length; 318 t.byte_pos <- t.byte_pos + slice.length 319 320 let write_bytes t src = 321 write_slice t (Slice.of_bytes src) 322 323 let byte_position t = 324 if t.bit_pos > 0 then t.byte_pos + 1 else t.byte_pos 325 326 let finalize t = 327 flush t; 328 t.byte_pos - t.start_pos 329 330 let to_slice t : Slice.t = 331 flush t; 332 Slice.make t.dst ~first:t.start_pos ~length:(t.byte_pos - t.start_pos) 333end 334 335(** {1 Backward Bitstream Writer} 336 337 Accumulates bits to be read backwards. Used for FSE and Huffman encoding. *) 338 339module Backward_writer = struct 340 type t = { 341 mutable bits : int64; 342 mutable num_bits : int; 343 buffer : bytes; 344 mutable buf_pos : int; 345 } 346 347 let create size = 348 { bits = 0L; num_bits = 0; buffer = Bytes.create size; buf_pos = size } 349 350 let[@inline] write_bits t value n = 351 if n > 0 then begin 352 t.bits <- Int64.logor t.bits (Int64.shift_left (Int64.of_int value) t.num_bits); 353 t.num_bits <- t.num_bits + n 354 end 355 356 let flush_bytes t = 357 while t.num_bits >= 8 do 358 t.buf_pos <- t.buf_pos - 1; 359 Bytes.set_uint8 t.buffer t.buf_pos (Int64.to_int (Int64.logand t.bits 0xFFL)); 360 t.bits <- Int64.shift_right_logical t.bits 8; 361 t.num_bits <- t.num_bits - 8 362 done 363 364 let finalize_to_slice t : Slice.t = 365 write_bits t 1 1; 366 if t.num_bits mod 8 <> 0 then 367 t.num_bits <- ((t.num_bits + 7) / 8) * 8; 368 flush_bytes t; 369 let len = Bytes.length t.buffer - t.buf_pos in 370 (* Reverse bytes in place so marker ends up at the end *) 371 for i = 0 to len / 2 - 1 do 372 let j = t.buf_pos + i in 373 let k = t.buf_pos + len - 1 - i in 374 let tmp = Bytes.get t.buffer j in 375 Bytes.set t.buffer j (Bytes.get t.buffer k); 376 Bytes.set t.buffer k tmp 377 done; 378 Slice.make t.buffer ~first:t.buf_pos ~length:len 379 380 let finalize t = 381 Slice.to_bytes (finalize_to_slice t) 382 383 let current_size t = 384 Bytes.length t.buffer - t.buf_pos + (t.num_bits + 7) / 8 385end