My aggregated monorepo of OCaml code, automaintained
at doc-fixes 260 lines 9.8 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6(** WebSocket Protocol Support (RFC 6455) 7 8 This module provides functions for the WebSocket HTTP upgrade handshake. 9 WebSocket connections are established by upgrading an HTTP/1.1 connection 10 using the Upgrade mechanism. 11 12 @see <https://www.rfc-editor.org/rfc/rfc6455> RFC 6455: The WebSocket Protocol *) 13 14let src = Logs.Src.create "requests.websocket" ~doc:"WebSocket Support" 15module Log = (val Logs.src_log src : Logs.LOG) 16 17(** {1 Constants} *) 18 19(** The WebSocket protocol version per RFC 6455. 20 This is the only version defined by the RFC. *) 21let protocol_version = "13" 22 23(** The magic GUID used in Sec-WebSocket-Accept computation. 24 @see <https://www.rfc-editor.org/rfc/rfc6455#section-1.3> RFC 6455 Section 1.3 *) 25let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 26 27(** {1 Sec-WebSocket-Key Generation} 28 29 The client generates a random 16-byte value, base64-encodes it, and sends 30 it in the Sec-WebSocket-Key header. This proves the server understands 31 the WebSocket protocol. 32 33 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.1> RFC 6455 Section 4.1 *) 34 35(** Generate a random Sec-WebSocket-Key value. 36 37 Creates a cryptographically random 16-byte nonce and base64-encodes it. 38 The result is suitable for use in the Sec-WebSocket-Key header. *) 39let generate_key () = 40 let random_bytes = Mirage_crypto_rng.generate 16 in 41 let key = Base64.encode_exn random_bytes in 42 Log.debug (fun m -> m "Generated WebSocket key: %s" key); 43 key 44 45(** {1 Sec-WebSocket-Accept Computation} 46 47 The server computes Sec-WebSocket-Accept as: 48 [base64(SHA-1(Sec-WebSocket-Key + magic_guid))] 49 50 This proves the server received the client's handshake and understands 51 the WebSocket protocol. 52 53 @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.2.2> RFC 6455 Section 4.2.2 *) 54 55(** Compute the expected Sec-WebSocket-Accept value for a given key. 56 57 @param key The Sec-WebSocket-Key sent by the client 58 @return The expected Sec-WebSocket-Accept value *) 59let compute_accept ~key = 60 let combined = key ^ magic_guid in 61 let hash = Digestif.SHA1.(digest_string combined |> to_raw_string) in 62 let accept = Base64.encode_exn hash in 63 Log.debug (fun m -> m "Computed WebSocket accept for key %s: %s" key accept); 64 accept 65 66(** Validate a server's Sec-WebSocket-Accept value. 67 68 @param key The Sec-WebSocket-Key that was sent 69 @param accept The Sec-WebSocket-Accept received from the server 70 @return [true] if the accept value is correct *) 71let validate_accept ~key ~accept = 72 let expected = compute_accept ~key in 73 let valid = String.equal expected accept in 74 if not valid then 75 Log.warn (fun m -> m "WebSocket accept validation failed: expected %s, got %s" 76 expected accept); 77 valid 78 79(** {1 Sec-WebSocket-Protocol Negotiation} 80 81 The client sends a list of desired subprotocols; the server selects one. 82 Common subprotocols include "graphql-ws", "graphql-transport-ws", "wamp.2.json". 83 84 @see <https://www.rfc-editor.org/rfc/rfc6455#section-11.3.4> RFC 6455 Section 11.3.4 *) 85 86(** Parse a Sec-WebSocket-Protocol header value into a list of protocols. 87 88 The header value is a comma-separated list of protocol identifiers. *) 89let parse_protocols s = 90 String.split_on_char ',' s 91 |> List.map String.trim 92 |> List.filter (fun s -> String.length s > 0) 93 94(** Format a list of protocols as a Sec-WebSocket-Protocol header value. *) 95let protocols_to_string protocols = 96 String.concat ", " protocols 97 98(** Select a protocol from the offered list that matches one we support. 99 100 @param offered The protocols offered by the client 101 @param supported The protocols we support (in preference order) 102 @return The selected protocol, or [None] if no match *) 103let select_protocol ~offered ~supported = 104 List.find_opt (fun s -> List.mem s offered) supported 105 106(** {1 Sec-WebSocket-Extensions Parsing} 107 108 Extensions provide additional capabilities like compression. 109 The most common extension is "permessage-deflate" (RFC 7692). 110 111 @see <https://www.rfc-editor.org/rfc/rfc6455#section-9> RFC 6455 Section 9 112 @see <https://www.rfc-editor.org/rfc/rfc7692> RFC 7692: Compression Extensions *) 113 114(** An extension with optional parameters. 115 116 Example: [("permessage-deflate", [("client_max_window_bits", None)])] *) 117type extension = { 118 name : string; 119 params : (string * string option) list; 120} 121 122(** Parse a single extension (name with optional parameters). 123 124 Format: [name; param1; param2=value; ...] *) 125let parse_single_extension s = 126 let parts = String.split_on_char ';' s |> List.map String.trim in 127 match parts with 128 | [] -> None 129 | name :: params -> 130 let parse_param p = 131 match String.index_opt p '=' with 132 | None -> (String.trim p, None) 133 | Some eq_idx -> 134 let key = String.trim (String.sub p 0 eq_idx) in 135 let value = String.trim (String.sub p (eq_idx + 1) (String.length p - eq_idx - 1)) in 136 (* Remove quotes if present *) 137 let value = if String.length value >= 2 && value.[0] = '"' then 138 String.sub value 1 (String.length value - 2) 139 else value 140 in 141 (key, Some value) 142 in 143 Some { 144 name = String.trim name; 145 params = List.map parse_param params; 146 } 147 148(** Parse a Sec-WebSocket-Extensions header value. 149 150 Format is comma-separated extensions, each with semicolon-separated parameters: 151 [permessage-deflate; client_max_window_bits, another-ext] *) 152let parse_extensions s = 153 (* Split on commas, but be careful of quoted values *) 154 let extensions = String.split_on_char ',' s in 155 List.filter_map parse_single_extension extensions 156 157(** Format extensions as a Sec-WebSocket-Extensions header value. *) 158let extensions_to_string extensions = 159 let ext_to_string ext = 160 let params_str = List.map (fun (k, v) -> 161 match v with 162 | None -> k 163 | Some v -> Printf.sprintf "%s=%s" k v 164 ) ext.params in 165 String.concat "; " (ext.name :: params_str) 166 in 167 String.concat ", " (List.map ext_to_string extensions) 168 169(** Check if an extension is present in a list. *) 170let has_extension ~name extensions = 171 List.exists (fun ext -> String.equal ext.name name) extensions 172 173(** Get parameters for a specific extension. *) 174let get_extension_params ~name extensions = 175 match List.find_opt (fun ext -> String.equal ext.name name) extensions with 176 | Some ext -> Some ext.params 177 | None -> None 178 179(** {1 Handshake Header Helpers} *) 180 181(** Build the headers for a WebSocket upgrade request. 182 183 @param key The Sec-WebSocket-Key (use {!generate_key} to create) 184 @param protocols Optional list of subprotocols to request 185 @param extensions Optional list of extensions to request 186 @param origin Optional Origin header value *) 187let make_upgrade_headers ~key ?protocols ?extensions ?origin () = 188 let headers = Headers.empty 189 |> Headers.set `Upgrade "websocket" 190 |> Headers.set `Connection "Upgrade" 191 |> Headers.set `Sec_websocket_key key 192 |> Headers.set `Sec_websocket_version protocol_version 193 in 194 let headers = match protocols with 195 | Some ps when ps <> [] -> 196 Headers.set `Sec_websocket_protocol (protocols_to_string ps) headers 197 | _ -> headers 198 in 199 let headers = match extensions with 200 | Some exts when exts <> [] -> 201 Headers.set `Sec_websocket_extensions (extensions_to_string exts) headers 202 | _ -> headers 203 in 204 let headers = match origin with 205 | Some o -> Headers.set `Origin o headers 206 | None -> headers 207 in 208 headers 209 210(** Helper to check if a string contains a substring. *) 211let string_contains ~needle haystack = 212 let nlen = String.length needle in 213 let hlen = String.length haystack in 214 if nlen > hlen then false 215 else 216 let rec check i = 217 if i + nlen > hlen then false 218 else if String.sub haystack i nlen = needle then true 219 else check (i + 1) 220 in 221 check 0 222 223(** Validate a WebSocket upgrade response. 224 225 Checks that: 226 - Status code is 101 (Switching Protocols) 227 - Upgrade header is "websocket" 228 - Connection header includes "Upgrade" 229 - Sec-WebSocket-Accept is correct for the given key 230 231 @param key The Sec-WebSocket-Key that was sent 232 @param status The HTTP status code 233 @param headers The response headers 234 @return [Ok ()] if valid, [Error reason] if invalid *) 235let validate_upgrade_response ~key ~status ~headers = 236 (* Check status code *) 237 if status <> 101 then 238 Error (Printf.sprintf "Expected status 101, got %d" status) 239 (* Check Upgrade header *) 240 else match Headers.get `Upgrade headers with 241 | None -> Error "Missing Upgrade header" 242 | Some upgrade when String.lowercase_ascii upgrade <> "websocket" -> 243 Error (Printf.sprintf "Upgrade header is '%s', expected 'websocket'" upgrade) 244 | Some _ -> 245 (* Check Connection header *) 246 match Headers.get `Connection headers with 247 | None -> Error "Missing Connection header" 248 | Some conn -> 249 let conn_lower = String.lowercase_ascii conn in 250 if not (string_contains ~needle:"upgrade" conn_lower) then 251 Error (Printf.sprintf "Connection header is '%s', expected 'Upgrade'" conn) 252 else 253 (* Check Sec-WebSocket-Accept *) 254 match Headers.get `Sec_websocket_accept headers with 255 | None -> Error "Missing Sec-WebSocket-Accept header" 256 | Some accept -> 257 if validate_accept ~key ~accept then 258 Ok () 259 else 260 Error "Sec-WebSocket-Accept validation failed"