My aggregated monorepo of OCaml code, automaintained
at doc-fixes 553 lines 21 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6let src = Logs.Src.create "claude.client" ~doc:"Claude client" 7 8module Log = (val Logs.src_log src : Logs.LOG) 9 10(** Control response builders using Sdk_control codecs *) 11module Control_response = struct 12 let success ~request_id ~response = 13 let resp = Sdk_control.Response.success ~request_id ?response () in 14 let ctrl = Sdk_control.create_response ~response:resp () in 15 Jsont.Json.encode Sdk_control.jsont ctrl 16 |> Err.get_ok ~msg:"Control_response.success: " 17 18 let error ~request_id ~code ~message ?data () = 19 let error_detail = Sdk_control.Response.error_detail ~code ~message ?data () in 20 let resp = Sdk_control.Response.error ~request_id ~error:error_detail () in 21 let ctrl = Sdk_control.create_response ~response:resp () in 22 Jsont.Json.encode Sdk_control.jsont ctrl 23 |> Err.get_ok ~msg:"Control_response.error: " 24end 25 26(* Helper functions for JSON manipulation using jsont *) 27let json_to_string json = 28 Jsont_bytesrw.encode_string' Jsont.json json 29 |> Result.map_error Jsont.Error.to_string 30 |> Err.get_ok ~msg:"" 31 32(** Wire-level codec for hook matcher configuration sent to CLI. *) 33module Hook_matcher_wire = struct 34 type t = { matcher : string option; hook_callback_ids : string list } 35 36 let jsont : t Jsont.t = 37 let make matcher hook_callback_ids = { matcher; hook_callback_ids } in 38 Jsont.Object.map ~kind:"HookMatcherWire" make 39 |> Jsont.Object.opt_mem "matcher" Jsont.string ~enc:(fun r -> r.matcher) 40 |> Jsont.Object.mem "hookCallbackIds" (Jsont.list Jsont.string) 41 ~enc:(fun r -> r.hook_callback_ids) 42 |> Jsont.Object.finish 43 44 let encode matchers = 45 List.map 46 (fun m -> 47 Jsont.Json.encode jsont m 48 |> Err.get_ok ~msg:"Hook_matcher_wire.encode: ") 49 matchers 50 |> Jsont.Json.list 51end 52 53type t = { 54 transport : Transport.t; 55 mutable permission_callback : Permissions.callback option; 56 mutable permission_log : Permissions.Rule.t list ref option; 57 hook_callbacks : (string, Jsont.json -> Proto.Hooks.result) Hashtbl.t; 58 mutable session_id : string option; 59 control_responses : (string, Jsont.json) Hashtbl.t; 60 control_mutex : Eio.Mutex.t; 61 control_condition : Eio.Condition.t; 62 clock : float Eio.Time.clock_ty Eio.Resource.t; 63 (* Track tool_use_ids we've already responded to, preventing duplicates *) 64 responded_tool_ids : (string, unit) Hashtbl.t; 65 (* In-process MCP servers for custom tools *) 66 mcp_servers : (string, Mcp_server.t) Hashtbl.t; 67} 68 69let session_id t = t.session_id 70 71let handle_control_request t (ctrl_req : Sdk_control.control_request) = 72 let request_id = ctrl_req.request_id in 73 Log.info (fun m -> m "Handling control request: %s" request_id); 74 75 match ctrl_req.request with 76 | Sdk_control.Request.Permission req -> 77 let tool_name = req.tool_name in 78 let input_json = req.input in 79 Log.info (fun m -> 80 m "Permission request for tool '%s' with input: %s" tool_name 81 (json_to_string input_json)); 82 (* Convert permission_suggestions to suggested rules *) 83 let suggestions = Option.value req.permission_suggestions ~default:[] in 84 let suggested_rules = Permissions.extract_rules_from_proto_updates suggestions in 85 86 (* Convert input to Tool_input.t *) 87 let input = Tool_input.of_json input_json in 88 89 (* Create context *) 90 let context : Permissions.context = 91 { tool_name; input; suggested_rules } 92 in 93 94 Log.info (fun m -> 95 m "Invoking permission callback for tool: %s" tool_name); 96 let callback = 97 Option.value t.permission_callback 98 ~default:Permissions.default_allow 99 in 100 let decision = callback context in 101 Log.info (fun m -> 102 m "Permission callback returned: %s" 103 (if Permissions.Decision.is_allow decision then "ALLOW" else "DENY")); 104 105 (* Convert permission decision to proto result *) 106 let proto_result = Permissions.Decision.to_proto_result ~original_input:input decision in 107 108 (* Encode to JSON *) 109 let response_data = 110 match Jsont.Json.encode Proto.Permissions.Result.jsont proto_result with 111 | Ok json -> json 112 | Error err -> 113 Log.err (fun m -> m "Failed to encode permission result: %s" err); 114 failwith "Permission result encoding failed" 115 in 116 let response = 117 Control_response.success ~request_id ~response:(Some response_data) 118 in 119 Log.info (fun m -> 120 m "Sending control response: %s" (json_to_string response)); 121 Transport.send t.transport response 122 | Sdk_control.Request.Hook_callback req -> ( 123 let callback_id = req.callback_id in 124 let input = req.input in 125 let _tool_use_id = req.tool_use_id in 126 Log.info (fun m -> 127 m "Hook callback request for callback_id: %s" callback_id); 128 129 try 130 let callback = Hashtbl.find t.hook_callbacks callback_id in 131 let result = callback input in 132 133 let result_json = 134 Jsont.Json.encode Proto.Hooks.result_jsont result 135 |> Err.get_ok ~msg:"Failed to encode hook result: " 136 in 137 Log.debug (fun m -> 138 m "Hook result JSON: %s" (json_to_string result_json)); 139 let response = 140 Control_response.success ~request_id ~response:(Some result_json) 141 in 142 Log.info (fun m -> m "Hook callback succeeded, sending response"); 143 Transport.send t.transport response 144 with 145 | Not_found -> 146 let error_msg = 147 Printf.sprintf "Hook callback not found: %s" callback_id 148 in 149 Log.err (fun m -> m "%s" error_msg); 150 Transport.send t.transport 151 (Control_response.error ~request_id ~code:`Method_not_found ~message:error_msg ()) 152 | exn -> 153 let error_msg = 154 Printf.sprintf "Hook callback error: %s" (Printexc.to_string exn) 155 in 156 Log.err (fun m -> m "%s" error_msg); 157 Transport.send t.transport 158 (Control_response.error ~request_id ~code:`Internal_error ~message:error_msg ())) 159 | Sdk_control.Request.Mcp_message req -> ( 160 (* Handle MCP request for in-process SDK servers *) 161 let module J = Jsont.Json in 162 let server_name = req.server_name in 163 let message = req.message in 164 Log.info (fun m -> m "MCP request for server '%s'" server_name); 165 166 match Hashtbl.find_opt t.mcp_servers server_name with 167 | None -> 168 let error_msg = Printf.sprintf "MCP server '%s' not found" server_name in 169 Log.err (fun m -> m "%s" error_msg); 170 (* Return JSONRPC error in mcp_response format *) 171 let mcp_error = J.object' [ 172 J.mem (J.name "jsonrpc") (J.string "2.0"); 173 J.mem (J.name "id") (J.null ()); 174 J.mem (J.name "error") (J.object' [ 175 J.mem (J.name "code") (J.number (-32601.0)); 176 J.mem (J.name "message") (J.string error_msg) 177 ]) 178 ] in 179 let response_data = J.object' [J.mem (J.name "mcp_response") mcp_error] in 180 let response = Control_response.success ~request_id ~response:(Some response_data) in 181 Transport.send t.transport response 182 | Some server -> 183 let mcp_response = Mcp_server.handle_json_message server message in 184 Log.debug (fun m -> m "MCP response: %s" (json_to_string mcp_response)); 185 let response_data = J.object' [J.mem (J.name "mcp_response") mcp_response] in 186 let response = Control_response.success ~request_id ~response:(Some response_data) in 187 Transport.send t.transport response) 188 | _ -> 189 (* Other request types not handled here *) 190 let error_msg = "Unsupported control request type" in 191 Transport.send t.transport 192 (Control_response.error ~request_id ~code:`Invalid_request ~message:error_msg ()) 193 194let handle_control_response t control_resp = 195 let request_id = 196 match control_resp.Sdk_control.response with 197 | Sdk_control.Response.Success s -> s.request_id 198 | Sdk_control.Response.Error e -> e.request_id 199 in 200 Log.debug (fun m -> 201 m "Received control response for request_id: %s" request_id); 202 203 (* Store the response as JSON and signal waiting threads *) 204 let json = 205 Jsont.Json.encode Sdk_control.control_response_jsont control_resp 206 |> Err.get_ok ~msg:"Failed to encode control response: " 207 in 208 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () -> 209 Hashtbl.replace t.control_responses request_id json; 210 Eio.Condition.broadcast t.control_condition) 211 212let handle_raw_messages t = 213 let rec loop () = 214 match Transport.receive_line t.transport with 215 | None -> 216 (* EOF *) 217 Log.debug (fun m -> m "Handle messages: EOF received"); 218 Seq.Nil 219 | Some line -> ( 220 (* Use unified Incoming codec for all message types *) 221 match Jsont_bytesrw.decode_string' Incoming.jsont line with 222 | Ok incoming -> 223 Seq.Cons (incoming, loop) 224 | Error err -> 225 Log.err (fun m -> 226 m "Failed to decode incoming message: %s\nLine: %s" 227 (Jsont.Error.to_string err) 228 line); 229 loop ()) 230 in 231 Log.debug (fun m -> m "Starting message handler"); 232 loop 233 234let handle_messages t = 235 let raw_seq = handle_raw_messages t in 236 let rec loop raw_seq = 237 match raw_seq () with 238 | Seq.Nil -> Seq.Nil 239 | Seq.Cons (incoming, rest) -> ( 240 match incoming with 241 | Incoming.Message msg -> 242 Log.info (fun m -> m "← %a" Message.pp msg); 243 244 (* Extract session ID from system messages *) 245 (match msg with 246 | Message.System sys -> 247 Message.System.session_id sys 248 |> Option.iter (fun session_id -> 249 t.session_id <- Some session_id; 250 Log.debug (fun m -> m "Stored session ID: %s" session_id)) 251 | _ -> ()); 252 253 (* Convert message to response events *) 254 let responses = Response.of_message msg in 255 emit_responses responses rest 256 | Incoming.Control_response resp -> 257 handle_control_response t resp; 258 loop rest 259 | Incoming.Control_request ctrl_req -> 260 Log.info (fun m -> 261 m "Received control request (request_id: %s)" 262 ctrl_req.request_id); 263 handle_control_request t ctrl_req; 264 loop rest) 265 266 and emit_responses responses rest = 267 match responses with 268 | [] -> loop rest 269 | r :: rs -> Seq.Cons (r, fun () -> emit_responses rs rest) 270 in 271 loop raw_seq 272 273let create ?(options = Options.default) ~sw ~process_mgr ~clock () = 274 (* Automatically enable permission prompt tool when callback is configured 275 (matching Python SDK behavior in client.py:104-121) *) 276 let options = 277 match Options.permission_callback options with 278 | Some _ when Options.permission_prompt_tool_name options = None -> 279 (* Set permission_prompt_tool_name to "stdio" to enable control protocol *) 280 Options.with_permission_prompt_tool_name "stdio" options 281 | _ -> options 282 in 283 let transport = Transport.create ~sw ~process_mgr ~options () in 284 285 (* Setup hook callbacks *) 286 let hook_callbacks = Hashtbl.create 16 in 287 let next_callback_id = ref 0 in 288 289 (* Setup MCP servers from options *) 290 let mcp_servers_ht = Hashtbl.create 16 in 291 List.iter (fun (name, server) -> 292 Log.info (fun m -> m "Registering MCP server: %s" name); 293 Hashtbl.add mcp_servers_ht name server 294 ) (Options.mcp_servers options); 295 296 let t = 297 { 298 transport; 299 permission_callback = Options.permission_callback options; 300 permission_log = None; 301 hook_callbacks; 302 session_id = None; 303 control_responses = Hashtbl.create 16; 304 control_mutex = Eio.Mutex.create (); 305 control_condition = Eio.Condition.create (); 306 clock; 307 responded_tool_ids = Hashtbl.create 16; 308 mcp_servers = mcp_servers_ht; 309 } 310 in 311 312 (* Register hooks and send initialize if hooks are configured *) 313 Options.hooks options 314 |> Option.iter (fun hooks_config -> 315 Log.info (fun m -> m "Registering hooks..."); 316 317 (* Get callbacks in wire format from the new Hooks API *) 318 let callbacks_by_event = Hooks.get_callbacks hooks_config in 319 320 (* Build hooks configuration with callback IDs as (string * Jsont.json) list *) 321 let hooks_list = 322 List.map 323 (fun (event, matchers) -> 324 let event_name = Proto.Hooks.event_to_string event in 325 let matcher_wires = 326 List.map 327 (fun (pattern, callback) -> 328 let callback_id = 329 Printf.sprintf "hook_%d" !next_callback_id 330 in 331 incr next_callback_id; 332 Hashtbl.add hook_callbacks callback_id callback; 333 Log.debug (fun m -> 334 m "Registered callback: %s for event: %s" 335 callback_id event_name); 336 Hook_matcher_wire. 337 { 338 matcher = pattern; 339 hook_callback_ids = [callback_id]; 340 }) 341 matchers 342 in 343 (event_name, Hook_matcher_wire.encode matcher_wires)) 344 callbacks_by_event 345 in 346 347 (* Create initialize request using Sdk_control codec *) 348 let request = Sdk_control.Request.initialize ~hooks:hooks_list () in 349 let ctrl_req = 350 Sdk_control.create_request ~request_id:"init_hooks" ~request () 351 in 352 let initialize_msg = 353 Jsont.Json.encode Sdk_control.jsont ctrl_req 354 |> Err.get_ok ~msg:"Failed to encode initialize request: " 355 in 356 Log.info (fun m -> m "Sending hooks initialize request"); 357 Transport.send t.transport initialize_msg); 358 359 t 360 361(* Helper to send a message with proper "type" wrapper via Proto.Outgoing *) 362let send_message t msg = 363 Log.info (fun m -> m "→ %a" Message.pp msg); 364 let proto_msg = Message.to_proto msg in 365 let outgoing = Proto.Outgoing.Message proto_msg in 366 let json = Proto.Outgoing.to_json outgoing in 367 Transport.send t.transport json 368 369let query t prompt = 370 let msg = Message.user_string prompt in 371 send_message t msg 372 373let respond_to_tool t ~tool_use_id ~content ?(is_error = false) () = 374 (* Check for duplicate response - prevents API errors from multiple responses *) 375 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin 376 Log.warn (fun m -> m "Skipping duplicate tool response for tool_use_id: %s" tool_use_id) 377 end else begin 378 Hashtbl.add t.responded_tool_ids tool_use_id (); 379 let user_msg = Message.User.with_tool_result ~tool_use_id ~content ~is_error () in 380 let msg = Message.User user_msg in 381 send_message t msg 382 end 383 384let respond_to_tools t responses = 385 (* Filter out duplicates *) 386 let new_responses = List.filter (fun (tool_use_id, _, _) -> 387 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin 388 Log.warn (fun m -> m "Skipping duplicate tool response for tool_use_id: %s" tool_use_id); 389 false 390 end else begin 391 Hashtbl.add t.responded_tool_ids tool_use_id (); 392 true 393 end 394 ) responses in 395 if new_responses <> [] then begin 396 let tool_results = 397 List.map 398 (fun (tool_use_id, content, is_error_opt) -> 399 let is_error = Option.value is_error_opt ~default:false in 400 Content_block.tool_result ~tool_use_id ~content ~is_error ()) 401 new_responses 402 in 403 let user_msg = Message.User.of_blocks tool_results in 404 let msg = Message.User user_msg in 405 send_message t msg 406 end 407 408let clear_tool_response_tracking t = 409 Hashtbl.clear t.responded_tool_ids 410 411let receive t = fun () -> handle_messages t 412 413let run t ~handler = 414 Seq.iter (Handler.dispatch handler) (receive t) 415 416let receive_all t = 417 let rec collect acc seq = 418 match seq () with 419 | Seq.Nil -> 420 Log.debug (fun m -> 421 m "End of response sequence (%d responses)" (List.length acc)); 422 List.rev acc 423 | Seq.Cons ((Response.Complete _ as resp), _) -> 424 Log.debug (fun m -> m "Received final Complete response"); 425 List.rev (resp :: acc) 426 | Seq.Cons (resp, rest) -> collect (resp :: acc) rest 427 in 428 collect [] (receive t) 429 430let interrupt t = Transport.interrupt t.transport 431 432let enable_permission_discovery t = 433 let log = ref [] in 434 let callback = Permissions.discovery log in 435 t.permission_callback <- Some callback; 436 t.permission_log <- Some log 437 438let discovered_permissions t = 439 t.permission_log |> Option.map ( ! ) |> Option.value ~default:[] 440 441(* Helper to send a control request and wait for response *) 442let send_control_request t ~request_id request = 443 (* Send the control request *) 444 let control_msg = Sdk_control.create_request ~request_id ~request () in 445 let json = 446 Jsont.Json.encode Sdk_control.jsont control_msg 447 |> Err.get_ok ~msg:"Failed to encode control request: " 448 in 449 Log.info (fun m -> m "Sending control request: %s" (json_to_string json)); 450 Transport.send t.transport json; 451 452 (* Wait for the response with timeout *) 453 let max_wait = 10.0 in 454 (* 10 seconds timeout *) 455 let start_time = Eio.Time.now t.clock in 456 457 let rec wait_for_response () = 458 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () -> 459 match Hashtbl.find_opt t.control_responses request_id with 460 | Some response_json -> 461 (* Remove it from the table *) 462 Hashtbl.remove t.control_responses request_id; 463 response_json 464 | None -> 465 let elapsed = Eio.Time.now t.clock -. start_time in 466 if elapsed > max_wait then 467 raise 468 (Failure 469 (Printf.sprintf "Timeout waiting for control response: %s" 470 request_id)) 471 else ( 472 (* Release mutex and wait for signal *) 473 Eio.Condition.await_no_mutex t.control_condition; 474 wait_for_response ())) 475 in 476 477 let response_json = wait_for_response () in 478 Log.debug (fun m -> 479 m "Received control response: %s" (json_to_string response_json)); 480 481 (* Parse the response - extract the "response" field using jsont codec *) 482 let response_field_codec = 483 Jsont.Object.map ~kind:"ResponseField" Fun.id 484 |> Jsont.Object.mem "response" Jsont.json ~enc:Fun.id 485 |> Jsont.Object.finish 486 in 487 let response_data = 488 Jsont.Json.decode response_field_codec response_json 489 |> Err.get_ok' ~msg:"Failed to extract response field: " 490 in 491 let response = 492 Jsont.Json.decode Sdk_control.Response.jsont response_data 493 |> Err.get_ok' ~msg:"Failed to decode response: " 494 in 495 match response with 496 | Sdk_control.Response.Success s -> s.response 497 | Sdk_control.Response.Error e -> 498 raise (Failure (Printf.sprintf "Control request failed: [%d] %s" e.error.code e.error.message)) 499 500let set_permission_mode t mode = 501 let request_id = Printf.sprintf "set_perm_mode_%f" (Eio.Time.now t.clock) in 502 let proto_mode = Permissions.Mode.to_proto mode in 503 let request = Sdk_control.Request.set_permission_mode ~mode:proto_mode () in 504 let _response = send_control_request t ~request_id request in 505 Log.info (fun m -> 506 m "Permission mode set to: %s" (Permissions.Mode.to_string mode)) 507 508let set_model t model = 509 let model_str = Model.to_string model in 510 let request_id = Printf.sprintf "set_model_%f" (Eio.Time.now t.clock) in 511 let request = Sdk_control.Request.set_model ~model:model_str () in 512 let _response = send_control_request t ~request_id request in 513 Log.info (fun m -> m "Model set to: %s" model_str) 514 515let get_server_info t = 516 let request_id = Printf.sprintf "get_server_info_%f" (Eio.Time.now t.clock) in 517 let request = Sdk_control.Request.get_server_info () in 518 let response_data = 519 send_control_request t ~request_id request 520 |> Option.to_result ~none:"No response data from get_server_info request" 521 |> Err.get_ok ~msg:"" 522 in 523 let server_info = 524 Jsont.Json.decode Sdk_control.Server_info.jsont response_data 525 |> Err.get_ok' ~msg:"Failed to decode server info: " 526 in 527 Log.info (fun m -> 528 m "Retrieved server info: %a" 529 (Jsont.pp_value Sdk_control.Server_info.jsont ()) 530 server_info); 531 Server_info.of_sdk_control server_info 532 533module Advanced = struct 534 let send_message t msg = send_message t msg 535 536 let send_user_message t user_msg = 537 let msg = Message.User user_msg in 538 send_message t msg 539 540 let send_raw t control = 541 let json = 542 Jsont.Json.encode Sdk_control.jsont control 543 |> Err.get_ok ~msg:"Failed to encode control message: " 544 in 545 Log.info (fun m -> m "→ Raw control: %s" (json_to_string json)); 546 Transport.send t.transport json 547 548 let send_json t json = 549 Log.info (fun m -> m "→ Raw JSON: %s" (json_to_string json)); 550 Transport.send t.transport json 551 552 let receive_raw t = handle_raw_messages t 553end