forked from
anil.recoil.org/ocaml-requests
A batteries included HTTP/1.1 client in OCaml
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6(** WebSocket Protocol Support (RFC 6455)
7
8 This module provides functions for the WebSocket HTTP upgrade handshake.
9 WebSocket connections are established by upgrading an HTTP/1.1 connection
10 using the Upgrade mechanism.
11
12 @see <https://www.rfc-editor.org/rfc/rfc6455>
13 RFC 6455: The WebSocket Protocol *)
14
15let src = Logs.Src.create "requests.websocket" ~doc:"WebSocket Support"
16
17module Log = (val Logs.src_log src : Logs.LOG)
18
19let err_unexpected_status status =
20 Error (Fmt.str "Expected status 101, got %d" status)
21
22let err_bad_upgrade upgrade =
23 Error (Fmt.str "Upgrade header is '%s', expected 'websocket'" upgrade)
24
25(** {1 Constants} *)
26
27(** The WebSocket protocol version per RFC 6455. This is the only version
28 defined by the RFC. *)
29let protocol_version = "13"
30
31(** The magic GUID used in Sec-WebSocket-Accept computation.
32 @see <https://www.rfc-editor.org/rfc/rfc6455#section-1.3>
33 RFC 6455 Section 1.3 *)
34let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
35
36(** {1 Sec-WebSocket-Key Generation}
37
38 The client generates a random 16-byte value, base64-encodes it, and sends it
39 in the Sec-WebSocket-Key header. This proves the server understands the
40 WebSocket protocol.
41
42 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.1>
43 RFC 6455 Section 4.1 *)
44
45(** Generate a random Sec-WebSocket-Key value.
46
47 Creates a cryptographically random 16-byte nonce and base64-encodes it. The
48 result is suitable for use in the Sec-WebSocket-Key header. *)
49let generate_key () =
50 let random_bytes = Crypto_rng.generate 16 in
51 let key = Base64.encode_exn random_bytes in
52 Log.debug (fun m -> m "Generated WebSocket key: %s" key);
53 key
54
55(** {1 Sec-WebSocket-Accept Computation}
56
57 The server computes Sec-WebSocket-Accept as:
58 [base64(SHA-1(Sec-WebSocket-Key + magic_guid))]
59
60 This proves the server received the client's handshake and understands the
61 WebSocket protocol.
62
63 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.2.2>
64 RFC 6455 Section 4.2.2 *)
65
66(** Compute the expected Sec-WebSocket-Accept value for a given key.
67
68 @param key The Sec-WebSocket-Key sent by the client
69 @return The expected Sec-WebSocket-Accept value *)
70let compute_accept ~key =
71 let combined = key ^ magic_guid in
72 let hash = Digestif.SHA1.(digest_string combined |> to_raw_string) in
73 let accept = Base64.encode_exn hash in
74 Log.debug (fun m -> m "Computed WebSocket accept for key %s: %s" key accept);
75 accept
76
77(** Validate a server's Sec-WebSocket-Accept value.
78
79 @param key The Sec-WebSocket-Key that was sent
80 @param accept The Sec-WebSocket-Accept received from the server
81 @return [true] if the accept value is correct *)
82let validate_accept ~key ~accept =
83 let expected = compute_accept ~key in
84 let valid = String.equal expected accept in
85 if not valid then
86 Log.warn (fun m ->
87 m "WebSocket accept validation failed: expected %s, got %s" expected
88 accept);
89 valid
90
91(** {1 Sec-WebSocket-Protocol Negotiation}
92
93 The client sends a list of desired subprotocols; the server selects one.
94 Common subprotocols include "graphql-ws", "graphql-transport-ws",
95 "wamp.2.json".
96
97 @see <https://www.rfc-editor.org/rfc/rfc6455#section-11.3.4>
98 RFC 6455 Section 11.3.4 *)
99
100(** Parse a Sec-WebSocket-Protocol header value into a list of protocols.
101
102 The header value is a comma-separated list of protocol identifiers. *)
103let parse_protocols s =
104 String.split_on_char ',' s |> List.map String.trim
105 |> List.filter (fun s -> String.length s > 0)
106
107(** Format a list of protocols as a Sec-WebSocket-Protocol header value. *)
108let protocols_to_string protocols = String.concat ", " protocols
109
110(** Select a protocol from the offered list that matches one we support.
111
112 @param offered The protocols offered by the client
113 @param supported The protocols we support (in preference order)
114 @return The selected protocol, or [None] if no match *)
115let select_protocol ~offered ~supported =
116 List.find_opt (fun s -> List.mem s offered) supported
117
118(** {1 Sec-WebSocket-Extensions Parsing}
119
120 Extensions provide additional capabilities like compression. The most common
121 extension is "permessage-deflate" (RFC 7692).
122
123 @see <https://www.rfc-editor.org/rfc/rfc6455#section-9> RFC 6455 Section 9
124 @see <https://www.rfc-editor.org/rfc/rfc7692>
125 RFC 7692: Compression Extensions *)
126
127type extension = { name : string; params : (string * string option) list }
128(** An extension with optional parameters.
129
130 Example: [("permessage-deflate", [("client_max_window_bits", None)])] *)
131
132(** Parse a single extension (name with optional parameters).
133
134 Format: [name; param1; param2=value; ...] *)
135let parse_single_extension s =
136 let parts = String.split_on_char ';' s |> List.map String.trim in
137 match parts with
138 | [] -> None
139 | name :: params ->
140 let parse_param p =
141 match String.index_opt p '=' with
142 | None -> (String.trim p, None)
143 | Some eq_idx ->
144 let key = String.trim (String.sub p 0 eq_idx) in
145 let value =
146 String.trim
147 (String.sub p (eq_idx + 1) (String.length p - eq_idx - 1))
148 in
149 (* Remove quotes if present *)
150 let value =
151 if String.length value >= 2 && value.[0] = '"' then
152 String.sub value 1 (String.length value - 2)
153 else value
154 in
155 (key, Some value)
156 in
157 Some { name = String.trim name; params = List.map parse_param params }
158
159(** Parse a Sec-WebSocket-Extensions header value.
160
161 Format is comma-separated extensions, each with semicolon-separated
162 parameters: [permessage-deflate; client_max_window_bits, another-ext] *)
163let parse_extensions s =
164 (* Split on commas, but be careful of quoted values *)
165 let extensions = String.split_on_char ',' s in
166 List.filter_map parse_single_extension extensions
167
168(** Format extensions as a Sec-WebSocket-Extensions header value. *)
169let extensions_to_string extensions =
170 let ext_to_string ext =
171 let params_str =
172 List.map
173 (fun (k, v) -> match v with None -> k | Some v -> Fmt.str "%s=%s" k v)
174 ext.params
175 in
176 String.concat "; " (ext.name :: params_str)
177 in
178 String.concat ", " (List.map ext_to_string extensions)
179
180(** Check if an extension is present in a list. *)
181let has_extension ~name extensions =
182 List.exists (fun ext -> String.equal ext.name name) extensions
183
184(** Get parameters for a specific extension. *)
185let extension_params ~name extensions =
186 match List.find_opt (fun ext -> String.equal ext.name name) extensions with
187 | Some ext -> Some ext.params
188 | None -> None
189
190(** {1 Handshake Header Helpers} *)
191
192(** Build the headers for a WebSocket upgrade request.
193
194 @param key The Sec-WebSocket-Key (use {!generate_key} to create)
195 @param protocols Optional list of subprotocols to request
196 @param extensions Optional list of extensions to request
197 @param origin Optional Origin header value *)
198let upgrade_headers ~key ?protocols ?extensions ?origin () =
199 let headers =
200 Headers.empty
201 |> Headers.set `Upgrade "websocket"
202 |> Headers.set `Connection "Upgrade"
203 |> Headers.set `Sec_websocket_key key
204 |> Headers.set `Sec_websocket_version protocol_version
205 in
206 let headers =
207 match protocols with
208 | Some ps when ps <> [] ->
209 Headers.set `Sec_websocket_protocol (protocols_to_string ps) headers
210 | _ -> headers
211 in
212 let headers =
213 match extensions with
214 | Some exts when exts <> [] ->
215 Headers.set `Sec_websocket_extensions
216 (extensions_to_string exts)
217 headers
218 | _ -> headers
219 in
220 let headers =
221 match origin with
222 | Some o -> Headers.set `Origin o headers
223 | None -> headers
224 in
225 headers
226
227(** Helper to check if a string contains a substring. *)
228let string_contains ~needle haystack =
229 let nlen = String.length needle in
230 let hlen = String.length haystack in
231 if nlen > hlen then false
232 else
233 let rec check i =
234 if i + nlen > hlen then false
235 else if String.sub haystack i nlen = needle then true
236 else check (i + 1)
237 in
238 check 0
239
240(** Validate a WebSocket upgrade response.
241
242 Checks that:
243 - Status code is 101 (Switching Protocols)
244 - Upgrade header is "websocket"
245 - Connection header includes "Upgrade"
246 - Sec-WebSocket-Accept is correct for the given key
247
248 @param key The Sec-WebSocket-Key that was sent
249 @param status The HTTP status code
250 @param headers The response headers
251 @return [Ok ()] if valid, [Error reason] if invalid *)
252let validate_upgrade_response ~key ~status ~headers =
253 (* Check status code *)
254 if status <> 101 then err_unexpected_status status
255 (* Check Upgrade header *)
256 else
257 match Headers.find `Upgrade headers with
258 | None -> Error "Missing Upgrade header"
259 | Some upgrade when String.lowercase_ascii upgrade <> "websocket" ->
260 err_bad_upgrade upgrade
261 | Some _ -> (
262 (* Check Connection header *)
263 match Headers.find `Connection headers with
264 | None -> Error "Missing Connection header"
265 | Some conn -> (
266 let conn_lower = String.lowercase_ascii conn in
267 if not (string_contains ~needle:"upgrade" conn_lower) then
268 Error
269 (Fmt.str "Connection header is '%s', expected 'Upgrade'" conn)
270 else
271 (* Check Sec-WebSocket-Accept *)
272 match Headers.find `Sec_websocket_accept headers with
273 | None -> Error "Missing Sec-WebSocket-Accept header"
274 | Some accept ->
275 if validate_accept ~key ~accept then Ok ()
276 else Error "Sec-WebSocket-Accept validation failed"))