Punycode (RFC3492) in OCaml

more idiomatic ocaml

+101 -116
+43 -53
lib/punycode.ml
··· 61 61 | Empty_label -> 62 62 Format.fprintf fmt "empty label" 63 63 64 + (* {1 Error Constructors} *) 65 + 66 + let overflow pos = Error (Overflow pos) 67 + let invalid_character pos u = Error (Invalid_character (pos, u)) 68 + let invalid_digit pos c = Error (Invalid_digit (pos, c)) 69 + let unexpected_end pos = Error (Unexpected_end pos) 70 + let _invalid_utf8 pos = Error (Invalid_utf8 pos) 71 + let label_too_long len = Error (Label_too_long len) 72 + let empty_label = Error Empty_label 64 73 65 74 (* {1 Case Flags} *) 66 75 ··· 70 79 71 80 let is_basic u = 72 81 Uchar.to_int u < 0x80 73 - 74 - 75 - let is_delimiter c = c = delimiter 76 82 77 83 let is_ascii_string s = 78 - let rec loop i = 79 - if i >= String.length s then true 80 - else if Char.code s.[i] >= 0x80 then false 81 - else loop (i + 1) 82 - in 83 - loop 0 84 + String.for_all (fun c -> Char.code c < 0x80) s 84 85 85 86 let has_ace_prefix s = 86 87 let len = String.length s in ··· 144 145 let safe_mul_add a b c pos = 145 146 if c = 0 then Ok a 146 147 else if b > (max_int_value - a) / c then 147 - Error (Overflow pos) 148 + overflow pos 148 149 else 149 150 Ok (a + b * c) 150 151 ··· 228 229 229 230 while !h < input_length && !result = Ok () do 230 231 (* Find minimum code point >= n *) 231 - let m = ref max_int_value in 232 - for j = 0 to input_length - 1 do 233 - let cp = Uchar.to_int codepoints.(j) in 234 - if cp >= !n && cp < !m then 235 - m := cp 236 - done; 232 + let m = Array.fold_left (fun acc cp -> 233 + let cp_val = Uchar.to_int cp in 234 + if cp_val >= !n && cp_val < acc then cp_val else acc 235 + ) max_int_value codepoints in 237 236 238 237 (* Increase delta to advance state to <m, 0> *) 239 238 let pos = { byte_offset = 0; char_index = !h } in 240 - (match safe_mul_add !delta (!m - !n) (!h + 1) pos with 239 + (match safe_mul_add !delta (m - !n) (!h + 1) pos with 241 240 | Error e -> result := Error e 242 241 | Ok new_delta -> 243 242 delta := new_delta; 244 - n := !m; 243 + n := m; 245 244 246 245 (* Process each code point *) 247 246 let j = ref 0 in ··· 252 251 if cp < !n then begin 253 252 incr delta; 254 253 if !delta = 0 then (* Overflow *) 255 - result := Error (Overflow pos) 254 + result := overflow pos 256 255 end 257 256 else if cp = !n then begin 258 257 (* Encode delta as variable-length integer *) ··· 316 315 Ok ([||], [||]) 317 316 else begin 318 317 (* Find last delimiter *) 319 - let last_delim = ref (-1) in 320 - for j = 0 to input_length - 1 do 321 - if is_delimiter input.[j] then 322 - last_delim := j 323 - done; 324 - let b = if !last_delim < 0 then 0 else !last_delim in 318 + let b = Option.value ~default:0 (String.rindex_opt input delimiter) in 325 319 326 320 (* Copy basic code points and extract case flags *) 327 321 let output = ref [] in ··· 365 359 let pos = { byte_offset = !in_pos; char_index = Array.length !output } in 366 360 367 361 if !in_pos >= input_length then begin 368 - result := Error (Unexpected_end pos); 362 + result := unexpected_end pos; 369 363 done_decoding := true 370 364 end else begin 371 365 let c = input.[!in_pos] in ··· 373 367 374 368 match decode_digit c with 375 369 | None -> 376 - result := Error (Invalid_digit (pos, c)); 370 + result := invalid_digit pos c; 377 371 done_decoding := true 378 372 | Some digit -> 379 373 (* i = i + digit * w, with overflow check *) ··· 397 391 (* w = w * (base - t), with overflow check *) 398 392 let base_minus_t = base - t in 399 393 if !w > max_int_value / base_minus_t then begin 400 - result := Error (Overflow pos); 394 + result := overflow pos; 401 395 done_decoding := true 402 396 end else begin 403 397 w := !w * base_minus_t; ··· 416 410 (* n = n + i / (out_len + 1), with overflow check *) 417 411 let increment = !i / (out_len + 1) in 418 412 if increment > max_int_value - !n then 419 - result := Error (Overflow pos) 413 + result := overflow pos 420 414 else begin 421 415 n := !n + increment; 422 416 i := !i mod (out_len + 1); 423 417 424 418 (* Validate that n is a valid Unicode scalar value *) 425 419 if not (Uchar.is_valid !n) then 426 - result := Error (Invalid_character (pos, Uchar.rep)) 420 + result := invalid_character pos Uchar.rep 427 421 else begin 428 422 (* Insert n at position i *) 429 423 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in ··· 456 450 end 457 451 458 452 let decode input = 459 - match decode_impl input with 460 - | Error e -> Error e 461 - | Ok (codepoints, _) -> Ok codepoints 453 + Result.map fst (decode_impl input) 462 454 463 455 let decode_with_case input = 464 456 decode_impl input ··· 466 458 (* {1 UTF-8 String Operations} *) 467 459 468 460 let encode_utf8 s = 469 - match utf8_to_codepoints s with 470 - | Error e -> Error e 471 - | Ok codepoints -> encode codepoints 461 + let open Result.Syntax in 462 + let* codepoints = utf8_to_codepoints s in 463 + encode codepoints 472 464 473 465 let decode_utf8 punycode = 474 - match decode punycode with 475 - | Error e -> Error e 476 - | Ok codepoints -> Ok (codepoints_to_utf8 codepoints) 466 + let open Result.Syntax in 467 + let+ codepoints = decode punycode in 468 + codepoints_to_utf8 codepoints 477 469 478 470 (* {1 Domain Label Operations} *) 479 471 480 472 let encode_label label = 481 473 if String.length label = 0 then 482 - Error Empty_label 474 + empty_label 483 475 else if is_ascii_string label then begin 484 476 (* All ASCII - return as-is, but check length *) 485 477 let len = String.length label in 486 478 if len > max_label_length then 487 - Error (Label_too_long len) 479 + label_too_long len 488 480 else 489 481 Ok label 490 - end else begin 482 + end else 491 483 (* Has non-ASCII - encode with Punycode *) 492 - match encode_utf8 label with 493 - | Error e -> Error e 494 - | Ok encoded -> 495 - let result = ace_prefix ^ encoded in 496 - let len = String.length result in 497 - if len > max_label_length then 498 - Error (Label_too_long len) 499 - else 500 - Ok result 501 - end 484 + let open Result.Syntax in 485 + let* encoded = encode_utf8 label in 486 + let result = ace_prefix ^ encoded in 487 + let len = String.length result in 488 + if len > max_label_length then 489 + label_too_long len 490 + else 491 + Ok result 502 492 503 493 let decode_label label = 504 494 if String.length label = 0 then 505 - Error Empty_label 495 + empty_label 506 496 else if has_ace_prefix label then begin 507 497 (* Remove ACE prefix and decode *) 508 498 let punycode = String.sub label 4 (String.length label - 4) in
+58 -63
lib/punycode_idna.ml
··· 28 28 | Verification_failed -> 29 29 Format.fprintf fmt "IDNA verification failed (round-trip mismatch)" 30 30 31 + (* {1 Error Constructors} *) 32 + 33 + let punycode_error e = Error (Punycode_error e) 34 + let invalid_label msg = Error (Invalid_label msg) 35 + let domain_too_long len = Error (Domain_too_long len) 36 + let _normalization_failed = Error Normalization_failed 37 + let verification_failed = Error Verification_failed 31 38 32 39 (* {1 Unicode Normalization} *) 33 40 ··· 44 51 - Cannot start or end with hyphen *) 45 52 let is_std3_valid label = 46 53 let len = String.length label in 47 - if len = 0 then false 48 - else if label.[0] = '-' || label.[len - 1] = '-' then false 49 - else 50 - let rec check i = 51 - if i >= len then true 52 - else 53 - let c = label.[i] in 54 - let valid = 55 - (c >= 'a' && c <= 'z') || 56 - (c >= 'A' && c <= 'Z') || 57 - (c >= '0' && c <= '9') || 58 - c = '-' 59 - in 60 - if valid then check (i + 1) else false 61 - in 62 - check 0 54 + let is_ldh c = 55 + (c >= 'a' && c <= 'z') || 56 + (c >= 'A' && c <= 'Z') || 57 + (c >= '0' && c <= '9') || 58 + c = '-' 59 + in 60 + len > 0 && 61 + label.[0] <> '-' && 62 + label.[len - 1] <> '-' && 63 + String.for_all is_ldh label 63 64 64 65 (* Check hyphen placement: hyphens not in positions 3 and 4 (except for ACE) *) 65 66 let check_hyphen_rules label = ··· 75 76 let label_to_ascii_impl ~check_hyphens ~use_std3_rules label = 76 77 let len = String.length label in 77 78 if len = 0 then 78 - Error (Invalid_label "empty label") 79 + invalid_label "empty label" 79 80 else if len > Punycode.max_label_length then 80 - Error (Punycode_error (Punycode.Label_too_long len)) 81 + punycode_error (Punycode.Label_too_long len) 81 82 else if Punycode.is_ascii_string label then begin 82 83 (* All ASCII - validate and pass through *) 83 84 if use_std3_rules && not (is_std3_valid label) then 84 - Error (Invalid_label "STD3 rules violation") 85 + invalid_label "STD3 rules violation" 85 86 else if check_hyphens && not (check_hyphen_rules label) then 86 - Error (Invalid_label "invalid hyphen placement") 87 + invalid_label "invalid hyphen placement" 87 88 else 88 89 Ok label 89 90 end else begin ··· 92 93 93 94 (* Encode to Punycode *) 94 95 match Punycode.encode_utf8 normalized with 95 - | Error e -> Error (Punycode_error e) 96 + | Error e -> punycode_error e 96 97 | Ok encoded -> 97 98 let result = Punycode.ace_prefix ^ encoded in 98 99 let result_len = String.length result in 99 100 if result_len > Punycode.max_label_length then 100 - Error (Punycode_error (Punycode.Label_too_long result_len)) 101 + punycode_error (Punycode.Label_too_long result_len) 101 102 else if check_hyphens && not (check_hyphen_rules result) then 102 - Error (Invalid_label "invalid hyphen placement in encoded label") 103 + invalid_label "invalid hyphen placement in encoded label" 103 104 else 104 105 (* Verification: decode and compare to original normalized form *) 105 106 match Punycode.decode_utf8 encoded with 106 - | Error _ -> Error Verification_failed 107 + | Error _ -> verification_failed 107 108 | Ok decoded -> 108 109 if decoded <> normalized then 109 - Error Verification_failed 110 + verification_failed 110 111 else 111 112 Ok result 112 113 end ··· 118 119 if is_ace_label label then begin 119 120 let encoded = String.sub label 4 (String.length label - 4) in 120 121 match Punycode.decode_utf8 encoded with 121 - | Error e -> Error (Punycode_error e) 122 + | Error e -> punycode_error e 122 123 | Ok decoded -> Ok decoded 123 124 end else 124 125 Ok label ··· 133 134 let join_labels labels = 134 135 String.concat "." labels 135 136 137 + (* Map a function returning Result over a list, short-circuiting on first Error *) 138 + let map_result f lst = 139 + List.fold_right (fun x acc -> 140 + let open Result.Syntax in 141 + let* y = f x in 142 + let+ ys = acc in 143 + y :: ys 144 + ) lst (Ok []) 145 + 136 146 let to_ascii ?(check_hyphens = true) ?(check_bidi = false) 137 147 ?(check_joiners = false) ?(use_std3_rules = false) 138 148 ?(transitional = false) domain = ··· 142 152 let _ = check_joiners in 143 153 let _ = transitional in 144 154 155 + let open Result.Syntax in 145 156 let labels = split_domain domain in 146 - let rec process acc = function 147 - | [] -> 148 - let result = join_labels (List.rev acc) in 149 - let len = String.length result in 150 - if len > max_domain_length then 151 - Error (Domain_too_long len) 152 - else 153 - Ok result 154 - | label :: rest -> 155 - match label_to_ascii_impl ~check_hyphens ~use_std3_rules label with 156 - | Error e -> Error e 157 - | Ok encoded -> process (encoded :: acc) rest 158 - in 159 - process [] labels 157 + let* encoded_labels = map_result (label_to_ascii_impl ~check_hyphens ~use_std3_rules) labels in 158 + let result = join_labels encoded_labels in 159 + let len = String.length result in 160 + if len > max_domain_length then 161 + domain_too_long len 162 + else 163 + Ok result 160 164 161 165 let to_unicode domain = 166 + let open Result.Syntax in 162 167 let labels = split_domain domain in 163 - let rec process acc = function 164 - | [] -> Ok (join_labels (List.rev acc)) 165 - | label :: rest -> 166 - match label_to_unicode label with 167 - | Error e -> Error e 168 - | Ok decoded -> process (decoded :: acc) rest 169 - in 170 - process [] labels 168 + let+ decoded_labels = map_result label_to_unicode labels in 169 + join_labels decoded_labels 171 170 172 171 (* {1 Domain Name Library Integration} *) 173 172 174 173 let domain_to_ascii ?(check_hyphens = true) ?(use_std3_rules = false) domain = 174 + let open Result.Syntax in 175 175 let s = Domain_name.to_string domain in 176 - match to_ascii ~check_hyphens ~use_std3_rules s with 177 - | Error e -> Error e 178 - | Ok ascii -> 179 - match Domain_name.of_string ascii with 180 - | Error (`Msg msg) -> Error (Invalid_label msg) 181 - | Ok d -> Ok d 176 + let* ascii = to_ascii ~check_hyphens ~use_std3_rules s in 177 + match Domain_name.of_string ascii with 178 + | Error (`Msg msg) -> invalid_label msg 179 + | Ok d -> Ok d 182 180 183 181 let domain_to_unicode domain = 182 + let open Result.Syntax in 184 183 let s = Domain_name.to_string domain in 185 - match to_unicode s with 186 - | Error e -> Error e 187 - | Ok unicode -> 188 - match Domain_name.of_string unicode with 189 - | Error (`Msg msg) -> Error (Invalid_label msg) 190 - | Ok d -> Ok d 184 + let* unicode = to_unicode s in 185 + match Domain_name.of_string unicode with 186 + | Error (`Msg msg) -> invalid_label msg 187 + | Ok d -> Ok d 191 188 192 189 (* {1 Validation} *) 193 190 194 191 let is_idna_valid domain = 195 - match to_ascii domain with 196 - | Ok _ -> true 197 - | Error _ -> false 192 + Result.is_ok (to_ascii domain)