Punycode (RFC3492) in OCaml
at cb3b948db5e4331c200ef196b41ea35be325cf60 507 lines 15 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6(* RFC 3492 Punycode Implementation *) 7 8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *) 9 10let base = 36 11let tmin = 1 12let tmax = 26 13let skew = 38 14let damp = 700 15let initial_bias = 72 16let initial_n = 0x80 (* 128 *) 17let delimiter = '-' 18let ace_prefix = "xn--" 19let max_label_length = 63 20 21(* {1 Position Tracking} *) 22 23type position = { 24 byte_offset : int; 25 char_index : int; 26} 27 28let position_byte_offset pos = pos.byte_offset 29let position_char_index pos = pos.char_index 30 31let pp_position fmt pos = 32 Format.fprintf fmt "byte %d, char %d" pos.byte_offset pos.char_index 33 34 35(* {1 Error Types} *) 36 37type error = 38 | Overflow of position 39 | Invalid_character of position * Uchar.t 40 | Invalid_digit of position * char 41 | Unexpected_end of position 42 | Invalid_utf8 of position 43 | Label_too_long of int 44 | Empty_label 45 46let pp_error fmt = function 47 | Overflow pos -> 48 Format.fprintf fmt "arithmetic overflow at %a" pp_position pos 49 | Invalid_character (pos, u) -> 50 Format.fprintf fmt "invalid character U+%04X at %a" 51 (Uchar.to_int u) pp_position pos 52 | Invalid_digit (pos, c) -> 53 Format.fprintf fmt "invalid Punycode digit '%c' (0x%02X) at %a" 54 c (Char.code c) pp_position pos 55 | Unexpected_end pos -> 56 Format.fprintf fmt "unexpected end of input at %a" pp_position pos 57 | Invalid_utf8 pos -> 58 Format.fprintf fmt "invalid UTF-8 sequence at %a" pp_position pos 59 | Label_too_long len -> 60 Format.fprintf fmt "label too long: %d bytes (max %d)" len max_label_length 61 | Empty_label -> 62 Format.fprintf fmt "empty label" 63 64(* {1 Error Constructors} *) 65 66let overflow pos = Error (Overflow pos) 67let invalid_character pos u = Error (Invalid_character (pos, u)) 68let invalid_digit pos c = Error (Invalid_digit (pos, c)) 69let unexpected_end pos = Error (Unexpected_end pos) 70let _invalid_utf8 pos = Error (Invalid_utf8 pos) 71let label_too_long len = Error (Label_too_long len) 72let empty_label = Error Empty_label 73 74(* {1 Case Flags} *) 75 76type case_flag = Uppercase | Lowercase 77 78(* {1 Basic Predicates} *) 79 80let is_basic u = 81 Uchar.to_int u < 0x80 82 83let is_ascii_string s = 84 String.for_all (fun c -> Char.code c < 0x80) s 85 86let has_ace_prefix s = 87 let len = String.length s in 88 len >= 4 && 89 (s.[0] = 'x' || s.[0] = 'X') && 90 (s.[1] = 'n' || s.[1] = 'N') && 91 s.[2] = '-' && s.[3] = '-' 92 93(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)} 94 95 Digit values: 96 - 0-25: a-z (or A-Z) 97 - 26-35: 0-9 98*) 99 100let encode_digit d case_flag = 101 if d < 26 then 102 Char.chr (d + (if case_flag = Uppercase then 0x41 else 0x61)) 103 else 104 Char.chr (d - 26 + 0x30) 105 106let decode_digit c = 107 let code = Char.code c in 108 if code >= 0x30 && code <= 0x39 then 109 Some (code - 0x30 + 26) (* '0'-'9' -> 26-35 *) 110 else if code >= 0x41 && code <= 0x5A then 111 Some (code - 0x41) (* 'A'-'Z' -> 0-25 *) 112 else if code >= 0x61 && code <= 0x7A then 113 Some (code - 0x61) (* 'a'-'z' -> 0-25 *) 114 else 115 None 116 117(* Check if a character is "flagged" (uppercase) for case annotation *) 118let is_flagged c = 119 let code = Char.code c in 120 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *) 121 122(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *) 123 124let adapt ~delta ~numpoints ~firsttime = 125 let delta = if firsttime then delta / damp else delta / 2 in 126 let delta = delta + (delta / numpoints) in 127 let threshold = ((base - tmin) * tmax) / 2 in 128 let rec loop delta k = 129 if delta > threshold then 130 loop (delta / (base - tmin)) (k + base) 131 else 132 k + (((base - tmin + 1) * delta) / (delta + skew)) 133 in 134 loop delta 0 135 136(* {1 Overflow-Safe Arithmetic} 137 138 RFC 3492 Section 6.4: Use detection to avoid overflow. 139 A + B overflows iff B > maxint - A 140 A + B*C overflows iff B > (maxint - A) / C 141*) 142 143let max_int_value = max_int 144 145let safe_mul_add a b c pos = 146 if c = 0 then Ok a 147 else if b > (max_int_value - a) / c then 148 overflow pos 149 else 150 Ok (a + b * c) 151 152(* {1 UTF-8 to Code Points Conversion} *) 153 154let utf8_to_codepoints s = 155 let len = String.length s in 156 let acc = ref [] in 157 let byte_offset = ref 0 in 158 let char_index = ref 0 in 159 let error = ref None in 160 while !byte_offset < len && !error = None do 161 let pos = { byte_offset = !byte_offset; char_index = !char_index } in 162 let dec = String.get_utf_8_uchar s !byte_offset in 163 if Uchar.utf_decode_is_valid dec then begin 164 acc := Uchar.utf_decode_uchar dec :: !acc; 165 byte_offset := !byte_offset + Uchar.utf_decode_length dec; 166 incr char_index 167 end else begin 168 error := Some (Invalid_utf8 pos) 169 end 170 done; 171 match !error with 172 | Some e -> Error e 173 | None -> Ok (Array.of_list (List.rev !acc)) 174 175(* {1 Code Points to UTF-8 Conversion} *) 176 177let codepoints_to_utf8 codepoints = 178 let buf = Buffer.create (Array.length codepoints * 2) in 179 Array.iter (Buffer.add_utf_8_uchar buf) codepoints; 180 Buffer.contents buf 181 182(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *) 183 184let encode_impl codepoints case_flags = 185 let input_length = Array.length codepoints in 186 if input_length = 0 then 187 Ok "" 188 else begin 189 let output = Buffer.create (input_length * 2) in 190 191 (* Copy basic code points to output *) 192 let basic_count = ref 0 in 193 for j = 0 to input_length - 1 do 194 let cp = codepoints.(j) in 195 if is_basic cp then begin 196 let c = Uchar.to_int cp in 197 let case = 198 match case_flags with 199 | Some flags -> flags.(j) 200 | None -> Lowercase 201 in 202 (* Preserve or apply case for ASCII letters *) 203 let c' = 204 if c >= 0x41 && c <= 0x5A then (* 'A'-'Z' *) 205 if case = Lowercase then c + 0x20 else c 206 else if c >= 0x61 && c <= 0x7A then (* 'a'-'z' *) 207 if case = Uppercase then c - 0x20 else c 208 else 209 c 210 in 211 Buffer.add_char output (Char.chr c'); 212 incr basic_count 213 end 214 done; 215 216 let b = !basic_count in 217 let h = ref b in 218 219 (* Add delimiter if there were basic code points *) 220 if b > 0 then 221 Buffer.add_char output delimiter; 222 223 (* Main encoding loop *) 224 let n = ref initial_n in 225 let delta = ref 0 in 226 let bias = ref initial_bias in 227 228 let result = ref (Ok ()) in 229 230 while !h < input_length && !result = Ok () do 231 (* Find minimum code point >= n *) 232 let m = Array.fold_left (fun acc cp -> 233 let cp_val = Uchar.to_int cp in 234 if cp_val >= !n && cp_val < acc then cp_val else acc 235 ) max_int_value codepoints in 236 237 (* Increase delta to advance state to <m, 0> *) 238 let pos = { byte_offset = 0; char_index = !h } in 239 (match safe_mul_add !delta (m - !n) (!h + 1) pos with 240 | Error e -> result := Error e 241 | Ok new_delta -> 242 delta := new_delta; 243 n := m; 244 245 (* Process each code point *) 246 let j = ref 0 in 247 while !j < input_length && !result = Ok () do 248 let cp = Uchar.to_int codepoints.(!j) in 249 let pos = { byte_offset = 0; char_index = !j } in 250 251 if cp < !n then begin 252 incr delta; 253 if !delta = 0 then (* Overflow *) 254 result := overflow pos 255 end 256 else if cp = !n then begin 257 (* Encode delta as variable-length integer *) 258 let q = ref !delta in 259 let k = ref base in 260 let done_encoding = ref false in 261 262 while not !done_encoding do 263 let t = 264 if !k <= !bias then tmin 265 else if !k >= !bias + tmax then tmax 266 else !k - !bias 267 in 268 if !q < t then begin 269 (* Output final digit *) 270 let case = 271 match case_flags with 272 | Some flags -> flags.(!j) 273 | None -> Lowercase 274 in 275 Buffer.add_char output (encode_digit !q case); 276 done_encoding := true 277 end else begin 278 (* Output intermediate digit and continue *) 279 let digit = t + ((!q - t) mod (base - t)) in 280 Buffer.add_char output (encode_digit digit Lowercase); 281 q := (!q - t) / (base - t); 282 k := !k + base 283 end 284 done; 285 286 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b); 287 delta := 0; 288 incr h 289 end; 290 incr j 291 done; 292 293 incr delta; 294 incr n) 295 done; 296 297 match !result with 298 | Error e -> Error e 299 | Ok () -> Ok (Buffer.contents output) 300 end 301 302let encode codepoints = 303 encode_impl codepoints None 304 305let encode_with_case codepoints case_flags = 306 if Array.length codepoints <> Array.length case_flags then 307 invalid_arg "encode_with_case: array lengths must match"; 308 encode_impl codepoints (Some case_flags) 309 310(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *) 311 312let decode_impl input = 313 let input_length = String.length input in 314 if input_length = 0 then 315 Ok ([||], [||]) 316 else begin 317 (* Find last delimiter *) 318 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in 319 320 (* Copy basic code points and extract case flags *) 321 let output = ref [] in 322 let case_output = ref [] in 323 let error = ref None in 324 325 for j = 0 to b - 1 do 326 if !error = None then begin 327 let c = input.[j] in 328 let pos = { byte_offset = j; char_index = j } in 329 let code = Char.code c in 330 if code >= 0x80 then 331 error := Some (Invalid_character (pos, Uchar.of_int code)) 332 else begin 333 output := Uchar.of_int code :: !output; 334 case_output := (if is_flagged c then Uppercase else Lowercase) :: !case_output 335 end 336 end 337 done; 338 339 match !error with 340 | Some e -> Error e 341 | None -> 342 let output = ref (Array.of_list (List.rev !output)) in 343 let case_output = ref (Array.of_list (List.rev !case_output)) in 344 345 (* Main decoding loop *) 346 let n = ref initial_n in 347 let i = ref 0 in 348 let bias = ref initial_bias in 349 let in_pos = ref (if b > 0 then b + 1 else 0) in 350 let result = ref (Ok ()) in 351 352 while !in_pos < input_length && !result = Ok () do 353 let oldi = !i in 354 let w = ref 1 in 355 let k = ref base in 356 let done_decoding = ref false in 357 358 while not !done_decoding && !result = Ok () do 359 let pos = { byte_offset = !in_pos; char_index = Array.length !output } in 360 361 if !in_pos >= input_length then begin 362 result := unexpected_end pos; 363 done_decoding := true 364 end else begin 365 let c = input.[!in_pos] in 366 incr in_pos; 367 368 match decode_digit c with 369 | None -> 370 result := invalid_digit pos c; 371 done_decoding := true 372 | Some digit -> 373 (* i = i + digit * w, with overflow check *) 374 (match safe_mul_add !i digit !w pos with 375 | Error e -> 376 result := Error e; 377 done_decoding := true 378 | Ok new_i -> 379 i := new_i; 380 381 let t = 382 if !k <= !bias then tmin 383 else if !k >= !bias + tmax then tmax 384 else !k - !bias 385 in 386 387 if digit < t then begin 388 (* Record case flag from this final digit *) 389 done_decoding := true 390 end else begin 391 (* w = w * (base - t), with overflow check *) 392 let base_minus_t = base - t in 393 if !w > max_int_value / base_minus_t then begin 394 result := overflow pos; 395 done_decoding := true 396 end else begin 397 w := !w * base_minus_t; 398 k := !k + base 399 end 400 end) 401 end 402 done; 403 404 if !result = Ok () then begin 405 let out_len = Array.length !output in 406 bias := adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1) ~firsttime:(oldi = 0); 407 408 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in 409 410 (* n = n + i / (out_len + 1), with overflow check *) 411 let increment = !i / (out_len + 1) in 412 if increment > max_int_value - !n then 413 result := overflow pos 414 else begin 415 n := !n + increment; 416 i := !i mod (out_len + 1); 417 418 (* Validate that n is a valid Unicode scalar value *) 419 if not (Uchar.is_valid !n) then 420 result := invalid_character pos Uchar.rep 421 else begin 422 (* Insert n at position i *) 423 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in 424 let new_case = Array.make (out_len + 1) Lowercase in 425 426 for j = 0 to !i - 1 do 427 new_output.(j) <- !output.(j); 428 new_case.(j) <- !case_output.(j) 429 done; 430 new_output.(!i) <- Uchar.of_int !n; 431 (* Case flag from final digit of this delta *) 432 new_case.(!i) <- (if !in_pos > 0 && is_flagged input.[!in_pos - 1] 433 then Uppercase else Lowercase); 434 for j = !i to out_len - 1 do 435 new_output.(j + 1) <- !output.(j); 436 new_case.(j + 1) <- !case_output.(j) 437 done; 438 439 output := new_output; 440 case_output := new_case; 441 incr i 442 end 443 end 444 end 445 done; 446 447 match !result with 448 | Error e -> Error e 449 | Ok () -> Ok (!output, !case_output) 450 end 451 452let decode input = 453 Result.map fst (decode_impl input) 454 455let decode_with_case input = 456 decode_impl input 457 458(* {1 UTF-8 String Operations} *) 459 460let encode_utf8 s = 461 let open Result.Syntax in 462 let* codepoints = utf8_to_codepoints s in 463 encode codepoints 464 465let decode_utf8 punycode = 466 let open Result.Syntax in 467 let+ codepoints = decode punycode in 468 codepoints_to_utf8 codepoints 469 470(* {1 Domain Label Operations} *) 471 472let encode_label label = 473 if String.length label = 0 then 474 empty_label 475 else if is_ascii_string label then begin 476 (* All ASCII - return as-is, but check length *) 477 let len = String.length label in 478 if len > max_label_length then 479 label_too_long len 480 else 481 Ok label 482 end else 483 (* Has non-ASCII - encode with Punycode *) 484 let open Result.Syntax in 485 let* encoded = encode_utf8 label in 486 let result = ace_prefix ^ encoded in 487 let len = String.length result in 488 if len > max_label_length then 489 label_too_long len 490 else 491 Ok result 492 493let decode_label label = 494 if String.length label = 0 then 495 empty_label 496 else if has_ace_prefix label then begin 497 (* Remove ACE prefix and decode *) 498 let punycode = String.sub label 4 (String.length label - 4) in 499 decode_utf8 punycode 500 end else begin 501 (* No ACE prefix - validate and return *) 502 if is_ascii_string label then 503 Ok label 504 else 505 (* Has non-ASCII but no ACE prefix - return as-is *) 506 Ok label 507 end