Zstd compression in pure OCaml
at main 752 lines 29 kB view raw
1(** Zstandard compression implementation. 2 3 Implements LZ77 matching, block compression, and frame encoding. *) 4 5(** Compression level affects speed vs ratio tradeoff *) 6type compression_level = { 7 window_log : int; (* Log2 of window size *) 8 chain_log : int; (* Log2 of hash chain length *) 9 hash_log : int; (* Log2 of hash table size *) 10 search_log : int; (* Number of searches per position *) 11 min_match : int; (* Minimum match length *) 12 target_len : int; (* Target match length *) 13 strategy : int; (* 0=fast, 1=greedy, 2=lazy *) 14} 15 16(** Default levels 1-19 *) 17let level_params = [| 18 (* Level 0/1: Fast *) 19 { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 }; 20 { window_log = 17; chain_log = 12; hash_log = 11; search_log = 1; min_match = 4; target_len = 0; strategy = 0 }; 21 (* Level 2 *) 22 { window_log = 18; chain_log = 13; hash_log = 12; search_log = 1; min_match = 5; target_len = 4; strategy = 0 }; 23 (* Level 3 *) 24 { window_log = 18; chain_log = 14; hash_log = 13; search_log = 1; min_match = 5; target_len = 8; strategy = 1 }; 25 (* Level 4 *) 26 { window_log = 18; chain_log = 14; hash_log = 14; search_log = 2; min_match = 4; target_len = 8; strategy = 1 }; 27 (* Level 5 *) 28 { window_log = 18; chain_log = 15; hash_log = 14; search_log = 3; min_match = 4; target_len = 16; strategy = 1 }; 29 (* Level 6 *) 30 { window_log = 19; chain_log = 16; hash_log = 15; search_log = 3; min_match = 4; target_len = 32; strategy = 1 }; 31 (* Level 7 *) 32 { window_log = 19; chain_log = 16; hash_log = 15; search_log = 4; min_match = 4; target_len = 32; strategy = 2 }; 33 (* Level 8 *) 34 { window_log = 19; chain_log = 17; hash_log = 16; search_log = 4; min_match = 4; target_len = 64; strategy = 2 }; 35 (* Level 9 *) 36 { window_log = 20; chain_log = 17; hash_log = 16; search_log = 5; min_match = 4; target_len = 64; strategy = 2 }; 37 (* Level 10 *) 38 { window_log = 20; chain_log = 17; hash_log = 16; search_log = 6; min_match = 4; target_len = 128; strategy = 2 }; 39 (* Level 11 *) 40 { window_log = 20; chain_log = 18; hash_log = 17; search_log = 6; min_match = 4; target_len = 128; strategy = 2 }; 41 (* Level 12 *) 42 { window_log = 21; chain_log = 18; hash_log = 17; search_log = 7; min_match = 4; target_len = 256; strategy = 2 }; 43 (* Level 13 *) 44 { window_log = 21; chain_log = 19; hash_log = 18; search_log = 7; min_match = 4; target_len = 256; strategy = 2 }; 45 (* Level 14 *) 46 { window_log = 22; chain_log = 19; hash_log = 18; search_log = 8; min_match = 4; target_len = 256; strategy = 2 }; 47 (* Level 15 *) 48 { window_log = 22; chain_log = 20; hash_log = 18; search_log = 9; min_match = 4; target_len = 256; strategy = 2 }; 49 (* Level 16 *) 50 { window_log = 22; chain_log = 20; hash_log = 19; search_log = 10; min_match = 4; target_len = 512; strategy = 2 }; 51 (* Level 17 *) 52 { window_log = 22; chain_log = 21; hash_log = 19; search_log = 11; min_match = 4; target_len = 512; strategy = 2 }; 53 (* Level 18 *) 54 { window_log = 22; chain_log = 21; hash_log = 20; search_log = 12; min_match = 4; target_len = 512; strategy = 2 }; 55 (* Level 19 *) 56 { window_log = 23; chain_log = 22; hash_log = 20; search_log = 12; min_match = 4; target_len = 1024; strategy = 2 }; 57|] 58 59let get_level_params level = 60 let level = max 1 (min level 19) in 61 level_params.(level) 62 63(** A sequence represents a literal run + match *) 64type sequence = { 65 lit_length : int; 66 match_offset : int; 67 match_length : int; 68} 69 70(** Hash table for fast match finding *) 71type hash_table = { 72 table : int array; (* Position indexed by hash *) 73 chain : int array; (* Chain of previous matches at same hash *) 74 mask : int; 75} 76 77let create_hash_table log_size = 78 let size = 1 lsl log_size in 79 { 80 table = Array.make size (-1); 81 chain = Array.make (1 lsl 20) (-1); (* Max input size *) 82 mask = size - 1; 83 } 84 85(** Compute hash of 4 bytes *) 86let[@inline] hash4 src pos = 87 let v = Bytes.get_int32_le src pos in 88 (* MurmurHash3-like mixing *) 89 let h = Int32.to_int (Int32.mul v 0xcc9e2d51l) in 90 (h lxor (h lsr 15)) 91 92(** Check if positions match and return length *) 93let match_length src pos1 pos2 limit = 94 let len = ref 0 in 95 let max_len = min (limit - pos1) (pos1 - pos2) in 96 while !len < max_len && 97 Bytes.get_uint8 src (pos1 + !len) = Bytes.get_uint8 src (pos2 + !len) do 98 incr len 99 done; 100 !len 101 102(** Find best match at current position *) 103let find_best_match ht src pos limit params = 104 if pos + 4 > limit then 105 (0, 0) 106 else begin 107 let h = hash4 src pos land ht.mask in 108 let prev_pos = ht.table.(h) in 109 110 (* Update hash table *) 111 ht.chain.(pos) <- prev_pos; 112 ht.table.(h) <- pos; 113 114 if prev_pos < 0 || pos - prev_pos > (1 lsl params.window_log) then 115 (0, 0) 116 else begin 117 (* Search chain for best match *) 118 let best_offset = ref 0 in 119 let best_length = ref 0 in 120 let chain_pos = ref prev_pos in 121 let searches = ref 0 in 122 let max_searches = 1 lsl params.search_log in 123 124 while !chain_pos >= 0 && !searches < max_searches do 125 let offset = pos - !chain_pos in 126 if offset > (1 lsl params.window_log) then 127 chain_pos := -1 128 else begin 129 let len = match_length src pos !chain_pos limit in 130 if len >= params.min_match && len > !best_length then begin 131 best_length := len; 132 best_offset := offset 133 end; 134 chain_pos := ht.chain.(!chain_pos); 135 incr searches 136 end 137 done; 138 139 (!best_offset, !best_length) 140 end 141 end 142 143(** Parse input into sequences using greedy/lazy matching *) 144let parse_sequences src ~pos ~len params = 145 let sequences = ref [] in 146 let cur_pos = ref pos in 147 let limit = pos + len in 148 let lit_start = ref pos in 149 150 let ht = create_hash_table params.hash_log in 151 152 while !cur_pos + 4 <= limit do 153 let (offset, length) = find_best_match ht src !cur_pos limit params in 154 155 if length >= params.min_match then begin 156 (* Emit sequence *) 157 let lit_len = !cur_pos - !lit_start in 158 sequences := { lit_length = lit_len; match_offset = offset; match_length = length } :: !sequences; 159 160 (* Update hash table for matched positions *) 161 for i = !cur_pos + 1 to !cur_pos + length - 1 do 162 if i + 4 <= limit then begin 163 let h = hash4 src i land ht.mask in 164 ht.chain.(i) <- ht.table.(h); 165 ht.table.(h) <- i 166 end 167 done; 168 169 cur_pos := !cur_pos + length; 170 lit_start := !cur_pos 171 end else begin 172 incr cur_pos 173 end 174 done; 175 176 (* Handle remaining literals *) 177 let remaining = limit - !lit_start in 178 if remaining > 0 || !sequences = [] then 179 sequences := { lit_length = remaining; match_offset = 0; match_length = 0 } :: !sequences; 180 181 List.rev !sequences 182 183(** Encode literal length code *) 184let encode_lit_length_code lit_len = 185 if lit_len < 16 then 186 (lit_len, 0, 0) 187 else if lit_len < 64 then 188 (16 + (lit_len - 16) / 4, (lit_len - 16) mod 4, 2) 189 else if lit_len < 128 then 190 (28 + (lit_len - 64) / 8, (lit_len - 64) mod 8, 3) 191 else begin 192 (* Use baseline tables for larger values *) 193 let rec find_code code = 194 if code >= 35 then (35, lit_len - Constants.ll_baselines.(35), Constants.ll_extra_bits.(35)) 195 else if lit_len < Constants.ll_baselines.(code + 1) then 196 (code, lit_len - Constants.ll_baselines.(code), Constants.ll_extra_bits.(code)) 197 else find_code (code + 1) 198 in 199 find_code 16 200 end 201 202(** Minimum match length for zstd *) 203let min_match = 3 204 205(** Encode match length code *) 206let encode_match_length_code match_len = 207 let ml = match_len - min_match in 208 if ml < 32 then 209 (ml, 0, 0) 210 else if ml < 64 then 211 (32 + (ml - 32) / 2, (ml - 32) mod 2, 1) 212 else begin 213 let rec find_code code = 214 if code >= 52 then (52, ml - Constants.ml_baselines.(52) + 3, Constants.ml_extra_bits.(52)) 215 else if ml < Constants.ml_baselines.(code + 1) - 3 then 216 (code, ml - Constants.ml_baselines.(code) + 3, Constants.ml_extra_bits.(code)) 217 else find_code (code + 1) 218 in 219 find_code 32 220 end 221 222(** Encode offset code. 223 Returns (of_code, extra_value, extra_bits). 224 225 Repeat offsets use offBase 1,2,3: 226 - offBase=1: ofCode=0, no extra bits 227 - offBase=2: ofCode=1, extra=0 (1 bit) 228 - offBase=3: ofCode=1, extra=1 (1 bit) 229 230 Real offsets use offBase = offset + 3: 231 - ofCode = highbit(offBase) 232 - extra = lower ofCode bits of offBase *) 233let encode_offset_code offset offset_history = 234 let off_base = 235 if offset = offset_history.(0) then 1 236 else if offset = offset_history.(1) then 2 237 else if offset = offset_history.(2) then 3 238 else offset + 3 239 in 240 let of_code = Fse.highest_set_bit off_base in 241 let extra = off_base land ((1 lsl of_code) - 1) in 242 (of_code, extra, of_code) 243 244(** Write raw literals section *) 245let write_raw_literals literals ~pos ~len output ~out_pos = 246 if len = 0 then begin 247 (* Empty literals: single-byte header with type=0, size=0 *) 248 Bytes.set_uint8 output out_pos 0; 249 1 250 end else if len < 32 then begin 251 (* Raw literals, single stream, 1-byte header *) 252 (* Header: type=0 (raw), size_format=0 (5-bit), regen_size in bits 3-7 *) 253 let header = 0b00 lor ((len land 0x1f) lsl 3) in 254 Bytes.set_uint8 output out_pos header; 255 Bytes.blit literals pos output (out_pos + 1) len; 256 1 + len 257 end else if len < 4096 then begin 258 (* Raw literals, 2-byte header *) 259 (* type=0 (bits 0-1), size_format=1 (bits 2-3), size in bits 4-15 *) 260 let header = 0b0100 lor ((len land 0x0fff) lsl 4) in 261 Bytes.set_uint16_le output out_pos header; 262 Bytes.blit literals pos output (out_pos + 2) len; 263 2 + len 264 end else begin 265 (* Raw literals, 3-byte header *) 266 (* type=0 (bits 0-1), size_format=2 (bits 2-3), size in bits 4-17 (14 bits) *) 267 let header = 0b1000 lor ((len land 0x3fff) lsl 4) in 268 Bytes.set_uint8 output out_pos (header land 0xff); 269 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff); 270 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff); 271 Bytes.blit literals pos output (out_pos + 3) len; 272 3 + len 273 end 274 275(** Write compressed literals with Huffman encoding *) 276let write_compressed_literals literals ~pos ~len output ~out_pos = 277 if len < 32 then 278 (* Too small for Huffman, use raw *) 279 write_raw_literals literals ~pos ~len output ~out_pos 280 else begin 281 (* Count symbol frequencies *) 282 let counts = Array.make 256 0 in 283 for i = pos to pos + len - 1 do 284 let c = Bytes.get_uint8 literals i in 285 counts.(c) <- counts.(c) + 1 286 done; 287 288 (* Find max symbol used *) 289 let max_symbol = ref 0 in 290 for i = 0 to 255 do 291 if counts.(i) > 0 then max_symbol := i 292 done; 293 294 (* Build Huffman table *) 295 let ctable = Huffman.build_ctable counts !max_symbol Constants.max_huffman_bits in 296 297 if ctable.num_symbols = 0 then 298 write_raw_literals literals ~pos ~len output ~out_pos 299 else begin 300 (* Decide single vs 4-stream based on size *) 301 let use_4streams = len >= 256 in 302 303 (* Write Huffman table header to temp buffer *) 304 let header_buf = Bytes.create 256 in 305 let header_stream = Bit_writer.Forward.of_bytes header_buf in 306 let _num_written = Huffman.write_header header_stream ctable in 307 let header_size = Bit_writer.Forward.byte_position header_stream in 308 309 (* Compress literals *) 310 let compressed = 311 if use_4streams then 312 Huffman.compress_4stream ctable literals ~pos ~len 313 else 314 Huffman.compress_1stream ctable literals ~pos ~len 315 in 316 let compressed_size = Bytes.length compressed in 317 318 (* Check if compression is worthwhile (should save at least 10%) *) 319 let total_compressed_size = header_size + compressed_size in 320 if total_compressed_size >= len - len / 10 then 321 write_raw_literals literals ~pos ~len output ~out_pos 322 else begin 323 (* Write compressed literals header *) 324 (* Type: 2 = compressed, size_format based on sizes *) 325 let regen_size = len in 326 let lit_type = 2 in (* Compressed_literals *) 327 328 let header_pos = ref out_pos in 329 if regen_size < 1024 && total_compressed_size < 1024 then begin 330 (* 3-byte header: type(2) + size_format(2) + regen(10) + compressed(10) + streams(2) *) 331 let size_format = 0 in 332 let streams_flag = if use_4streams then 3 else 0 in 333 let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in 334 let h1 = ((regen_size lsr 6) land 0xf) lor ((total_compressed_size land 0xf) lsl 4) in 335 let h2 = (total_compressed_size lsr 4) land 0xff in 336 Bytes.set_uint8 output !header_pos h0; 337 Bytes.set_uint8 output (!header_pos + 1) h1; 338 Bytes.set_uint8 output (!header_pos + 2) h2; 339 header_pos := !header_pos + 3 340 end else begin 341 (* 5-byte header for larger sizes *) 342 let size_format = 1 in 343 let streams_flag = if use_4streams then 3 else 0 in 344 let h0 = lit_type lor (size_format lsl 2) lor (streams_flag lsl 4) lor ((regen_size land 0x3f) lsl 6) in 345 Bytes.set_uint8 output !header_pos h0; 346 Bytes.set_uint16_le output (!header_pos + 1) (((regen_size lsr 6) land 0x3fff) lor ((total_compressed_size land 0x3) lsl 14)); 347 Bytes.set_uint16_le output (!header_pos + 3) ((total_compressed_size lsr 2) land 0xffff); 348 header_pos := !header_pos + 5 349 end; 350 351 (* Write Huffman table *) 352 Bytes.blit header_buf 0 output !header_pos header_size; 353 header_pos := !header_pos + header_size; 354 355 (* Write compressed streams *) 356 Bytes.blit compressed 0 output !header_pos compressed_size; 357 358 !header_pos + compressed_size - out_pos 359 end 360 end 361 end 362 363(** Compress literals - try Huffman, fall back to raw *) 364let compress_literals literals ~pos ~len output ~out_pos = 365 write_compressed_literals literals ~pos ~len output ~out_pos 366 367(** Build predefined FSE compression tables *) 368let ll_ctable = lazy (Fse.build_predefined_ctable Constants.ll_default_distribution Constants.ll_default_accuracy_log) 369let ml_ctable = lazy (Fse.build_predefined_ctable Constants.ml_default_distribution Constants.ml_default_accuracy_log) 370let of_ctable = lazy (Fse.build_predefined_ctable Constants.of_default_distribution Constants.of_default_accuracy_log) 371 372(** Compress sequences section using predefined FSE tables. 373 This implements proper zstd sequence encoding following RFC 8878. 374 375 Matches C zstd's ZSTD_encodeSequences_body exactly: 376 1. Initialize states with FSE_initCState2 using LAST sequence's codes 377 2. Write LAST sequence's extra bits (LL, ML, OF order) 378 3. For sequences n-2 down to 0: 379 - FSE_encodeSymbol for OF, ML, LL 380 - Extra bits for LL, ML, OF 381 4. FSE_flushCState for ML, OF, LL 382*) 383let compress_sequences sequences output ~out_pos offset_history = 384 if sequences = [] then begin 385 (* Zero sequences *) 386 Bytes.set_uint8 output out_pos 0; 387 1 388 end else begin 389 let num_seq = List.length sequences in 390 let header_size = ref 0 in 391 392 (* Write sequence count (1-3 bytes) *) 393 if num_seq < 128 then begin 394 Bytes.set_uint8 output out_pos num_seq; 395 header_size := 1 396 end else if num_seq < 0x7f00 then begin 397 Bytes.set_uint8 output out_pos ((num_seq lsr 8) + 128); 398 Bytes.set_uint8 output (out_pos + 1) (num_seq land 0xff); 399 header_size := 2 400 end else begin 401 Bytes.set_uint8 output out_pos 0xff; 402 Bytes.set_uint16_le output (out_pos + 1) (num_seq - 0x7f00); 403 header_size := 3 404 end; 405 406 (* Symbol compression modes byte: 407 bits 0-1: Literals_Lengths_Mode (0 = predefined) 408 bits 2-3: Offsets_Mode (0 = predefined) 409 bits 4-5: Match_Lengths_Mode (0 = predefined) 410 bits 6-7: reserved *) 411 Bytes.set_uint8 output (out_pos + !header_size) 0b00; 412 incr header_size; 413 414 (* Get predefined FSE tables *) 415 let ll_ct = Lazy.force ll_ctable in 416 let ml_ct = Lazy.force ml_ctable in 417 let of_ct = Lazy.force of_ctable in 418 419 let offset_hist = Array.copy offset_history in 420 let seq_array = Array.of_list sequences in 421 422 (* Encode all sequences in forward order to track offset history *) 423 let encoded = Array.map (fun seq -> 424 let (ll_code, ll_extra, ll_extra_bits) = encode_lit_length_code seq.lit_length in 425 let (ml_code, ml_extra, ml_extra_bits) = encode_match_length_code seq.match_length in 426 let (of_code, of_extra, of_extra_bits) = encode_offset_code seq.match_offset offset_hist in 427 428 (* Update offset history for real offsets (of_code > 1 means offBase > 2) *) 429 if seq.match_offset > 0 && of_code > 1 then begin 430 offset_hist.(2) <- offset_hist.(1); 431 offset_hist.(1) <- offset_hist.(0); 432 offset_hist.(0) <- seq.match_offset 433 end; 434 435 (ll_code, ll_extra, ll_extra_bits, ml_code, ml_extra, ml_extra_bits, of_code, of_extra, of_extra_bits) 436 ) seq_array in 437 438 (* Use a backward bit writer *) 439 let stream = Bit_writer.Backward.create (num_seq * 20 + 32) in 440 441 (* Get last sequence's codes for state initialization *) 442 let last_idx = num_seq - 1 in 443 let (ll_code_last, ll_extra_last, ll_extra_bits_last, 444 ml_code_last, ml_extra_last, ml_extra_bits_last, 445 of_code_last, of_extra_last, of_extra_bits_last) = encoded.(last_idx) in 446 447 (* Initialize FSE states with LAST sequence's codes *) 448 let ll_state = Fse.init_cstate2 ll_ct ll_code_last in 449 let ml_state = Fse.init_cstate2 ml_ct ml_code_last in 450 let of_state = Fse.init_cstate2 of_ct of_code_last in 451 452 (* Write LAST sequence's extra bits first (LL, ML, OF order) *) 453 if ll_extra_bits_last > 0 then 454 Bit_writer.Backward.write_bits stream ll_extra_last ll_extra_bits_last; 455 if ml_extra_bits_last > 0 then 456 Bit_writer.Backward.write_bits stream ml_extra_last ml_extra_bits_last; 457 if of_extra_bits_last > 0 then 458 Bit_writer.Backward.write_bits stream of_extra_last of_extra_bits_last; 459 460 (* Process sequences from n-2 down to 0 *) 461 for i = last_idx - 1 downto 0 do 462 let (ll_code, ll_extra, ll_extra_bits, 463 ml_code, ml_extra, ml_extra_bits, 464 of_code, of_extra, of_extra_bits) = encoded.(i) in 465 466 (* FSE encode: OF, ML, LL order *) 467 Fse.encode_symbol stream of_state of_code; 468 Fse.encode_symbol stream ml_state ml_code; 469 Fse.encode_symbol stream ll_state ll_code; 470 471 (* Extra bits: LL, ML, OF order *) 472 if ll_extra_bits > 0 then 473 Bit_writer.Backward.write_bits stream ll_extra ll_extra_bits; 474 if ml_extra_bits > 0 then 475 Bit_writer.Backward.write_bits stream ml_extra ml_extra_bits; 476 if of_extra_bits > 0 then 477 Bit_writer.Backward.write_bits stream of_extra of_extra_bits 478 done; 479 480 (* Flush states: ML, OF, LL order *) 481 Fse.flush_cstate stream ml_state; 482 Fse.flush_cstate stream of_state; 483 Fse.flush_cstate stream ll_state; 484 485 (* Finalize and copy to output *) 486 let seq_data = Bit_writer.Backward.finalize stream in 487 let seq_len = Bytes.length seq_data in 488 Bytes.blit seq_data 0 output (out_pos + !header_size) seq_len; 489 490 !header_size + seq_len 491 end 492 493(** Write raw block (no compression) *) 494let write_raw_block src ~pos ~len output ~out_pos = 495 (* Raw block: header (3 bytes) + raw data 496 Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = block_size 497 For raw: block_type = 0, block_size = number of bytes *) 498 let header = (Constants.block_raw lsl 1) lor ((len land 0x1fffff) lsl 3) in 499 Bytes.set_uint8 output out_pos (header land 0xff); 500 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff); 501 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff); 502 Bytes.blit src pos output (out_pos + 3) len; 503 3 + len 504 505(** Write compressed block with sequences *) 506let write_compressed_block src ~pos ~len sequences output ~out_pos offset_history = 507 (* Collect all literals *) 508 let total_lit_len = List.fold_left (fun acc seq -> acc + seq.lit_length) 0 sequences in 509 let literals = Bytes.create total_lit_len in 510 let lit_pos = ref 0 in 511 let src_pos = ref pos in 512 List.iter (fun seq -> 513 if seq.lit_length > 0 then begin 514 Bytes.blit src !src_pos literals !lit_pos seq.lit_length; 515 lit_pos := !lit_pos + seq.lit_length; 516 src_pos := !src_pos + seq.lit_length 517 end; 518 src_pos := !src_pos + seq.match_length 519 ) sequences; 520 521 (* Build block content in temp buffer *) 522 let block_buf = Bytes.create (len * 2 + 256) in 523 let block_pos = ref 0 in 524 525 (* Write literals section *) 526 let lit_size = compress_literals literals ~pos:0 ~len:total_lit_len block_buf ~out_pos:!block_pos in 527 block_pos := !block_pos + lit_size; 528 529 (* Filter out sequences with only literals (match_length = 0 and match_offset = 0) 530 at the end - the last sequence can be literal-only *) 531 let real_sequences = List.filter (fun seq -> 532 seq.match_length > 0 || seq.match_offset > 0 533 ) sequences in 534 535 (* Write sequences section *) 536 let seq_size = compress_sequences real_sequences block_buf ~out_pos:!block_pos offset_history in 537 block_pos := !block_pos + seq_size; 538 539 let block_size = !block_pos in 540 541 (* Check if compressed block is actually smaller *) 542 if block_size >= len then begin 543 (* Fall back to raw block *) 544 write_raw_block src ~pos ~len output ~out_pos 545 end else begin 546 (* Write compressed block header *) 547 let header = (Constants.block_compressed lsl 1) lor ((block_size land 0x1fffff) lsl 3) in 548 Bytes.set_uint8 output out_pos (header land 0xff); 549 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff); 550 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff); 551 Bytes.blit block_buf 0 output (out_pos + 3) block_size; 552 3 + block_size 553 end 554 555(** Write RLE block (single byte repeated) *) 556let write_rle_block byte len output ~out_pos = 557 (* RLE block: header (3 bytes) + single byte 558 Header format: bit 0 = last_block, bits 1-2 = block_type, bits 3-23 = regen_size 559 For RLE: block_type = 1, regen_size = number of bytes when expanded *) 560 let header = (Constants.block_rle lsl 1) lor ((len land 0x1fffff) lsl 3) in 561 Bytes.set_uint8 output out_pos (header land 0xff); 562 Bytes.set_uint8 output (out_pos + 1) ((header lsr 8) land 0xff); 563 Bytes.set_uint8 output (out_pos + 2) ((header lsr 16) land 0xff); 564 Bytes.set_uint8 output (out_pos + 3) byte; 565 4 566 567(** Check if block is all same byte *) 568let is_rle_block src ~pos ~len = 569 if len = 0 then None 570 else begin 571 let first = Bytes.get_uint8 src pos in 572 let all_same = ref true in 573 for i = pos + 1 to pos + len - 1 do 574 if Bytes.get_uint8 src i <> first then all_same := false 575 done; 576 if !all_same then Some first else None 577 end 578 579(** Compress a single block using LZ77 + FSE + Huffman. 580 Falls back to RLE for repetitive data, or raw blocks if compression doesn't help. *) 581let compress_block src ~pos ~len output ~out_pos params offset_history = 582 if len = 0 then 583 0 584 else 585 (* Check for RLE opportunity (all same byte) *) 586 match is_rle_block src ~pos ~len with 587 | Some byte when len > 4 -> 588 (* RLE is worthwhile: 4 bytes instead of len+3 *) 589 write_rle_block byte len output ~out_pos 590 | _ -> 591 (* Try LZ77 + FSE compression for compressible data *) 592 let sequences = parse_sequences src ~pos ~len params in 593 let match_count = List.fold_left (fun acc s -> 594 if s.match_length > 0 then acc + 1 else acc) 0 sequences in 595 (* Use compressed blocks for compressible data. The backward bitstream 596 writer now uses periodic flushing like C zstd, supporting any size. *) 597 if match_count >= 2 && len >= 64 then 598 write_compressed_block src ~pos ~len sequences output ~out_pos offset_history 599 else 600 write_raw_block src ~pos ~len output ~out_pos 601 602(** Write frame header *) 603let write_frame_header output ~pos content_size window_log checksum_flag = 604 (* Magic number *) 605 Bytes.set_int32_le output pos Constants.zstd_magic_number; 606 let out_pos = ref (pos + 4) in 607 608 (* Use single segment mode for smaller content (no window descriptor needed). 609 FCS field sizes when single_segment is set: 610 - fcs_flag=0: 1 byte (content size 0-255) 611 - fcs_flag=1: 2 bytes (content size 256-65791, stored with -256) 612 - fcs_flag=2: 4 bytes 613 - fcs_flag=3: 8 bytes *) 614 let single_segment = content_size <= 131072L in 615 616 let (fcs_flag, fcs_bytes) = 617 if single_segment then begin 618 if content_size <= 255L then (0, 1) 619 else if content_size <= 65791L then (1, 2) (* 2-byte has +256 offset *) 620 else if content_size <= 0xFFFFFFFFL then (2, 4) 621 else (3, 8) 622 end else begin 623 (* For non-single-segment, fcs_flag=0 means no FCS field *) 624 if content_size = 0L then (0, 0) 625 else if content_size <= 65535L then (1, 2) 626 else if content_size <= 0xFFFFFFFFL then (2, 4) 627 else (3, 8) 628 end 629 in 630 631 (* Frame header descriptor: 632 bit 0-1: dict ID flag (0 = no dict) 633 bit 2: content checksum flag 634 bit 3: reserved 635 bit 4: unused 636 bit 5: single segment (no window descriptor) 637 bit 6-7: FCS field size flag *) 638 let descriptor = 639 (if checksum_flag then 0b00000100 else 0) 640 lor (if single_segment then 0b00100000 else 0) 641 lor (fcs_flag lsl 6) 642 in 643 Bytes.set_uint8 output !out_pos descriptor; 644 incr out_pos; 645 646 (* Window descriptor (only if not single segment) *) 647 if not single_segment then begin 648 let window_desc = ((window_log - 10) lsl 3) in 649 Bytes.set_uint8 output !out_pos window_desc; 650 incr out_pos 651 end; 652 653 (* Frame content size *) 654 begin match fcs_bytes with 655 | 1 -> 656 Bytes.set_uint8 output !out_pos (Int64.to_int content_size); 657 out_pos := !out_pos + 1 658 | 2 -> 659 (* 2-byte FCS stores value - 256 *) 660 let adjusted = Int64.sub content_size 256L in 661 Bytes.set_uint16_le output !out_pos (Int64.to_int adjusted); 662 out_pos := !out_pos + 2 663 | 4 -> 664 Bytes.set_int32_le output !out_pos (Int64.to_int32 content_size); 665 out_pos := !out_pos + 4 666 | 8 -> 667 Bytes.set_int64_le output !out_pos content_size; 668 out_pos := !out_pos + 8 669 | _ -> () 670 end; 671 672 !out_pos - pos 673 674(** Compress data to zstd frame *) 675let compress ?(level = 3) ?(checksum = true) src = 676 let src = Bytes.of_string src in 677 let len = Bytes.length src in 678 let params = get_level_params level in 679 680 (* Allocate output buffer - worst case is slightly larger than input *) 681 let max_output = len + len / 128 + 256 in 682 let output = Bytes.create max_output in 683 684 (* Initialize offset history *) 685 let offset_history = Array.copy Constants.initial_repeat_offsets in 686 687 (* Write frame header *) 688 let header_size = write_frame_header output ~pos:0 (Int64.of_int len) params.window_log checksum in 689 let out_pos = ref header_size in 690 691 (* Compress blocks *) 692 if len = 0 then begin 693 (* Empty content: write an empty raw block with last_block flag *) 694 (* Block header: last_block=1, block_type=raw(0), block_size=0 *) 695 (* Header = 1 | (0 << 1) | (0 << 3) = 0x01 *) 696 Bytes.set_uint8 output !out_pos 0x01; 697 Bytes.set_uint8 output (!out_pos + 1) 0x00; 698 Bytes.set_uint8 output (!out_pos + 2) 0x00; 699 out_pos := !out_pos + 3 700 end else begin 701 let block_size = min len Constants.block_size_max in 702 let pos = ref 0 in 703 704 while !pos < len do 705 let this_block = min block_size (len - !pos) in 706 let is_last = !pos + this_block >= len in 707 708 let block_len = compress_block src ~pos:!pos ~len:this_block output ~out_pos:!out_pos params offset_history in 709 710 (* Set last block flag *) 711 if is_last then begin 712 let current = Bytes.get_uint8 output !out_pos in 713 Bytes.set_uint8 output !out_pos (current lor 0x01) 714 end; 715 716 out_pos := !out_pos + block_len; 717 pos := !pos + this_block 718 done 719 end; 720 721 (* Write checksum if requested *) 722 if checksum then begin 723 let hash = Xxhash.hash64 src ~pos:0 ~len in 724 (* Write only lower 32 bits *) 725 Bytes.set_int32_le output !out_pos (Int64.to_int32 hash); 726 out_pos := !out_pos + 4 727 end; 728 729 Bytes.sub_string output 0 !out_pos 730 731(** Calculate maximum compressed size *) 732let compress_bound len = 733 len + len / 128 + 256 734 735(** Write a skippable frame. 736 @param variant Magic number variant 0-15 737 @param content The content to embed in the skippable frame 738 @return The complete skippable frame as a string *) 739let write_skippable_frame ?(variant = 0) content = 740 let variant = max 0 (min 15 variant) in 741 let len = String.length content in 742 if len > 0xFFFFFFFF then 743 invalid_arg "Skippable frame content too large (max 4GB)"; 744 let output = Bytes.create (Constants.skippable_header_size + len) in 745 (* Magic number: 0x184D2A50 + variant *) 746 let magic = Int32.add Constants.skippable_magic_start (Int32.of_int variant) in 747 Bytes.set_int32_le output 0 magic; 748 (* Content size (4 bytes little-endian) *) 749 Bytes.set_int32_le output 4 (Int32.of_int len); 750 (* Content *) 751 Bytes.blit_string content 0 output 8 len; 752 Bytes.unsafe_to_string output