Punycode (RFC3492) in OCaml
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6(* RFC 3492 Punycode Implementation *)
7
8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *)
9
10let base = 36
11let tmin = 1
12let tmax = 26
13let skew = 38
14let damp = 700
15let initial_bias = 72
16let initial_n = 0x80 (* 128 *)
17let delimiter = '-'
18let ace_prefix = "xn--"
19let max_label_length = 63
20
21(* {1 Position Tracking} *)
22
23type position = {
24 byte_offset : int;
25 char_index : int;
26}
27
28let position_byte_offset pos = pos.byte_offset
29let position_char_index pos = pos.char_index
30
31let pp_position fmt pos =
32 Format.fprintf fmt "byte %d, char %d" pos.byte_offset pos.char_index
33
34
35(* {1 Error Types} *)
36
37type error =
38 | Overflow of position
39 | Invalid_character of position * Uchar.t
40 | Invalid_digit of position * char
41 | Unexpected_end of position
42 | Invalid_utf8 of position
43 | Label_too_long of int
44 | Empty_label
45
46let pp_error fmt = function
47 | Overflow pos ->
48 Format.fprintf fmt "arithmetic overflow at %a" pp_position pos
49 | Invalid_character (pos, u) ->
50 Format.fprintf fmt "invalid character U+%04X at %a"
51 (Uchar.to_int u) pp_position pos
52 | Invalid_digit (pos, c) ->
53 Format.fprintf fmt "invalid Punycode digit '%c' (0x%02X) at %a"
54 c (Char.code c) pp_position pos
55 | Unexpected_end pos ->
56 Format.fprintf fmt "unexpected end of input at %a" pp_position pos
57 | Invalid_utf8 pos ->
58 Format.fprintf fmt "invalid UTF-8 sequence at %a" pp_position pos
59 | Label_too_long len ->
60 Format.fprintf fmt "label too long: %d bytes (max %d)" len max_label_length
61 | Empty_label ->
62 Format.fprintf fmt "empty label"
63
64(* {1 Error Constructors} *)
65
66let overflow pos = Error (Overflow pos)
67let invalid_character pos u = Error (Invalid_character (pos, u))
68let invalid_digit pos c = Error (Invalid_digit (pos, c))
69let unexpected_end pos = Error (Unexpected_end pos)
70let _invalid_utf8 pos = Error (Invalid_utf8 pos)
71let label_too_long len = Error (Label_too_long len)
72let empty_label = Error Empty_label
73
74(* {1 Case Flags} *)
75
76type case_flag = Uppercase | Lowercase
77
78(* {1 Basic Predicates} *)
79
80let is_basic u =
81 Uchar.to_int u < 0x80
82
83let is_ascii_string s =
84 String.for_all (fun c -> Char.code c < 0x80) s
85
86let has_ace_prefix s =
87 let len = String.length s in
88 len >= 4 &&
89 (s.[0] = 'x' || s.[0] = 'X') &&
90 (s.[1] = 'n' || s.[1] = 'N') &&
91 s.[2] = '-' && s.[3] = '-'
92
93(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)}
94
95 Digit values:
96 - 0-25: a-z (or A-Z)
97 - 26-35: 0-9
98*)
99
100let encode_digit d case_flag =
101 if d < 26 then
102 Char.chr (d + (if case_flag = Uppercase then 0x41 else 0x61))
103 else
104 Char.chr (d - 26 + 0x30)
105
106let decode_digit c =
107 let code = Char.code c in
108 if code >= 0x30 && code <= 0x39 then
109 Some (code - 0x30 + 26) (* '0'-'9' -> 26-35 *)
110 else if code >= 0x41 && code <= 0x5A then
111 Some (code - 0x41) (* 'A'-'Z' -> 0-25 *)
112 else if code >= 0x61 && code <= 0x7A then
113 Some (code - 0x61) (* 'a'-'z' -> 0-25 *)
114 else
115 None
116
117(* Check if a character is "flagged" (uppercase) for case annotation *)
118let is_flagged c =
119 let code = Char.code c in
120 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *)
121
122(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *)
123
124let adapt ~delta ~numpoints ~firsttime =
125 let delta = if firsttime then delta / damp else delta / 2 in
126 let delta = delta + (delta / numpoints) in
127 let threshold = ((base - tmin) * tmax) / 2 in
128 let rec loop delta k =
129 if delta > threshold then
130 loop (delta / (base - tmin)) (k + base)
131 else
132 k + (((base - tmin + 1) * delta) / (delta + skew))
133 in
134 loop delta 0
135
136(* {1 Overflow-Safe Arithmetic}
137
138 RFC 3492 Section 6.4: Use detection to avoid overflow.
139 A + B overflows iff B > maxint - A
140 A + B*C overflows iff B > (maxint - A) / C
141*)
142
143let max_int_value = max_int
144
145let safe_mul_add a b c pos =
146 if c = 0 then Ok a
147 else if b > (max_int_value - a) / c then
148 overflow pos
149 else
150 Ok (a + b * c)
151
152(* {1 UTF-8 to Code Points Conversion} *)
153
154let utf8_to_codepoints s =
155 let len = String.length s in
156 let acc = ref [] in
157 let byte_offset = ref 0 in
158 let char_index = ref 0 in
159 let error = ref None in
160 while !byte_offset < len && !error = None do
161 let pos = { byte_offset = !byte_offset; char_index = !char_index } in
162 let dec = String.get_utf_8_uchar s !byte_offset in
163 if Uchar.utf_decode_is_valid dec then begin
164 acc := Uchar.utf_decode_uchar dec :: !acc;
165 byte_offset := !byte_offset + Uchar.utf_decode_length dec;
166 incr char_index
167 end else begin
168 error := Some (Invalid_utf8 pos)
169 end
170 done;
171 match !error with
172 | Some e -> Error e
173 | None -> Ok (Array.of_list (List.rev !acc))
174
175(* {1 Code Points to UTF-8 Conversion} *)
176
177let codepoints_to_utf8 codepoints =
178 let buf = Buffer.create (Array.length codepoints * 2) in
179 Array.iter (Buffer.add_utf_8_uchar buf) codepoints;
180 Buffer.contents buf
181
182(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *)
183
184let encode_impl codepoints case_flags =
185 let input_length = Array.length codepoints in
186 if input_length = 0 then
187 Ok ""
188 else begin
189 let output = Buffer.create (input_length * 2) in
190
191 (* Copy basic code points to output *)
192 let basic_count = ref 0 in
193 for j = 0 to input_length - 1 do
194 let cp = codepoints.(j) in
195 if is_basic cp then begin
196 let c = Uchar.to_int cp in
197 let case =
198 match case_flags with
199 | Some flags -> flags.(j)
200 | None -> Lowercase
201 in
202 (* Preserve or apply case for ASCII letters *)
203 let c' =
204 if c >= 0x41 && c <= 0x5A then (* 'A'-'Z' *)
205 if case = Lowercase then c + 0x20 else c
206 else if c >= 0x61 && c <= 0x7A then (* 'a'-'z' *)
207 if case = Uppercase then c - 0x20 else c
208 else
209 c
210 in
211 Buffer.add_char output (Char.chr c');
212 incr basic_count
213 end
214 done;
215
216 let b = !basic_count in
217 let h = ref b in
218
219 (* Add delimiter if there were basic code points *)
220 if b > 0 then
221 Buffer.add_char output delimiter;
222
223 (* Main encoding loop *)
224 let n = ref initial_n in
225 let delta = ref 0 in
226 let bias = ref initial_bias in
227
228 let result = ref (Ok ()) in
229
230 while !h < input_length && !result = Ok () do
231 (* Find minimum code point >= n *)
232 let m = Array.fold_left (fun acc cp ->
233 let cp_val = Uchar.to_int cp in
234 if cp_val >= !n && cp_val < acc then cp_val else acc
235 ) max_int_value codepoints in
236
237 (* Increase delta to advance state to <m, 0> *)
238 let pos = { byte_offset = 0; char_index = !h } in
239 (match safe_mul_add !delta (m - !n) (!h + 1) pos with
240 | Error e -> result := Error e
241 | Ok new_delta ->
242 delta := new_delta;
243 n := m;
244
245 (* Process each code point *)
246 let j = ref 0 in
247 while !j < input_length && !result = Ok () do
248 let cp = Uchar.to_int codepoints.(!j) in
249 let pos = { byte_offset = 0; char_index = !j } in
250
251 if cp < !n then begin
252 incr delta;
253 if !delta = 0 then (* Overflow *)
254 result := overflow pos
255 end
256 else if cp = !n then begin
257 (* Encode delta as variable-length integer *)
258 let q = ref !delta in
259 let k = ref base in
260 let done_encoding = ref false in
261
262 while not !done_encoding do
263 let t =
264 if !k <= !bias then tmin
265 else if !k >= !bias + tmax then tmax
266 else !k - !bias
267 in
268 if !q < t then begin
269 (* Output final digit *)
270 let case =
271 match case_flags with
272 | Some flags -> flags.(!j)
273 | None -> Lowercase
274 in
275 Buffer.add_char output (encode_digit !q case);
276 done_encoding := true
277 end else begin
278 (* Output intermediate digit and continue *)
279 let digit = t + ((!q - t) mod (base - t)) in
280 Buffer.add_char output (encode_digit digit Lowercase);
281 q := (!q - t) / (base - t);
282 k := !k + base
283 end
284 done;
285
286 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b);
287 delta := 0;
288 incr h
289 end;
290 incr j
291 done;
292
293 incr delta;
294 incr n)
295 done;
296
297 match !result with
298 | Error e -> Error e
299 | Ok () -> Ok (Buffer.contents output)
300 end
301
302let encode codepoints =
303 encode_impl codepoints None
304
305let encode_with_case codepoints case_flags =
306 if Array.length codepoints <> Array.length case_flags then
307 invalid_arg "encode_with_case: array lengths must match";
308 encode_impl codepoints (Some case_flags)
309
310(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *)
311
312let decode_impl input =
313 let input_length = String.length input in
314 if input_length = 0 then
315 Ok ([||], [||])
316 else begin
317 (* Find last delimiter *)
318 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in
319
320 (* Copy basic code points and extract case flags *)
321 let output = ref [] in
322 let case_output = ref [] in
323 let error = ref None in
324
325 for j = 0 to b - 1 do
326 if !error = None then begin
327 let c = input.[j] in
328 let pos = { byte_offset = j; char_index = j } in
329 let code = Char.code c in
330 if code >= 0x80 then
331 error := Some (Invalid_character (pos, Uchar.of_int code))
332 else begin
333 output := Uchar.of_int code :: !output;
334 case_output := (if is_flagged c then Uppercase else Lowercase) :: !case_output
335 end
336 end
337 done;
338
339 match !error with
340 | Some e -> Error e
341 | None ->
342 let output = ref (Array.of_list (List.rev !output)) in
343 let case_output = ref (Array.of_list (List.rev !case_output)) in
344
345 (* Main decoding loop *)
346 let n = ref initial_n in
347 let i = ref 0 in
348 let bias = ref initial_bias in
349 let in_pos = ref (if b > 0 then b + 1 else 0) in
350 let result = ref (Ok ()) in
351
352 while !in_pos < input_length && !result = Ok () do
353 let oldi = !i in
354 let w = ref 1 in
355 let k = ref base in
356 let done_decoding = ref false in
357
358 while not !done_decoding && !result = Ok () do
359 let pos = { byte_offset = !in_pos; char_index = Array.length !output } in
360
361 if !in_pos >= input_length then begin
362 result := unexpected_end pos;
363 done_decoding := true
364 end else begin
365 let c = input.[!in_pos] in
366 incr in_pos;
367
368 match decode_digit c with
369 | None ->
370 result := invalid_digit pos c;
371 done_decoding := true
372 | Some digit ->
373 (* i = i + digit * w, with overflow check *)
374 (match safe_mul_add !i digit !w pos with
375 | Error e ->
376 result := Error e;
377 done_decoding := true
378 | Ok new_i ->
379 i := new_i;
380
381 let t =
382 if !k <= !bias then tmin
383 else if !k >= !bias + tmax then tmax
384 else !k - !bias
385 in
386
387 if digit < t then begin
388 (* Record case flag from this final digit *)
389 done_decoding := true
390 end else begin
391 (* w = w * (base - t), with overflow check *)
392 let base_minus_t = base - t in
393 if !w > max_int_value / base_minus_t then begin
394 result := overflow pos;
395 done_decoding := true
396 end else begin
397 w := !w * base_minus_t;
398 k := !k + base
399 end
400 end)
401 end
402 done;
403
404 if !result = Ok () then begin
405 let out_len = Array.length !output in
406 bias := adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1) ~firsttime:(oldi = 0);
407
408 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in
409
410 (* n = n + i / (out_len + 1), with overflow check *)
411 let increment = !i / (out_len + 1) in
412 if increment > max_int_value - !n then
413 result := overflow pos
414 else begin
415 n := !n + increment;
416 i := !i mod (out_len + 1);
417
418 (* Validate that n is a valid Unicode scalar value *)
419 if not (Uchar.is_valid !n) then
420 result := invalid_character pos Uchar.rep
421 else begin
422 (* Insert n at position i *)
423 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in
424 let new_case = Array.make (out_len + 1) Lowercase in
425
426 for j = 0 to !i - 1 do
427 new_output.(j) <- !output.(j);
428 new_case.(j) <- !case_output.(j)
429 done;
430 new_output.(!i) <- Uchar.of_int !n;
431 (* Case flag from final digit of this delta *)
432 new_case.(!i) <- (if !in_pos > 0 && is_flagged input.[!in_pos - 1]
433 then Uppercase else Lowercase);
434 for j = !i to out_len - 1 do
435 new_output.(j + 1) <- !output.(j);
436 new_case.(j + 1) <- !case_output.(j)
437 done;
438
439 output := new_output;
440 case_output := new_case;
441 incr i
442 end
443 end
444 end
445 done;
446
447 match !result with
448 | Error e -> Error e
449 | Ok () -> Ok (!output, !case_output)
450 end
451
452let decode input =
453 Result.map fst (decode_impl input)
454
455let decode_with_case input =
456 decode_impl input
457
458(* {1 UTF-8 String Operations} *)
459
460let encode_utf8 s =
461 let open Result.Syntax in
462 let* codepoints = utf8_to_codepoints s in
463 encode codepoints
464
465let decode_utf8 punycode =
466 let open Result.Syntax in
467 let+ codepoints = decode punycode in
468 codepoints_to_utf8 codepoints
469
470(* {1 Domain Label Operations} *)
471
472let encode_label label =
473 if String.length label = 0 then
474 empty_label
475 else if is_ascii_string label then begin
476 (* All ASCII - return as-is, but check length *)
477 let len = String.length label in
478 if len > max_label_length then
479 label_too_long len
480 else
481 Ok label
482 end else
483 (* Has non-ASCII - encode with Punycode *)
484 let open Result.Syntax in
485 let* encoded = encode_utf8 label in
486 let result = ace_prefix ^ encoded in
487 let len = String.length result in
488 if len > max_label_length then
489 label_too_long len
490 else
491 Ok result
492
493let decode_label label =
494 if String.length label = 0 then
495 empty_label
496 else if has_ace_prefix label then begin
497 (* Remove ACE prefix and decode *)
498 let punycode = String.sub label 4 (String.length label - 4) in
499 decode_utf8 punycode
500 end else begin
501 (* No ACE prefix - validate and return *)
502 if is_ascii_string label then
503 Ok label
504 else
505 (* Has non-ASCII but no ACE prefix - return as-is *)
506 Ok label
507 end