forked from
anil.recoil.org/monopam-myspace
My aggregated monorepo of OCaml code, automaintained
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6let src = Logs.Src.create "claude.client" ~doc:"Claude client"
7
8module Log = (val Logs.src_log src : Logs.LOG)
9
10(** Control response builders using Sdk_control codecs *)
11module Control_response = struct
12 let success ~request_id ~response =
13 let resp = Sdk_control.Response.success ~request_id ?response () in
14 let ctrl = Sdk_control.create_response ~response:resp () in
15 Jsont.Json.encode Sdk_control.jsont ctrl
16 |> Err.get_ok ~msg:"Control_response.success: "
17
18 let error ~request_id ~code ~message ?data () =
19 let error_detail = Sdk_control.Response.error_detail ~code ~message ?data () in
20 let resp = Sdk_control.Response.error ~request_id ~error:error_detail () in
21 let ctrl = Sdk_control.create_response ~response:resp () in
22 Jsont.Json.encode Sdk_control.jsont ctrl
23 |> Err.get_ok ~msg:"Control_response.error: "
24end
25
26(* Helper functions for JSON manipulation using jsont *)
27let json_to_string json =
28 Jsont_bytesrw.encode_string' Jsont.json json
29 |> Result.map_error Jsont.Error.to_string
30 |> Err.get_ok ~msg:""
31
32(** Wire-level codec for hook matcher configuration sent to CLI. *)
33module Hook_matcher_wire = struct
34 type t = { matcher : string option; hook_callback_ids : string list }
35
36 let jsont : t Jsont.t =
37 let make matcher hook_callback_ids = { matcher; hook_callback_ids } in
38 Jsont.Object.map ~kind:"HookMatcherWire" make
39 |> Jsont.Object.opt_mem "matcher" Jsont.string ~enc:(fun r -> r.matcher)
40 |> Jsont.Object.mem "hookCallbackIds" (Jsont.list Jsont.string)
41 ~enc:(fun r -> r.hook_callback_ids)
42 |> Jsont.Object.finish
43
44 let encode matchers =
45 List.map
46 (fun m ->
47 Jsont.Json.encode jsont m
48 |> Err.get_ok ~msg:"Hook_matcher_wire.encode: ")
49 matchers
50 |> Jsont.Json.list
51end
52
53type t = {
54 transport : Transport.t;
55 mutable permission_callback : Permissions.callback option;
56 mutable permission_log : Permissions.Rule.t list ref option;
57 hook_callbacks : (string, Jsont.json -> Proto.Hooks.result) Hashtbl.t;
58 mutable session_id : string option;
59 control_responses : (string, Jsont.json) Hashtbl.t;
60 control_mutex : Eio.Mutex.t;
61 control_condition : Eio.Condition.t;
62 clock : float Eio.Time.clock_ty Eio.Resource.t;
63 (* Track tool_use_ids we've already responded to, preventing duplicates *)
64 responded_tool_ids : (string, unit) Hashtbl.t;
65 (* In-process MCP servers for custom tools *)
66 mcp_servers : (string, Mcp_server.t) Hashtbl.t;
67}
68
69let session_id t = t.session_id
70
71let handle_control_request t (ctrl_req : Sdk_control.control_request) =
72 let request_id = ctrl_req.request_id in
73 Log.info (fun m -> m "Handling control request: %s" request_id);
74
75 match ctrl_req.request with
76 | Sdk_control.Request.Permission req ->
77 let tool_name = req.tool_name in
78 let input_json = req.input in
79 Log.info (fun m ->
80 m "Permission request for tool '%s' with input: %s" tool_name
81 (json_to_string input_json));
82 (* Convert permission_suggestions to suggested rules *)
83 let suggestions = Option.value req.permission_suggestions ~default:[] in
84 let suggested_rules = Permissions.extract_rules_from_proto_updates suggestions in
85
86 (* Convert input to Tool_input.t *)
87 let input = Tool_input.of_json input_json in
88
89 (* Create context *)
90 let context : Permissions.context =
91 { tool_name; input; suggested_rules }
92 in
93
94 Log.info (fun m ->
95 m "Invoking permission callback for tool: %s" tool_name);
96 let callback =
97 Option.value t.permission_callback
98 ~default:Permissions.default_allow
99 in
100 let decision = callback context in
101 Log.info (fun m ->
102 m "Permission callback returned: %s"
103 (if Permissions.Decision.is_allow decision then "ALLOW" else "DENY"));
104
105 (* Convert permission decision to proto result *)
106 let proto_result = Permissions.Decision.to_proto_result ~original_input:input decision in
107
108 (* Encode to JSON *)
109 let response_data =
110 match Jsont.Json.encode Proto.Permissions.Result.jsont proto_result with
111 | Ok json -> json
112 | Error err ->
113 Log.err (fun m -> m "Failed to encode permission result: %s" err);
114 failwith "Permission result encoding failed"
115 in
116 let response =
117 Control_response.success ~request_id ~response:(Some response_data)
118 in
119 Log.info (fun m ->
120 m "Sending control response: %s" (json_to_string response));
121 Transport.send t.transport response
122 | Sdk_control.Request.Hook_callback req -> (
123 let callback_id = req.callback_id in
124 let input = req.input in
125 let _tool_use_id = req.tool_use_id in
126 Log.info (fun m ->
127 m "Hook callback request for callback_id: %s" callback_id);
128
129 try
130 let callback = Hashtbl.find t.hook_callbacks callback_id in
131 let result = callback input in
132
133 let result_json =
134 Jsont.Json.encode Proto.Hooks.result_jsont result
135 |> Err.get_ok ~msg:"Failed to encode hook result: "
136 in
137 Log.debug (fun m ->
138 m "Hook result JSON: %s" (json_to_string result_json));
139 let response =
140 Control_response.success ~request_id ~response:(Some result_json)
141 in
142 Log.info (fun m -> m "Hook callback succeeded, sending response");
143 Transport.send t.transport response
144 with
145 | Not_found ->
146 let error_msg =
147 Printf.sprintf "Hook callback not found: %s" callback_id
148 in
149 Log.err (fun m -> m "%s" error_msg);
150 Transport.send t.transport
151 (Control_response.error ~request_id ~code:`Method_not_found ~message:error_msg ())
152 | exn ->
153 let error_msg =
154 Printf.sprintf "Hook callback error: %s" (Printexc.to_string exn)
155 in
156 Log.err (fun m -> m "%s" error_msg);
157 Transport.send t.transport
158 (Control_response.error ~request_id ~code:`Internal_error ~message:error_msg ()))
159 | Sdk_control.Request.Mcp_message req -> (
160 (* Handle MCP request for in-process SDK servers *)
161 let module J = Jsont.Json in
162 let server_name = req.server_name in
163 let message = req.message in
164 Log.info (fun m -> m "MCP request for server '%s'" server_name);
165
166 match Hashtbl.find_opt t.mcp_servers server_name with
167 | None ->
168 let error_msg = Printf.sprintf "MCP server '%s' not found" server_name in
169 Log.err (fun m -> m "%s" error_msg);
170 (* Return JSONRPC error in mcp_response format *)
171 let mcp_error = J.object' [
172 J.mem (J.name "jsonrpc") (J.string "2.0");
173 J.mem (J.name "id") (J.null ());
174 J.mem (J.name "error") (J.object' [
175 J.mem (J.name "code") (J.number (-32601.0));
176 J.mem (J.name "message") (J.string error_msg)
177 ])
178 ] in
179 let response_data = J.object' [J.mem (J.name "mcp_response") mcp_error] in
180 let response = Control_response.success ~request_id ~response:(Some response_data) in
181 Transport.send t.transport response
182 | Some server ->
183 let mcp_response = Mcp_server.handle_json_message server message in
184 Log.debug (fun m -> m "MCP response: %s" (json_to_string mcp_response));
185 let response_data = J.object' [J.mem (J.name "mcp_response") mcp_response] in
186 let response = Control_response.success ~request_id ~response:(Some response_data) in
187 Transport.send t.transport response)
188 | _ ->
189 (* Other request types not handled here *)
190 let error_msg = "Unsupported control request type" in
191 Transport.send t.transport
192 (Control_response.error ~request_id ~code:`Invalid_request ~message:error_msg ())
193
194let handle_control_response t control_resp =
195 let request_id =
196 match control_resp.Sdk_control.response with
197 | Sdk_control.Response.Success s -> s.request_id
198 | Sdk_control.Response.Error e -> e.request_id
199 in
200 Log.debug (fun m ->
201 m "Received control response for request_id: %s" request_id);
202
203 (* Store the response as JSON and signal waiting threads *)
204 let json =
205 Jsont.Json.encode Sdk_control.control_response_jsont control_resp
206 |> Err.get_ok ~msg:"Failed to encode control response: "
207 in
208 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () ->
209 Hashtbl.replace t.control_responses request_id json;
210 Eio.Condition.broadcast t.control_condition)
211
212let handle_raw_messages t =
213 let rec loop () =
214 match Transport.receive_line t.transport with
215 | None ->
216 (* EOF *)
217 Log.debug (fun m -> m "Handle messages: EOF received");
218 Seq.Nil
219 | Some line -> (
220 (* Use unified Incoming codec for all message types *)
221 match Jsont_bytesrw.decode_string' Incoming.jsont line with
222 | Ok incoming ->
223 Seq.Cons (incoming, loop)
224 | Error err ->
225 Log.err (fun m ->
226 m "Failed to decode incoming message: %s\nLine: %s"
227 (Jsont.Error.to_string err)
228 line);
229 loop ())
230 in
231 Log.debug (fun m -> m "Starting message handler");
232 loop
233
234let handle_messages t =
235 let raw_seq = handle_raw_messages t in
236 let rec loop raw_seq =
237 match raw_seq () with
238 | Seq.Nil -> Seq.Nil
239 | Seq.Cons (incoming, rest) -> (
240 match incoming with
241 | Incoming.Message msg ->
242 Log.info (fun m -> m "← %a" Message.pp msg);
243
244 (* Extract session ID from system messages *)
245 (match msg with
246 | Message.System sys ->
247 Message.System.session_id sys
248 |> Option.iter (fun session_id ->
249 t.session_id <- Some session_id;
250 Log.debug (fun m -> m "Stored session ID: %s" session_id))
251 | _ -> ());
252
253 (* Convert message to response events *)
254 let responses = Response.of_message msg in
255 emit_responses responses rest
256 | Incoming.Control_response resp ->
257 handle_control_response t resp;
258 loop rest
259 | Incoming.Control_request ctrl_req ->
260 Log.info (fun m ->
261 m "Received control request (request_id: %s)"
262 ctrl_req.request_id);
263 handle_control_request t ctrl_req;
264 loop rest)
265
266 and emit_responses responses rest =
267 match responses with
268 | [] -> loop rest
269 | r :: rs -> Seq.Cons (r, fun () -> emit_responses rs rest)
270 in
271 loop raw_seq
272
273let create ?(options = Options.default) ~sw ~process_mgr ~clock () =
274 (* Automatically enable permission prompt tool when callback is configured
275 (matching Python SDK behavior in client.py:104-121) *)
276 let options =
277 match Options.permission_callback options with
278 | Some _ when Options.permission_prompt_tool_name options = None ->
279 (* Set permission_prompt_tool_name to "stdio" to enable control protocol *)
280 Options.with_permission_prompt_tool_name "stdio" options
281 | _ -> options
282 in
283 let transport = Transport.create ~sw ~process_mgr ~options () in
284
285 (* Setup hook callbacks *)
286 let hook_callbacks = Hashtbl.create 16 in
287 let next_callback_id = ref 0 in
288
289 (* Setup MCP servers from options *)
290 let mcp_servers_ht = Hashtbl.create 16 in
291 List.iter (fun (name, server) ->
292 Log.info (fun m -> m "Registering MCP server: %s" name);
293 Hashtbl.add mcp_servers_ht name server
294 ) (Options.mcp_servers options);
295
296 let t =
297 {
298 transport;
299 permission_callback = Options.permission_callback options;
300 permission_log = None;
301 hook_callbacks;
302 session_id = None;
303 control_responses = Hashtbl.create 16;
304 control_mutex = Eio.Mutex.create ();
305 control_condition = Eio.Condition.create ();
306 clock;
307 responded_tool_ids = Hashtbl.create 16;
308 mcp_servers = mcp_servers_ht;
309 }
310 in
311
312 (* Register hooks and send initialize if hooks are configured *)
313 Options.hooks options
314 |> Option.iter (fun hooks_config ->
315 Log.info (fun m -> m "Registering hooks...");
316
317 (* Get callbacks in wire format from the new Hooks API *)
318 let callbacks_by_event = Hooks.get_callbacks hooks_config in
319
320 (* Build hooks configuration with callback IDs as (string * Jsont.json) list *)
321 let hooks_list =
322 List.map
323 (fun (event, matchers) ->
324 let event_name = Proto.Hooks.event_to_string event in
325 let matcher_wires =
326 List.map
327 (fun (pattern, callback) ->
328 let callback_id =
329 Printf.sprintf "hook_%d" !next_callback_id
330 in
331 incr next_callback_id;
332 Hashtbl.add hook_callbacks callback_id callback;
333 Log.debug (fun m ->
334 m "Registered callback: %s for event: %s"
335 callback_id event_name);
336 Hook_matcher_wire.
337 {
338 matcher = pattern;
339 hook_callback_ids = [callback_id];
340 })
341 matchers
342 in
343 (event_name, Hook_matcher_wire.encode matcher_wires))
344 callbacks_by_event
345 in
346
347 (* Create initialize request using Sdk_control codec *)
348 let request = Sdk_control.Request.initialize ~hooks:hooks_list () in
349 let ctrl_req =
350 Sdk_control.create_request ~request_id:"init_hooks" ~request ()
351 in
352 let initialize_msg =
353 Jsont.Json.encode Sdk_control.jsont ctrl_req
354 |> Err.get_ok ~msg:"Failed to encode initialize request: "
355 in
356 Log.info (fun m -> m "Sending hooks initialize request");
357 Transport.send t.transport initialize_msg);
358
359 t
360
361(* Helper to send a message with proper "type" wrapper via Proto.Outgoing *)
362let send_message t msg =
363 Log.info (fun m -> m "→ %a" Message.pp msg);
364 let proto_msg = Message.to_proto msg in
365 let outgoing = Proto.Outgoing.Message proto_msg in
366 let json = Proto.Outgoing.to_json outgoing in
367 Transport.send t.transport json
368
369let query t prompt =
370 let msg = Message.user_string prompt in
371 send_message t msg
372
373let respond_to_tool t ~tool_use_id ~content ?(is_error = false) () =
374 (* Check for duplicate response - prevents API errors from multiple responses *)
375 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin
376 Log.warn (fun m -> m "Skipping duplicate tool response for tool_use_id: %s" tool_use_id)
377 end else begin
378 Hashtbl.add t.responded_tool_ids tool_use_id ();
379 let user_msg = Message.User.with_tool_result ~tool_use_id ~content ~is_error () in
380 let msg = Message.User user_msg in
381 send_message t msg
382 end
383
384let respond_to_tools t responses =
385 (* Filter out duplicates *)
386 let new_responses = List.filter (fun (tool_use_id, _, _) ->
387 if Hashtbl.mem t.responded_tool_ids tool_use_id then begin
388 Log.warn (fun m -> m "Skipping duplicate tool response for tool_use_id: %s" tool_use_id);
389 false
390 end else begin
391 Hashtbl.add t.responded_tool_ids tool_use_id ();
392 true
393 end
394 ) responses in
395 if new_responses <> [] then begin
396 let tool_results =
397 List.map
398 (fun (tool_use_id, content, is_error_opt) ->
399 let is_error = Option.value is_error_opt ~default:false in
400 Content_block.tool_result ~tool_use_id ~content ~is_error ())
401 new_responses
402 in
403 let user_msg = Message.User.of_blocks tool_results in
404 let msg = Message.User user_msg in
405 send_message t msg
406 end
407
408let clear_tool_response_tracking t =
409 Hashtbl.clear t.responded_tool_ids
410
411let receive t = fun () -> handle_messages t
412
413let run t ~handler =
414 Seq.iter (Handler.dispatch handler) (receive t)
415
416let receive_all t =
417 let rec collect acc seq =
418 match seq () with
419 | Seq.Nil ->
420 Log.debug (fun m ->
421 m "End of response sequence (%d responses)" (List.length acc));
422 List.rev acc
423 | Seq.Cons ((Response.Complete _ as resp), _) ->
424 Log.debug (fun m -> m "Received final Complete response");
425 List.rev (resp :: acc)
426 | Seq.Cons (resp, rest) -> collect (resp :: acc) rest
427 in
428 collect [] (receive t)
429
430let interrupt t = Transport.interrupt t.transport
431
432let enable_permission_discovery t =
433 let log = ref [] in
434 let callback = Permissions.discovery log in
435 t.permission_callback <- Some callback;
436 t.permission_log <- Some log
437
438let discovered_permissions t =
439 t.permission_log |> Option.map ( ! ) |> Option.value ~default:[]
440
441(* Helper to send a control request and wait for response *)
442let send_control_request t ~request_id request =
443 (* Send the control request *)
444 let control_msg = Sdk_control.create_request ~request_id ~request () in
445 let json =
446 Jsont.Json.encode Sdk_control.jsont control_msg
447 |> Err.get_ok ~msg:"Failed to encode control request: "
448 in
449 Log.info (fun m -> m "Sending control request: %s" (json_to_string json));
450 Transport.send t.transport json;
451
452 (* Wait for the response with timeout *)
453 let max_wait = 10.0 in
454 (* 10 seconds timeout *)
455 let start_time = Eio.Time.now t.clock in
456
457 let rec wait_for_response () =
458 Eio.Mutex.use_rw ~protect:false t.control_mutex (fun () ->
459 match Hashtbl.find_opt t.control_responses request_id with
460 | Some response_json ->
461 (* Remove it from the table *)
462 Hashtbl.remove t.control_responses request_id;
463 response_json
464 | None ->
465 let elapsed = Eio.Time.now t.clock -. start_time in
466 if elapsed > max_wait then
467 raise
468 (Failure
469 (Printf.sprintf "Timeout waiting for control response: %s"
470 request_id))
471 else (
472 (* Release mutex and wait for signal *)
473 Eio.Condition.await_no_mutex t.control_condition;
474 wait_for_response ()))
475 in
476
477 let response_json = wait_for_response () in
478 Log.debug (fun m ->
479 m "Received control response: %s" (json_to_string response_json));
480
481 (* Parse the response - extract the "response" field using jsont codec *)
482 let response_field_codec =
483 Jsont.Object.map ~kind:"ResponseField" Fun.id
484 |> Jsont.Object.mem "response" Jsont.json ~enc:Fun.id
485 |> Jsont.Object.finish
486 in
487 let response_data =
488 Jsont.Json.decode response_field_codec response_json
489 |> Err.get_ok' ~msg:"Failed to extract response field: "
490 in
491 let response =
492 Jsont.Json.decode Sdk_control.Response.jsont response_data
493 |> Err.get_ok' ~msg:"Failed to decode response: "
494 in
495 match response with
496 | Sdk_control.Response.Success s -> s.response
497 | Sdk_control.Response.Error e ->
498 raise (Failure (Printf.sprintf "Control request failed: [%d] %s" e.error.code e.error.message))
499
500let set_permission_mode t mode =
501 let request_id = Printf.sprintf "set_perm_mode_%f" (Eio.Time.now t.clock) in
502 let proto_mode = Permissions.Mode.to_proto mode in
503 let request = Sdk_control.Request.set_permission_mode ~mode:proto_mode () in
504 let _response = send_control_request t ~request_id request in
505 Log.info (fun m ->
506 m "Permission mode set to: %s" (Permissions.Mode.to_string mode))
507
508let set_model t model =
509 let model_str = Model.to_string model in
510 let request_id = Printf.sprintf "set_model_%f" (Eio.Time.now t.clock) in
511 let request = Sdk_control.Request.set_model ~model:model_str () in
512 let _response = send_control_request t ~request_id request in
513 Log.info (fun m -> m "Model set to: %s" model_str)
514
515let get_server_info t =
516 let request_id = Printf.sprintf "get_server_info_%f" (Eio.Time.now t.clock) in
517 let request = Sdk_control.Request.get_server_info () in
518 let response_data =
519 send_control_request t ~request_id request
520 |> Option.to_result ~none:"No response data from get_server_info request"
521 |> Err.get_ok ~msg:""
522 in
523 let server_info =
524 Jsont.Json.decode Sdk_control.Server_info.jsont response_data
525 |> Err.get_ok' ~msg:"Failed to decode server info: "
526 in
527 Log.info (fun m ->
528 m "Retrieved server info: %a"
529 (Jsont.pp_value Sdk_control.Server_info.jsont ())
530 server_info);
531 Server_info.of_sdk_control server_info
532
533module Advanced = struct
534 let send_message t msg = send_message t msg
535
536 let send_user_message t user_msg =
537 let msg = Message.User user_msg in
538 send_message t msg
539
540 let send_raw t control =
541 let json =
542 Jsont.Json.encode Sdk_control.jsont control
543 |> Err.get_ok ~msg:"Failed to encode control message: "
544 in
545 Log.info (fun m -> m "→ Raw control: %s" (json_to_string json));
546 Transport.send t.transport json
547
548 let send_json t json =
549 Log.info (fun m -> m "→ Raw JSON: %s" (json_to_string json));
550 Transport.send t.transport json
551
552 let receive_raw t = handle_raw_messages t
553end