OCaml Claude SDK using Eio and Jsont
at main 598 lines 22 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6let src = Logs.Src.create "claude.client" ~doc:"Claude client" 7 8module Log = (val Logs.src_log src : Logs.LOG) 9 10(** Control response builders using Sdk_control codecs *) 11module Control_response = struct 12 let success ~request_id ~response = 13 let resp = Sdk_control.Response.success ~request_id ?response () in 14 let ctrl = Sdk_control.create_response ~response:resp () in 15 Jsont.Json.encode Sdk_control.jsont ctrl 16 |> Err.get_ok ~msg:"Control_response.success: " 17 18 let error ~request_id ~code ~message ?data () = 19 let error_detail = 20 Sdk_control.Response.error_detail ~code ~message ?data () 21 in 22 let resp = Sdk_control.Response.error ~request_id ~error:error_detail () in 23 let ctrl = Sdk_control.create_response ~response:resp () in 24 Jsont.Json.encode Sdk_control.jsont ctrl 25 |> Err.get_ok ~msg:"Control_response.error: " 26end 27 28(* Helper functions for JSON manipulation using jsont *) 29let json_to_string json = 30 Jsont_bytesrw.encode_string' Jsont.json json 31 |> Result.map_error Jsont.Error.to_string 32 |> Err.get_ok ~msg:"" 33 34(** Wire-level codec for hook matcher configuration sent to CLI. *) 35module Hook_matcher_wire = struct 36 type t = { matcher : string option; hook_callback_ids : string list } 37 38 let jsont : t Jsont.t = 39 let make matcher hook_callback_ids = { matcher; hook_callback_ids } in 40 Jsont.Object.map ~kind:"HookMatcherWire" make 41 |> Jsont.Object.opt_mem "matcher" Jsont.string ~enc:(fun r -> r.matcher) 42 |> Jsont.Object.mem "hookCallbackIds" (Jsont.list Jsont.string) 43 ~enc:(fun r -> r.hook_callback_ids) 44 |> Jsont.Object.finish 45 46 let encode matchers = 47 List.map 48 (fun m -> 49 Jsont.Json.encode jsont m 50 |> Err.get_ok ~msg:"Hook_matcher_wire.encode: ") 51 matchers 52 |> Jsont.Json.list 53end 54 55type t = { 56 transport : Transport.t; 57 mutable permission_callback : Permissions.callback option; 58 mutable permission_log : Permissions.Rule.t list ref option; 59 hook_callbacks : (string, Jsont.json -> Proto.Hooks.result) Hashtbl.t; 60 mutable session_id : string option; 61 control_responses : (string, Jsont.json) Hashtbl.t; 62 control_mutex : Eio.Mutex.t; 63 control_condition : Eio.Condition.t; 64 clock : float Eio.Time.clock_ty Eio.Resource.t; 65 (* Track tool_use_ids we've already responded to, preventing duplicates *) 66 responded_tool_ids : (string, unit) Hashtbl.t; 67 (* In-process MCP servers for custom tools *) 68 mcp_servers : (string, Mcp_server.t) Hashtbl.t; 69} 70 71let session_id t = t.session_id 72 73let handle_control_request t (ctrl_req : Sdk_control.control_request) = 74 let request_id = ctrl_req.request_id in 75 Log.info (fun m -> m "Handling control request: %s" request_id); 76 77 match ctrl_req.request with 78 | Sdk_control.Request.Permission req -> 79 let tool_name = req.tool_name in 80 let input_json = req.input in 81 Log.info (fun m -> 82 m "Permission request for tool '%s' with input: %s" tool_name 83 (json_to_string input_json)); 84 (* Convert permission_suggestions to suggested rules *) 85 let suggestions = Option.value req.permission_suggestions ~default:[] in 86 let suggested_rules = 87 Permissions.extract_rules_from_proto_updates suggestions 88 in 89 90 (* Convert input to Tool_input.t *) 91 let input = Tool_input.of_json input_json in 92 93 (* Create context *) 94 let context : Permissions.context = 95 { tool_name; input; suggested_rules } 96 in 97 98 Log.info (fun m -> 99 m "Invoking permission callback for tool: %s" tool_name); 100 let callback = 101 Option.value t.permission_callback ~default:Permissions.default_allow 102 in 103 let decision = callback context in 104 Log.info (fun m -> 105 m "Permission callback returned: %s" 106 (if Permissions.Decision.is_allow decision then "ALLOW" else "DENY")); 107 108 (* Convert permission decision to proto result *) 109 let proto_result = 110 Permissions.Decision.to_proto_result ~original_input:input decision 111 in 112 113 (* Encode to JSON *) 114 let response_data = 115 match Jsont.Json.encode Proto.Permissions.Result.jsont proto_result with 116 | Ok json -> json 117 | Error err -> 118 Log.err (fun m -> m "Failed to encode permission result: %s" err); 119 failwith "Permission result encoding failed" 120 in 121 let response = 122 Control_response.success ~request_id ~response:(Some response_data) 123 in 124 Log.info (fun m -> 125 m "Sending control response: %s" (json_to_string response)); 126 Transport.send t.transport response 127 | Sdk_control.Request.Hook_callback req -> ( 128 let callback_id = req.callback_id in 129 let input = req.input in 130 let _tool_use_id = req.tool_use_id in 131 Log.info (fun m -> 132 m "Hook callback request for callback_id: %s" callback_id); 133 134 try 135 let callback = Hashtbl.find t.hook_callbacks callback_id in 136 let result = callback input in 137 138 let result_json = 139 Jsont.Json.encode Proto.Hooks.result_jsont result 140 |> Err.get_ok ~msg:"Failed to encode hook result: " 141 in 142 Log.debug (fun m -> 143 m "Hook result JSON: %s" (json_to_string result_json)); 144 let response = 145 Control_response.success ~request_id ~response:(Some result_json) 146 in 147 Log.info (fun m -> m "Hook callback succeeded, sending response"); 148 Transport.send t.transport response 149 with 150 | Not_found -> 151 let error_msg = 152 Printf.sprintf "Hook callback not found: %s" callback_id 153 in 154 Log.err (fun m -> m "%s" error_msg); 155 Transport.send t.transport 156 (Control_response.error ~request_id ~code:`Method_not_found 157 ~message:error_msg ()) 158 | exn -> 159 let error_msg = 160 Printf.sprintf "Hook callback error: %s" (Printexc.to_string exn) 161 in 162 Log.err (fun m -> m "%s" error_msg); 163 Transport.send t.transport 164 (Control_response.error ~request_id ~code:`Internal_error 165 ~message:error_msg ())) 166 | Sdk_control.Request.Mcp_message req -> ( 167 let module 168 (* Handle MCP request for in-process SDK servers *) 169 J = 170 Jsont.Json 171 in 172 let server_name = req.server_name in 173 let message = req.message in 174 Log.info (fun m -> m "MCP request for server '%s'" server_name); 175 176 match Hashtbl.find_opt t.mcp_servers server_name with 177 | None -> 178 let error_msg = 179 Printf.sprintf "MCP server '%s' not found" server_name 180 in 181 Log.err (fun m -> m "%s" error_msg); 182 (* Return JSONRPC error in mcp_response format *) 183 let mcp_error = 184 J.object' 185 [ 186 J.mem (J.name "jsonrpc") (J.string "2.0"); 187 J.mem (J.name "id") (J.null ()); 188 J.mem (J.name "error") 189 (J.object' 190 [ 191 J.mem (J.name "code") (J.number (-32601.0)); 192 J.mem (J.name "message") (J.string error_msg); 193 ]); 194 ] 195 in 196 let response_data = 197 J.object' [ J.mem (J.name "mcp_response") mcp_error ] 198 in 199 let response = 200 Control_response.success ~request_id ~response:(Some response_data) 201 in 202 Transport.send t.transport response 203 | Some server -> 204 let mcp_response = Mcp_server.handle_json_message server message in 205 Log.debug (fun m -> 206 m "MCP response: %s" (json_to_string mcp_response)); 207 let response_data = 208 J.object' [ J.mem (J.name "mcp_response") mcp_response ] 209 in 210 let response = 211 Control_response.success ~request_id ~response:(Some response_data) 212 in 213 Transport.send t.transport response) 214 | _ -> 215 (* Other request types not handled here *) 216 let error_msg = "Unsupported control request type" in 217 Transport.send t.transport 218 (Control_response.error ~request_id ~code:`Invalid_request 219 ~message:error_msg ()) 220 221let handle_control_response t control_resp = 222 let request_id = 223 match control_resp.Sdk_control.response with 224 | Sdk_control.Response.Success s -> s.request_id 225 | Sdk_control.Response.Error e -> e.request_id 226 in 227 Log.debug (fun m -> 228 m "Received control response for request_id: %s" request_id); 229 230 (* Store the response as JSON and signal waiting threads *) 231 let json = 232 Jsont.Json.encode Sdk_control.control_response_jsont control_resp 233 |> Err.get_ok ~msg:"Failed to encode control response: " 234 in 235 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () -> 236 Hashtbl.replace t.control_responses request_id json; 237 Eio.Condition.broadcast t.control_condition) 238 239let handle_raw_messages t = 240 let rec loop () = 241 match Transport.receive_line t.transport with 242 | None -> 243 (* EOF *) 244 Log.debug (fun m -> m "Handle messages: EOF received"); 245 Seq.Nil 246 | Some line -> ( 247 (* Use unified Incoming codec for all message types *) 248 match Jsont_bytesrw.decode_string' Incoming.jsont line with 249 | Ok incoming -> Seq.Cons (incoming, loop) 250 | Error err -> 251 Log.err (fun m -> 252 m "Failed to decode incoming message: %s\nLine: %s" 253 (Jsont.Error.to_string err) 254 line); 255 loop ()) 256 in 257 Log.debug (fun m -> m "Starting message handler"); 258 loop 259 260let handle_messages t = 261 let raw_seq = handle_raw_messages t in 262 let rec loop raw_seq = 263 match raw_seq () with 264 | Seq.Nil -> Seq.Nil 265 | Seq.Cons (incoming, rest) -> ( 266 match incoming with 267 | Incoming.Message msg -> 268 Log.info (fun m -> m "← %a" Message.pp msg); 269 270 (* Extract session ID from system messages *) 271 (match msg with 272 | Message.System sys -> 273 Message.System.session_id sys 274 |> Option.iter (fun session_id -> 275 t.session_id <- Some session_id; 276 Log.debug (fun m -> m "Stored session ID: %s" session_id)) 277 | _ -> ()); 278 279 (* Convert message to response events *) 280 let responses = Response.of_message msg in 281 emit_responses responses rest 282 | Incoming.Control_response resp -> 283 handle_control_response t resp; 284 loop rest 285 | Incoming.Control_request ctrl_req -> 286 Log.info (fun m -> 287 m "Received control request (request_id: %s)" 288 ctrl_req.request_id); 289 handle_control_request t ctrl_req; 290 loop rest) 291 and emit_responses responses rest = 292 match responses with 293 | [] -> loop rest 294 | r :: rs -> Seq.Cons (r, fun () -> emit_responses rs rest) 295 in 296 loop raw_seq 297 298let create ?(options = Options.default) ~sw ~process_mgr ~clock () = 299 (* Automatically enable permission prompt tool when callback is configured 300 (matching Python SDK behavior in client.py:104-121) *) 301 let options = 302 match Options.permission_callback options with 303 | Some _ when Options.permission_prompt_tool_name options = None -> 304 (* Set permission_prompt_tool_name to "stdio" to enable control protocol *) 305 Options.with_permission_prompt_tool_name "stdio" options 306 | _ -> options 307 in 308 let transport = Transport.create ~sw ~process_mgr ~options () in 309 310 (* Setup hook callbacks *) 311 let hook_callbacks = Hashtbl.create 16 in 312 let next_callback_id = ref 0 in 313 314 (* Setup MCP servers from options *) 315 let mcp_servers_ht = Hashtbl.create 16 in 316 List.iter 317 (fun (name, server) -> 318 Log.info (fun m -> m "Registering MCP server: %s" name); 319 Hashtbl.add mcp_servers_ht name server) 320 (Options.mcp_servers options); 321 322 let t = 323 { 324 transport; 325 permission_callback = Options.permission_callback options; 326 permission_log = None; 327 hook_callbacks; 328 session_id = None; 329 control_responses = Hashtbl.create 16; 330 control_mutex = Eio.Mutex.create (); 331 control_condition = Eio.Condition.create (); 332 clock; 333 responded_tool_ids = Hashtbl.create 16; 334 mcp_servers = mcp_servers_ht; 335 } 336 in 337 338 (* Register hooks and send initialize if hooks are configured *) 339 Options.hooks options 340 |> Option.iter (fun hooks_config -> 341 Log.info (fun m -> m "Registering hooks..."); 342 343 (* Get callbacks in wire format from the new Hooks API *) 344 let callbacks_by_event = Hooks.get_callbacks hooks_config in 345 346 (* Build hooks configuration with callback IDs as (string * Jsont.json) list *) 347 let hooks_list = 348 List.map 349 (fun (event, matchers) -> 350 let event_name = Proto.Hooks.event_to_string event in 351 let matcher_wires = 352 List.map 353 (fun (pattern, callback) -> 354 let callback_id = 355 Printf.sprintf "hook_%d" !next_callback_id 356 in 357 incr next_callback_id; 358 Hashtbl.add hook_callbacks callback_id callback; 359 Log.debug (fun m -> 360 m "Registered callback: %s for event: %s" callback_id 361 event_name); 362 Hook_matcher_wire. 363 { matcher = pattern; hook_callback_ids = [ callback_id ] }) 364 matchers 365 in 366 (event_name, Hook_matcher_wire.encode matcher_wires)) 367 callbacks_by_event 368 in 369 370 (* Create initialize request using Sdk_control codec *) 371 let request = Sdk_control.Request.initialize ~hooks:hooks_list () in 372 let ctrl_req = 373 Sdk_control.create_request ~request_id:"init_hooks" ~request () 374 in 375 let initialize_msg = 376 Jsont.Json.encode Sdk_control.jsont ctrl_req 377 |> Err.get_ok ~msg:"Failed to encode initialize request: " 378 in 379 Log.info (fun m -> m "Sending hooks initialize request"); 380 Transport.send t.transport initialize_msg); 381 382 t 383 384(* Helper to send a message with proper "type" wrapper via Proto.Outgoing *) 385let send_message t msg = 386 Log.info (fun m -> m "→ %a" Message.pp msg); 387 let proto_msg = Message.to_proto msg in 388 let outgoing = Proto.Outgoing.Message proto_msg in 389 let json = Proto.Outgoing.to_json outgoing in 390 Transport.send t.transport json 391 392let query t prompt = 393 let msg = Message.user_string prompt in 394 send_message t msg 395 396let respond_to_tool t ~tool_use_id ~content ?(is_error = false) () = 397 (* Check for duplicate response - prevents API errors from multiple responses *) 398 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin 399 Log.warn (fun m -> 400 m "Skipping duplicate tool response for tool_use_id: %s" tool_use_id) 401 end 402 else begin 403 Hashtbl.add t.responded_tool_ids tool_use_id (); 404 let user_msg = 405 Message.User.with_tool_result ~tool_use_id ~content ~is_error () 406 in 407 let msg = Message.User user_msg in 408 send_message t msg 409 end 410 411let respond_to_tools t responses = 412 (* Filter out duplicates *) 413 let new_responses = 414 List.filter 415 (fun (tool_use_id, _, _) -> 416 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin 417 Log.warn (fun m -> 418 m "Skipping duplicate tool response for tool_use_id: %s" 419 tool_use_id); 420 false 421 end 422 else begin 423 Hashtbl.add t.responded_tool_ids tool_use_id (); 424 true 425 end) 426 responses 427 in 428 if new_responses <> [] then begin 429 let tool_results = 430 List.map 431 (fun (tool_use_id, content, is_error_opt) -> 432 let is_error = Option.value is_error_opt ~default:false in 433 Content_block.tool_result ~tool_use_id ~content ~is_error ()) 434 new_responses 435 in 436 let user_msg = Message.User.of_blocks tool_results in 437 let msg = Message.User user_msg in 438 send_message t msg 439 end 440 441let clear_tool_response_tracking t = Hashtbl.clear t.responded_tool_ids 442let receive t = fun () -> handle_messages t 443 444let run t ~handler = 445 (* Stop after Complete response - don't wait for EOF since Claude CLI 446 waits for stdin to close before exiting, causing a deadlock *) 447 let rec loop seq = 448 match seq () with 449 | Seq.Nil -> () 450 | Seq.Cons ((Response.Complete _ as resp), _) -> 451 Handler.dispatch handler resp 452 | Seq.Cons (resp, rest) -> 453 Handler.dispatch handler resp; 454 loop rest 455 in 456 loop (receive t) 457 458let receive_all t = 459 let rec collect acc seq = 460 match seq () with 461 | Seq.Nil -> 462 Log.debug (fun m -> 463 m "End of response sequence (%d responses)" (List.length acc)); 464 List.rev acc 465 | Seq.Cons ((Response.Complete _ as resp), _) -> 466 Log.debug (fun m -> m "Received final Complete response"); 467 List.rev (resp :: acc) 468 | Seq.Cons (resp, rest) -> collect (resp :: acc) rest 469 in 470 collect [] (receive t) 471 472let interrupt t = Transport.interrupt t.transport 473 474let enable_permission_discovery t = 475 let log = ref [] in 476 let callback = Permissions.discovery log in 477 t.permission_callback <- Some callback; 478 t.permission_log <- Some log 479 480let discovered_permissions t = 481 t.permission_log |> Option.map ( ! ) |> Option.value ~default:[] 482 483(* Helper to send a control request and wait for response *) 484let send_control_request t ~request_id request = 485 (* Send the control request *) 486 let control_msg = Sdk_control.create_request ~request_id ~request () in 487 let json = 488 Jsont.Json.encode Sdk_control.jsont control_msg 489 |> Err.get_ok ~msg:"Failed to encode control request: " 490 in 491 Log.info (fun m -> m "Sending control request: %s" (json_to_string json)); 492 Transport.send t.transport json; 493 494 (* Wait for the response with timeout *) 495 let max_wait = 10.0 in 496 (* 10 seconds timeout *) 497 let start_time = Eio.Time.now t.clock in 498 499 let rec wait_for_response () = 500 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () -> 501 match Hashtbl.find_opt t.control_responses request_id with 502 | Some response_json -> 503 (* Remove it from the table *) 504 Hashtbl.remove t.control_responses request_id; 505 response_json 506 | None -> 507 let elapsed = Eio.Time.now t.clock -. start_time in 508 if elapsed > max_wait then 509 raise 510 (Failure 511 (Printf.sprintf "Timeout waiting for control response: %s" 512 request_id)) 513 else ( 514 (* Release mutex and wait for signal *) 515 Eio.Condition.await_no_mutex t.control_condition; 516 wait_for_response ())) 517 in 518 519 let response_json = wait_for_response () in 520 Log.debug (fun m -> 521 m "Received control response: %s" (json_to_string response_json)); 522 523 (* Parse the response - extract the "response" field using jsont codec *) 524 let response_field_codec = 525 Jsont.Object.map ~kind:"ResponseField" Fun.id 526 |> Jsont.Object.mem "response" Jsont.json ~enc:Fun.id 527 |> Jsont.Object.finish 528 in 529 let response_data = 530 Jsont.Json.decode response_field_codec response_json 531 |> Err.get_ok' ~msg:"Failed to extract response field: " 532 in 533 let response = 534 Jsont.Json.decode Sdk_control.Response.jsont response_data 535 |> Err.get_ok' ~msg:"Failed to decode response: " 536 in 537 match response with 538 | Sdk_control.Response.Success s -> s.response 539 | Sdk_control.Response.Error e -> 540 raise 541 (Failure 542 (Printf.sprintf "Control request failed: [%d] %s" e.error.code 543 e.error.message)) 544 545let set_permission_mode t mode = 546 let request_id = Printf.sprintf "set_perm_mode_%f" (Eio.Time.now t.clock) in 547 let proto_mode = Permissions.Mode.to_proto mode in 548 let request = Sdk_control.Request.set_permission_mode ~mode:proto_mode () in 549 let _response = send_control_request t ~request_id request in 550 Log.info (fun m -> 551 m "Permission mode set to: %s" (Permissions.Mode.to_string mode)) 552 553let set_model t model = 554 let model_str = Model.to_string model in 555 let request_id = Printf.sprintf "set_model_%f" (Eio.Time.now t.clock) in 556 let request = Sdk_control.Request.set_model ~model:model_str () in 557 let _response = send_control_request t ~request_id request in 558 Log.info (fun m -> m "Model set to: %s" model_str) 559 560let get_server_info t = 561 let request_id = Printf.sprintf "get_server_info_%f" (Eio.Time.now t.clock) in 562 let request = Sdk_control.Request.get_server_info () in 563 let response_data = 564 send_control_request t ~request_id request 565 |> Option.to_result ~none:"No response data from get_server_info request" 566 |> Err.get_ok ~msg:"" 567 in 568 let server_info = 569 Jsont.Json.decode Sdk_control.Server_info.jsont response_data 570 |> Err.get_ok' ~msg:"Failed to decode server info: " 571 in 572 Log.info (fun m -> 573 m "Retrieved server info: %a" 574 (Jsont.pp_value Sdk_control.Server_info.jsont ()) 575 server_info); 576 Server_info.of_sdk_control server_info 577 578module Advanced = struct 579 let send_message t msg = send_message t msg 580 581 let send_user_message t user_msg = 582 let msg = Message.User user_msg in 583 send_message t msg 584 585 let send_raw t control = 586 let json = 587 Jsont.Json.encode Sdk_control.jsont control 588 |> Err.get_ok ~msg:"Failed to encode control message: " 589 in 590 Log.info (fun m -> m "→ Raw control: %s" (json_to_string json)); 591 Transport.send t.transport json 592 593 let send_json t json = 594 Log.info (fun m -> m "→ Raw JSON: %s" (json_to_string json)); 595 Transport.send t.transport json 596 597 let receive_raw t = handle_raw_messages t 598end