forked from
anil.recoil.org/monopam-myspace
My aggregated monorepo of OCaml code, automaintained
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6(** WebSocket Protocol Support (RFC 6455)
7
8 This module provides functions for the WebSocket HTTP upgrade handshake.
9 WebSocket connections are established by upgrading an HTTP/1.1 connection
10 using the Upgrade mechanism.
11
12 @see <https://www.rfc-editor.org/rfc/rfc6455> RFC 6455: The WebSocket Protocol *)
13
14let src = Logs.Src.create "requests.websocket" ~doc:"WebSocket Support"
15module Log = (val Logs.src_log src : Logs.LOG)
16
17(** {1 Constants} *)
18
19(** The WebSocket protocol version per RFC 6455.
20 This is the only version defined by the RFC. *)
21let protocol_version = "13"
22
23(** The magic GUID used in Sec-WebSocket-Accept computation.
24 @see <https://www.rfc-editor.org/rfc/rfc6455#section-1.3> RFC 6455 Section 1.3 *)
25let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
26
27(** {1 Sec-WebSocket-Key Generation}
28
29 The client generates a random 16-byte value, base64-encodes it, and sends
30 it in the Sec-WebSocket-Key header. This proves the server understands
31 the WebSocket protocol.
32
33 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.1> RFC 6455 Section 4.1 *)
34
35(** Generate a random Sec-WebSocket-Key value.
36
37 Creates a cryptographically random 16-byte nonce and base64-encodes it.
38 The result is suitable for use in the Sec-WebSocket-Key header. *)
39let generate_key () =
40 let random_bytes = Mirage_crypto_rng.generate 16 in
41 let key = Base64.encode_exn random_bytes in
42 Log.debug (fun m -> m "Generated WebSocket key: %s" key);
43 key
44
45(** {1 Sec-WebSocket-Accept Computation}
46
47 The server computes Sec-WebSocket-Accept as:
48 [base64(SHA-1(Sec-WebSocket-Key + magic_guid))]
49
50 This proves the server received the client's handshake and understands
51 the WebSocket protocol.
52
53 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.2.2> RFC 6455 Section 4.2.2 *)
54
55(** Compute the expected Sec-WebSocket-Accept value for a given key.
56
57 @param key The Sec-WebSocket-Key sent by the client
58 @return The expected Sec-WebSocket-Accept value *)
59let compute_accept ~key =
60 let combined = key ^ magic_guid in
61 let hash = Digestif.SHA1.(digest_string combined |> to_raw_string) in
62 let accept = Base64.encode_exn hash in
63 Log.debug (fun m -> m "Computed WebSocket accept for key %s: %s" key accept);
64 accept
65
66(** Validate a server's Sec-WebSocket-Accept value.
67
68 @param key The Sec-WebSocket-Key that was sent
69 @param accept The Sec-WebSocket-Accept received from the server
70 @return [true] if the accept value is correct *)
71let validate_accept ~key ~accept =
72 let expected = compute_accept ~key in
73 let valid = String.equal expected accept in
74 if not valid then
75 Log.warn (fun m -> m "WebSocket accept validation failed: expected %s, got %s"
76 expected accept);
77 valid
78
79(** {1 Sec-WebSocket-Protocol Negotiation}
80
81 The client sends a list of desired subprotocols; the server selects one.
82 Common subprotocols include "graphql-ws", "graphql-transport-ws", "wamp.2.json".
83
84 @see <https://www.rfc-editor.org/rfc/rfc6455#section-11.3.4> RFC 6455 Section 11.3.4 *)
85
86(** Parse a Sec-WebSocket-Protocol header value into a list of protocols.
87
88 The header value is a comma-separated list of protocol identifiers. *)
89let parse_protocols s =
90 String.split_on_char ',' s
91 |> List.map String.trim
92 |> List.filter (fun s -> String.length s > 0)
93
94(** Format a list of protocols as a Sec-WebSocket-Protocol header value. *)
95let protocols_to_string protocols =
96 String.concat ", " protocols
97
98(** Select a protocol from the offered list that matches one we support.
99
100 @param offered The protocols offered by the client
101 @param supported The protocols we support (in preference order)
102 @return The selected protocol, or [None] if no match *)
103let select_protocol ~offered ~supported =
104 List.find_opt (fun s -> List.mem s offered) supported
105
106(** {1 Sec-WebSocket-Extensions Parsing}
107
108 Extensions provide additional capabilities like compression.
109 The most common extension is "permessage-deflate" (RFC 7692).
110
111 @see <https://www.rfc-editor.org/rfc/rfc6455#section-9> RFC 6455 Section 9
112 @see <https://www.rfc-editor.org/rfc/rfc7692> RFC 7692: Compression Extensions *)
113
114(** An extension with optional parameters.
115
116 Example: [("permessage-deflate", [("client_max_window_bits", None)])] *)
117type extension = {
118 name : string;
119 params : (string * string option) list;
120}
121
122(** Parse a single extension (name with optional parameters).
123
124 Format: [name; param1; param2=value; ...] *)
125let parse_single_extension s =
126 let parts = String.split_on_char ';' s |> List.map String.trim in
127 match parts with
128 | [] -> None
129 | name :: params ->
130 let parse_param p =
131 match String.index_opt p '=' with
132 | None -> (String.trim p, None)
133 | Some eq_idx ->
134 let key = String.trim (String.sub p 0 eq_idx) in
135 let value = String.trim (String.sub p (eq_idx + 1) (String.length p - eq_idx - 1)) in
136 (* Remove quotes if present *)
137 let value = if String.length value >= 2 && value.[0] = '"' then
138 String.sub value 1 (String.length value - 2)
139 else value
140 in
141 (key, Some value)
142 in
143 Some {
144 name = String.trim name;
145 params = List.map parse_param params;
146 }
147
148(** Parse a Sec-WebSocket-Extensions header value.
149
150 Format is comma-separated extensions, each with semicolon-separated parameters:
151 [permessage-deflate; client_max_window_bits, another-ext] *)
152let parse_extensions s =
153 (* Split on commas, but be careful of quoted values *)
154 let extensions = String.split_on_char ',' s in
155 List.filter_map parse_single_extension extensions
156
157(** Format extensions as a Sec-WebSocket-Extensions header value. *)
158let extensions_to_string extensions =
159 let ext_to_string ext =
160 let params_str = List.map (fun (k, v) ->
161 match v with
162 | None -> k
163 | Some v -> Printf.sprintf "%s=%s" k v
164 ) ext.params in
165 String.concat "; " (ext.name :: params_str)
166 in
167 String.concat ", " (List.map ext_to_string extensions)
168
169(** Check if an extension is present in a list. *)
170let has_extension ~name extensions =
171 List.exists (fun ext -> String.equal ext.name name) extensions
172
173(** Get parameters for a specific extension. *)
174let get_extension_params ~name extensions =
175 match List.find_opt (fun ext -> String.equal ext.name name) extensions with
176 | Some ext -> Some ext.params
177 | None -> None
178
179(** {1 Handshake Header Helpers} *)
180
181(** Build the headers for a WebSocket upgrade request.
182
183 @param key The Sec-WebSocket-Key (use {!generate_key} to create)
184 @param protocols Optional list of subprotocols to request
185 @param extensions Optional list of extensions to request
186 @param origin Optional Origin header value *)
187let make_upgrade_headers ~key ?protocols ?extensions ?origin () =
188 let headers = Headers.empty
189 |> Headers.set `Upgrade "websocket"
190 |> Headers.set `Connection "Upgrade"
191 |> Headers.set `Sec_websocket_key key
192 |> Headers.set `Sec_websocket_version protocol_version
193 in
194 let headers = match protocols with
195 | Some ps when ps <> [] ->
196 Headers.set `Sec_websocket_protocol (protocols_to_string ps) headers
197 | _ -> headers
198 in
199 let headers = match extensions with
200 | Some exts when exts <> [] ->
201 Headers.set `Sec_websocket_extensions (extensions_to_string exts) headers
202 | _ -> headers
203 in
204 let headers = match origin with
205 | Some o -> Headers.set `Origin o headers
206 | None -> headers
207 in
208 headers
209
210(** Helper to check if a string contains a substring. *)
211let string_contains ~needle haystack =
212 let nlen = String.length needle in
213 let hlen = String.length haystack in
214 if nlen > hlen then false
215 else
216 let rec check i =
217 if i + nlen > hlen then false
218 else if String.sub haystack i nlen = needle then true
219 else check (i + 1)
220 in
221 check 0
222
223(** Validate a WebSocket upgrade response.
224
225 Checks that:
226 - Status code is 101 (Switching Protocols)
227 - Upgrade header is "websocket"
228 - Connection header includes "Upgrade"
229 - Sec-WebSocket-Accept is correct for the given key
230
231 @param key The Sec-WebSocket-Key that was sent
232 @param status The HTTP status code
233 @param headers The response headers
234 @return [Ok ()] if valid, [Error reason] if invalid *)
235let validate_upgrade_response ~key ~status ~headers =
236 (* Check status code *)
237 if status <> 101 then
238 Error (Printf.sprintf "Expected status 101, got %d" status)
239 (* Check Upgrade header *)
240 else match Headers.get `Upgrade headers with
241 | None -> Error "Missing Upgrade header"
242 | Some upgrade when String.lowercase_ascii upgrade <> "websocket" ->
243 Error (Printf.sprintf "Upgrade header is '%s', expected 'websocket'" upgrade)
244 | Some _ ->
245 (* Check Connection header *)
246 match Headers.get `Connection headers with
247 | None -> Error "Missing Connection header"
248 | Some conn ->
249 let conn_lower = String.lowercase_ascii conn in
250 if not (string_contains ~needle:"upgrade" conn_lower) then
251 Error (Printf.sprintf "Connection header is '%s', expected 'Upgrade'" conn)
252 else
253 (* Check Sec-WebSocket-Accept *)
254 match Headers.get `Sec_websocket_accept headers with
255 | None -> Error "Missing Sec-WebSocket-Accept header"
256 | Some accept ->
257 if validate_accept ~key ~accept then
258 Ok ()
259 else
260 Error "Sec-WebSocket-Accept validation failed"