···11+# Onnxrt Implementation Plan
22+33+> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
44+55+**Goal:** Implement `onnxrt`, an OCaml library providing type-safe bindings to ONNX Runtime Web for browser-based ML inference via js_of_ocaml.
66+77+**Architecture:** Two-layer design. An internal `Promise_lwt` helper bridges JS Promises to Lwt. The public `Onnxrt` module exposes pure OCaml types (Bigarray, Lwt.t) and uses `Js.Unsafe` internally to call the `onnxruntime-web` JavaScript API. No `Js.t` types in the public API.
88+99+**Tech Stack:** OCaml 5.2+, js_of_ocaml 5.8+, Lwt, dune 3.17, onnxruntime-web (npm, loaded externally)
1010+1111+**Key JS API surface being bound:**
1212+- `ort.env.wasm.*` / `ort.env.webgpu.*` — global config
1313+- `new ort.Tensor(type, data, dims)` — tensor construction
1414+- `ort.InferenceSession.create(model, options)` — returns Promise<session>
1515+- `session.run(feeds)` — returns Promise<results>
1616+- `session.inputNames` / `session.outputNames` — string arrays
1717+- `session.release()` — returns Promise<void>
1818+- `tensor.data` / `tensor.dims` / `tensor.type` / `tensor.size` — properties
1919+- `tensor.location` — "cpu" | "gpu-buffer"
2020+- `tensor.getData()` — returns Promise<TypedArray> (GPU download)
2121+- `tensor.dispose()` — void
2222+2323+---
2424+2525+### Task 1: Promise_lwt bridge helper
2626+2727+The ONNX API is entirely Promise-based. We need a minimal helper to convert
2828+JS Promises to Lwt threads. This is an internal module, not exposed publicly.
2929+3030+**Files:**
3131+- Create: `lib/promise_lwt.ml`
3232+- Create: `lib/promise_lwt.mli`
3333+3434+**Step 1: Write `lib/promise_lwt.mli`**
3535+3636+```ocaml
3737+(** Internal: bridge JavaScript Promises to Lwt.
3838+3939+ Not part of the public API. *)
4040+4141+val to_lwt : 'a Js_of_ocaml.Js.t -> 'a Lwt.t
4242+(** [to_lwt js_promise] converts a JavaScript Promise to an Lwt thread.
4343+ If the Promise rejects, the Lwt thread fails with [Failure msg]. *)
4444+```
4545+4646+**Step 2: Write `lib/promise_lwt.ml`**
4747+4848+```ocaml
4949+open Js_of_ocaml
5050+5151+let to_lwt (promise : 'a Js.t) : 'a Lwt.t =
5252+ let lwt_promise, resolver = Lwt.wait () in
5353+ let on_resolve result = Lwt.wakeup resolver result in
5454+ let on_reject error =
5555+ let msg =
5656+ Js.Opt.case
5757+ (Js.Unsafe.meth_call error "toString" [||] : Js.js_string Js.t Js.Opt.t)
5858+ (fun () -> "unknown error")
5959+ Js.to_string
6060+ in
6161+ Lwt.wakeup_exn resolver (Failure msg)
6262+ in
6363+ let _ignored : 'b Js.t =
6464+ Js.Unsafe.meth_call promise "then"
6565+ [| Js.Unsafe.inject (Js.wrap_callback on_resolve);
6666+ Js.Unsafe.inject (Js.wrap_callback on_reject) |]
6767+ in
6868+ lwt_promise
6969+```
7070+7171+**Step 3: Verify it compiles**
7272+7373+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
7474+Expected: Build succeeds (the modules are internal, linked into the library)
7575+7676+**Step 4: Commit**
7777+7878+```
7979+feat: add internal Promise_lwt bridge
8080+```
8181+8282+---
8383+8484+### Task 2: Dtype module
8585+8686+Pure OCaml, no JS interop. Implements the GADT and conversion functions.
8787+8888+**Files:**
8989+- Create: `lib/onnxrt.ml` (start with Dtype module only)
9090+9191+**Step 1: Write the Dtype implementation in `lib/onnxrt.ml`**
9292+9393+```ocaml
9494+module Dtype = struct
9595+ type ('ocaml, 'elt) t =
9696+ | Float32 : (float, Bigarray.float32_elt) t
9797+ | Float64 : (float, Bigarray.float64_elt) t
9898+ | Int8 : (int, Bigarray.int8_signed_elt) t
9999+ | Uint8 : (int, Bigarray.int8_unsigned_elt) t
100100+ | Int16 : (int, Bigarray.int16_signed_elt) t
101101+ | Uint16 : (int, Bigarray.int16_unsigned_elt) t
102102+ | Int32 : (int32, Bigarray.int32_elt) t
103103+104104+ type packed = Pack : ('ocaml, 'elt) t -> packed
105105+106106+ let to_string : type a b. (a, b) t -> string = function
107107+ | Float32 -> "float32"
108108+ | Float64 -> "float64"
109109+ | Int8 -> "int8"
110110+ | Uint8 -> "uint8"
111111+ | Int16 -> "int16"
112112+ | Uint16 -> "uint16"
113113+ | Int32 -> "int32"
114114+115115+ let of_string = function
116116+ | "float32" -> Some (Pack Float32)
117117+ | "float64" -> Some (Pack Float64)
118118+ | "int8" -> Some (Pack Int8)
119119+ | "uint8" -> Some (Pack Uint8)
120120+ | "int16" -> Some (Pack Int16)
121121+ | "uint16" -> Some (Pack Uint16)
122122+ | "int32" -> Some (Pack Int32)
123123+ | _ -> None
124124+125125+ let equal : type a b c d. (a, b) t -> (c, d) t -> bool =
126126+ fun a b ->
127127+ match (a, b) with
128128+ | Float32, Float32 -> true
129129+ | Float64, Float64 -> true
130130+ | Int8, Int8 -> true
131131+ | Uint8, Uint8 -> true
132132+ | Int16, Int16 -> true
133133+ | Uint16, Uint16 -> true
134134+ | Int32, Int32 -> true
135135+ | _ -> false
136136+137137+ (* Internal: return the Bigarray kind for a dtype *)
138138+ let to_bigarray_kind : type a b. (a, b) t -> (a, b) Bigarray.kind = function
139139+ | Float32 -> Bigarray.float32
140140+ | Float64 -> Bigarray.float64
141141+ | Int8 -> Bigarray.int8_signed
142142+ | Uint8 -> Bigarray.int8_unsigned
143143+ | Int16 -> Bigarray.int16_signed
144144+ | Uint16 -> Bigarray.int16_unsigned
145145+ | Int32 -> Bigarray.int32
146146+147147+ (* Internal: return the JS TypedArray constructor name *)
148148+ let typed_array_name : type a b. (a, b) t -> string = function
149149+ | Float32 -> "Float32Array"
150150+ | Float64 -> "Float64Array"
151151+ | Int8 -> "Int8Array"
152152+ | Uint8 -> "Uint8Array"
153153+ | Int16 -> "Int16Array"
154154+ | Uint16 -> "Uint16Array"
155155+ | Int32 -> "Int32Array"
156156+end
157157+```
158158+159159+**Step 2: Add stub modules so the .ml satisfies the .mli**
160160+161161+Append these stubs to `lib/onnxrt.ml` so dune can type-check against the .mli:
162162+163163+```ocaml
164164+module Tensor = struct
165165+ type t = { js_tensor : 'a. 'a } [@@warning "-37"]
166166+ type location = Cpu | Gpu_buffer
167167+ let of_bigarray1 _ _ ~dims:_ = assert false
168168+ let of_bigarray _ _ = assert false
169169+ let of_float32s _ ~dims:_ = assert false
170170+ let to_bigarray1_exn _ _ = assert false
171171+ let to_bigarray_exn _ _ = assert false
172172+ let download _ _ = assert false
173173+ let dims _ = assert false
174174+ let dtype _ = assert false
175175+ let size _ = assert false
176176+ let location _ = assert false
177177+ let dispose _ = assert false
178178+end
179179+180180+module Execution_provider = struct
181181+ type t = Wasm | Webgpu
182182+ let to_string = function Wasm -> "wasm" | Webgpu -> "webgpu"
183183+end
184184+185185+type output_location = Cpu | Gpu_buffer
186186+type graph_optimization = Disabled | Basic | Extended | All
187187+188188+module Session = struct
189189+ type t = { js_session : 'a. 'a } [@@warning "-37"]
190190+ let create ?execution_providers:_ ?graph_optimization:_ ?preferred_output_location:_ ?log_level:_ _ () = assert false
191191+ let create_from_buffer ?execution_providers:_ ?graph_optimization:_ ?preferred_output_location:_ ?log_level:_ _ () = assert false
192192+ let run _ _ = assert false
193193+ let run_with_outputs _ _ ~output_names:_ = assert false
194194+ let input_names _ = assert false
195195+ let output_names _ = assert false
196196+ let release _ = assert false
197197+end
198198+199199+module Env = struct
200200+ module Wasm = struct
201201+ let set_num_threads _ = assert false
202202+ let set_simd _ = assert false
203203+ let set_proxy _ = assert false
204204+ let set_wasm_paths _ = assert false
205205+ end
206206+ module Webgpu = struct
207207+ let set_power_preference _ = assert false
208208+ end
209209+end
210210+```
211211+212212+**Step 3: Verify it compiles**
213213+214214+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
215215+Expected: Build succeeds. All stubs satisfy the .mli signatures.
216216+217217+**Step 4: Commit**
218218+219219+```
220220+feat: add Dtype implementation and skeleton stubs
221221+```
222222+223223+---
224224+225225+### Task 3: Internal JS helpers
226226+227227+Shared helpers for accessing the `ort` global object and converting between
228228+OCaml and JS types used across Tensor, Session, and Env modules.
229229+230230+**Files:**
231231+- Create: `lib/js_helpers.ml`
232232+233233+**Step 1: Write `lib/js_helpers.ml`**
234234+235235+```ocaml
236236+(** Internal JS interop helpers. Not part of the public API. *)
237237+238238+open Js_of_ocaml
239239+240240+(** Access the global [ort] object (onnxruntime-web). *)
241241+let ort () : 'a Js.t =
242242+ let o = Js.Unsafe.global##.ort in
243243+ if Js.Optdef.test o then (Js.Unsafe.coerce o : 'a Js.t)
244244+ else failwith "onnxruntime-web is not loaded: global 'ort' object not found"
245245+246246+(** Convert an OCaml string list to a JS array of JS strings. *)
247247+let js_string_array (strs : string list) : Js.js_string Js.t Js.js_array Js.t =
248248+ Js.array (Array.of_list (List.map Js.string strs))
249249+250250+(** Convert a JS array of JS strings to an OCaml string list. *)
251251+let string_list_of_js_array (arr : Js.js_string Js.t Js.js_array Js.t) : string list =
252252+ Array.to_list (Array.map Js.to_string (Js.to_array arr))
253253+254254+(** Convert an OCaml int array to a JS array of ints. *)
255255+let js_int_array (dims : int array) : int Js.js_array Js.t =
256256+ Js.array dims
257257+258258+(** Convert a JS array of ints to an OCaml int array. *)
259259+let int_array_of_js (arr : int Js.js_array Js.t) : int array =
260260+ Js.to_array arr
261261+262262+(** Read a string property from a JS object. *)
263263+let get_string (obj : 'a Js.t) (key : string) : string =
264264+ Js.to_string (Js.Unsafe.get obj (Js.string key))
265265+266266+(** Read an int property from a JS object. *)
267267+let get_int (obj : 'a Js.t) (key : string) : int =
268268+ Js.Unsafe.get obj (Js.string key)
269269+270270+(** Set a property on a JS object. *)
271271+let set (obj : 'a Js.t) (key : string) (value : 'b) : unit =
272272+ Js.Unsafe.set obj (Js.string key) value
273273+274274+(** Get a nested property: obj.key1.key2 *)
275275+let get_nested (obj : 'a Js.t) (key1 : string) (key2 : string) : 'b Js.t =
276276+ Js.Unsafe.get (Js.Unsafe.get obj (Js.string key1)) (Js.string key2)
277277+```
278278+279279+**Step 2: Verify it compiles**
280280+281281+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
282282+Expected: Build succeeds
283283+284284+**Step 3: Commit**
285285+286286+```
287287+feat: add internal JS interop helpers
288288+```
289289+290290+---
291291+292292+### Task 4: Env module implementation
293293+294294+Replace the Env stubs with real implementations that set properties on the
295295+global `ort.env` object.
296296+297297+**Files:**
298298+- Modify: `lib/onnxrt.ml` — replace Env module
299299+300300+**Step 1: Replace the Env stub**
301301+302302+Replace the `module Env = struct ... end` block in `onnxrt.ml` with:
303303+304304+```ocaml
305305+module Env = struct
306306+ module Wasm = struct
307307+ let set_num_threads n =
308308+ let env = Js_helpers.ort () in
309309+ Js_helpers.set
310310+ (Js_helpers.get_nested env "env" "wasm")
311311+ "numThreads" n
312312+313313+ let set_simd enabled =
314314+ let env = Js_helpers.ort () in
315315+ Js_helpers.set
316316+ (Js_helpers.get_nested env "env" "wasm")
317317+ "simd" (Js_of_ocaml.Js.bool enabled)
318318+319319+ let set_proxy enabled =
320320+ let env = Js_helpers.ort () in
321321+ Js_helpers.set
322322+ (Js_helpers.get_nested env "env" "wasm")
323323+ "proxy" (Js_of_ocaml.Js.bool enabled)
324324+325325+ let set_wasm_paths prefix =
326326+ let env = Js_helpers.ort () in
327327+ Js_helpers.set
328328+ (Js_helpers.get_nested env "env" "wasm")
329329+ "wasmPaths" (Js_of_ocaml.Js.string prefix)
330330+ end
331331+332332+ module Webgpu = struct
333333+ let set_power_preference pref =
334334+ let env = Js_helpers.ort () in
335335+ let s = match pref with
336336+ | `High_performance -> "high-performance"
337337+ | `Low_power -> "low-power"
338338+ in
339339+ Js_helpers.set
340340+ (Js_helpers.get_nested env "env" "webgpu")
341341+ "powerPreference" (Js_of_ocaml.Js.string s)
342342+ end
343343+end
344344+```
345345+346346+**Step 2: Verify it compiles**
347347+348348+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
349349+Expected: Build succeeds
350350+351351+**Step 3: Commit**
352352+353353+```
354354+feat: implement Env module (WASM and WebGPU configuration)
355355+```
356356+357357+---
358358+359359+### Task 5: Execution_provider and top-level types
360360+361361+These are trivial but let's make sure the real implementations are in place.
362362+363363+**Files:**
364364+- Modify: `lib/onnxrt.ml` — replace stubs
365365+366366+**Step 1: Replace Execution_provider and type stubs**
367367+368368+The Execution_provider stub from Task 2 is already correct. Verify the types
369369+`output_location` and `graph_optimization` are also correct (they are pure
370370+OCaml types with no JS interop). No changes needed — these are already final.
371371+372372+Add an internal helper for converting `graph_optimization` and `output_location`
373373+to JS strings, used by Session:
374374+375375+```ocaml
376376+(* Internal: place after the type definitions, before Session module *)
377377+378378+let graph_optimization_to_string = function
379379+ | Disabled -> "disabled"
380380+ | Basic -> "basic"
381381+ | Extended -> "extended"
382382+ | All -> "all"
383383+384384+let output_location_to_js = function
385385+ | Cpu -> Js_of_ocaml.Js.string "cpu"
386386+ | Gpu_buffer -> Js_of_ocaml.Js.string "gpu-buffer"
387387+388388+let log_level_to_string = function
389389+ | `Verbose -> "verbose"
390390+ | `Info -> "info"
391391+ | `Warning -> "warning"
392392+ | `Error -> "error"
393393+ | `Fatal -> "fatal"
394394+```
395395+396396+**Step 2: Verify it compiles**
397397+398398+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
399399+Expected: Build succeeds
400400+401401+**Step 3: Commit**
402402+403403+```
404404+feat: add internal conversion helpers for session options
405405+```
406406+407407+---
408408+409409+### Task 6: Tensor module implementation
410410+411411+The core of the bindings. Creates and reads ONNX tensors by constructing
412412+`new ort.Tensor(type, typedArray, dims)` via `Js.Unsafe`.
413413+414414+**Files:**
415415+- Modify: `lib/onnxrt.ml` — replace Tensor module
416416+417417+**Step 1: Replace the Tensor stub**
418418+419419+Replace the entire `module Tensor = struct ... end` block with:
420420+421421+```ocaml
422422+module Tensor = struct
423423+ type t = {
424424+ js_tensor : Js_of_ocaml.Js.Unsafe.any;
425425+ mutable disposed : bool;
426426+ }
427427+428428+ type location = Cpu | Gpu_buffer
429429+430430+ let check_not_disposed t =
431431+ if t.disposed then invalid_arg "Tensor has been disposed"
432432+433433+ let check_cpu t =
434434+ check_not_disposed t;
435435+ let loc = Js_helpers.get_string
436436+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in
437437+ if loc <> "cpu" then
438438+ invalid_arg "Tensor data is on GPU; use Tensor.download first"
439439+440440+ let check_dtype : type a b. (a, b) Dtype.t -> t -> unit =
441441+ fun expected t ->
442442+ let actual_str = Js_helpers.get_string
443443+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in
444444+ let expected_str = Dtype.to_string expected in
445445+ if actual_str <> expected_str then
446446+ failwith (Printf.sprintf "Dtype mismatch: tensor is %s, expected %s"
447447+ actual_str expected_str)
448448+449449+ (* Create a JS TypedArray from a Bigarray *)
450450+ let typed_array_of_bigarray :
451451+ type a b. (a, b) Dtype.t ->
452452+ (a, b, Bigarray.c_layout) Bigarray.Array1.t ->
453453+ Js_of_ocaml.Js.Unsafe.any =
454454+ fun dtype ba ->
455455+ let open Js_of_ocaml in
456456+ let ga = Bigarray.genarray_of_array1 ba in
457457+ let ta = Typed_array.from_genarray ga in
458458+ Js.Unsafe.coerce ta
459459+460460+ let of_bigarray1 :
461461+ type a b. (a, b) Dtype.t ->
462462+ (a, b, Bigarray.c_layout) Bigarray.Array1.t ->
463463+ dims:int array -> t =
464464+ fun dtype ba ~dims ->
465465+ let expected_size = Array.fold_left ( * ) 1 dims in
466466+ let actual_size = Bigarray.Array1.dim ba in
467467+ if expected_size <> actual_size then
468468+ invalid_arg (Printf.sprintf
469469+ "Tensor.of_bigarray1: dims product (%d) <> bigarray length (%d)"
470470+ expected_size actual_size);
471471+ let open Js_of_ocaml in
472472+ let ta = typed_array_of_bigarray dtype ba in
473473+ let js_tensor =
474474+ Js.Unsafe.new_obj
475475+ (Js.Unsafe.get (Js_helpers.ort ()) (Js.string "Tensor"))
476476+ [| Js.Unsafe.inject (Js.string (Dtype.to_string dtype));
477477+ ta;
478478+ Js.Unsafe.inject (Js_helpers.js_int_array dims) |]
479479+ in
480480+ { js_tensor = Js.Unsafe.coerce js_tensor; disposed = false }
481481+482482+ let of_bigarray :
483483+ type a b. (a, b) Dtype.t ->
484484+ (a, b, Bigarray.c_layout) Bigarray.Genarray.t -> t =
485485+ fun dtype ga ->
486486+ let dims = Bigarray.Genarray.dims ga in
487487+ let flat = Bigarray.reshape_1 ga (Array.fold_left ( * ) 1 dims) in
488488+ of_bigarray1 dtype flat ~dims
489489+490490+ let of_float32s data ~dims =
491491+ let expected_size = Array.fold_left ( * ) 1 dims in
492492+ if Array.length data <> expected_size then
493493+ invalid_arg (Printf.sprintf
494494+ "Tensor.of_float32s: array length (%d) <> dims product (%d)"
495495+ (Array.length data) expected_size);
496496+ let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout
497497+ expected_size in
498498+ Array.iteri (fun i v -> Bigarray.Array1.set ba i v) data;
499499+ of_bigarray1 Float32 ba ~dims
500500+501501+ let to_bigarray1_exn :
502502+ type a b. (a, b) Dtype.t -> t ->
503503+ (a, b, Bigarray.c_layout) Bigarray.Array1.t =
504504+ fun dtype t ->
505505+ check_cpu t;
506506+ check_dtype dtype t;
507507+ let open Js_of_ocaml in
508508+ let data : Js.Unsafe.any = Js.Unsafe.get t.js_tensor (Js.string "data") in
509509+ let ta = (Js.Unsafe.coerce data : Typed_array.arrayBufferView Js.t) in
510510+ let ga = Typed_array.to_genarray ta in
511511+ let size = Bigarray.Genarray.nth_dim ga 0 in
512512+ let ba = Bigarray.reshape_1 ga size in
513513+ (* The Bigarray kind from Typed_array.to_genarray matches the JS typed array.
514514+ We need to coerce it to match the expected dtype. The check_dtype call
515515+ above ensures this is safe. *)
516516+ (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t)
517517+518518+ let to_bigarray_exn :
519519+ type a b. (a, b) Dtype.t -> t ->
520520+ (a, b, Bigarray.c_layout) Bigarray.Genarray.t =
521521+ fun dtype t ->
522522+ let flat = to_bigarray1_exn dtype t in
523523+ let dims_js : Js_of_ocaml.Js.Unsafe.any =
524524+ Js_of_ocaml.Js.Unsafe.get t.js_tensor (Js_of_ocaml.Js.string "dims") in
525525+ let dims = Js_helpers.int_array_of_js (Js_of_ocaml.Js.Unsafe.coerce dims_js) in
526526+ Bigarray.genarray_of_array1 flat |> fun ga -> Bigarray.reshape ga dims
527527+528528+ let download :
529529+ type a b. (a, b) Dtype.t -> t ->
530530+ (a, b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t =
531531+ fun dtype t ->
532532+ check_not_disposed t;
533533+ check_dtype dtype t;
534534+ let open Js_of_ocaml in
535535+ let promise = Js.Unsafe.meth_call t.js_tensor "getData" [||] in
536536+ let open Lwt.Syntax in
537537+ let+ data = Promise_lwt.to_lwt promise in
538538+ let ta = (Js.Unsafe.coerce data : Typed_array.arrayBufferView Js.t) in
539539+ let ga = Typed_array.to_genarray ta in
540540+ let size = Bigarray.Genarray.nth_dim ga 0 in
541541+ let ba = Bigarray.reshape_1 ga size in
542542+ (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t)
543543+544544+ let dims t =
545545+ check_not_disposed t;
546546+ let open Js_of_ocaml in
547547+ let dims_js = Js.Unsafe.get t.js_tensor (Js.string "dims") in
548548+ Js_helpers.int_array_of_js (Js.Unsafe.coerce dims_js)
549549+550550+ let dtype t =
551551+ check_not_disposed t;
552552+ let type_str = Js_helpers.get_string
553553+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in
554554+ match Dtype.of_string type_str with
555555+ | Some p -> p
556556+ | None -> failwith (Printf.sprintf "Unknown tensor dtype: %s" type_str)
557557+558558+ let size t =
559559+ check_not_disposed t;
560560+ Js_helpers.get_int (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "size"
561561+562562+ let location t =
563563+ check_not_disposed t;
564564+ let loc = Js_helpers.get_string
565565+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in
566566+ match loc with
567567+ | "cpu" -> Cpu
568568+ | "gpu-buffer" -> Gpu_buffer
569569+ | s -> failwith (Printf.sprintf "Unknown tensor location: %s" s)
570570+571571+ let dispose t =
572572+ if not t.disposed then begin
573573+ let open Js_of_ocaml in
574574+ ignore (Js.Unsafe.meth_call t.js_tensor "dispose" [||] : Js.Unsafe.any);
575575+ t.disposed <- true
576576+ end
577577+end
578578+```
579579+580580+**Step 2: Verify it compiles**
581581+582582+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
583583+Expected: Build succeeds
584584+585585+**Step 3: Commit**
586586+587587+```
588588+feat: implement Tensor module with CPU and GPU support
589589+```
590590+591591+---
592592+593593+### Task 7: Session module implementation
594594+595595+Creates inference sessions, runs models, manages lifecycle.
596596+597597+**Files:**
598598+- Modify: `lib/onnxrt.ml` — replace Session module
599599+600600+**Step 1: Replace the Session stub**
601601+602602+Replace the entire `module Session = struct ... end` block with:
603603+604604+```ocaml
605605+module Session = struct
606606+ type t = {
607607+ js_session : Js_of_ocaml.Js.Unsafe.any;
608608+ input_names_ : string list;
609609+ output_names_ : string list;
610610+ }
611611+612612+ let build_options ?execution_providers ?graph_optimization
613613+ ?preferred_output_location ?log_level () =
614614+ let open Js_of_ocaml in
615615+ let pairs = ref [] in
616616+ (match execution_providers with
617617+ | Some eps ->
618618+ let js_eps = Js.array (Array.of_list
619619+ (List.map (fun ep ->
620620+ Js.Unsafe.inject (Js.string (Execution_provider.to_string ep)))
621621+ eps)) in
622622+ pairs := ("executionProviders", Js.Unsafe.inject js_eps) :: !pairs
623623+ | None -> ());
624624+ (match graph_optimization with
625625+ | Some go ->
626626+ pairs := ("graphOptimizationLevel",
627627+ Js.Unsafe.inject (Js.string (graph_optimization_to_string go)))
628628+ :: !pairs
629629+ | None -> ());
630630+ (match preferred_output_location with
631631+ | Some loc ->
632632+ pairs := ("preferredOutputLocation",
633633+ Js.Unsafe.inject (output_location_to_js loc))
634634+ :: !pairs
635635+ | None -> ());
636636+ (match log_level with
637637+ | Some level ->
638638+ pairs := ("logSeverityLevel",
639639+ Js.Unsafe.inject (Js.string (log_level_to_string level)))
640640+ :: !pairs
641641+ | None -> ());
642642+ Js.Unsafe.obj (Array.of_list !pairs)
643643+644644+ let wrap_session js_session =
645645+ let open Js_of_ocaml in
646646+ let input_names_ = Js_helpers.string_list_of_js_array
647647+ (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "inputNames"))) in
648648+ let output_names_ = Js_helpers.string_list_of_js_array
649649+ (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "outputNames"))) in
650650+ { js_session = Js.Unsafe.coerce js_session; input_names_; output_names_ }
651651+652652+ let create ?execution_providers ?graph_optimization
653653+ ?preferred_output_location ?log_level model_url () =
654654+ let open Js_of_ocaml in
655655+ let ort = Js_helpers.ort () in
656656+ let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in
657657+ let options = build_options ?execution_providers ?graph_optimization
658658+ ?preferred_output_location ?log_level () in
659659+ let promise = Js.Unsafe.meth_call inference_session "create"
660660+ [| Js.Unsafe.inject (Js.string model_url);
661661+ Js.Unsafe.inject options |] in
662662+ let open Lwt.Syntax in
663663+ let+ js_session = Promise_lwt.to_lwt promise in
664664+ wrap_session js_session
665665+666666+ let create_from_buffer (type a b) ?execution_providers ?graph_optimization
667667+ ?preferred_output_location ?log_level
668668+ (buffer : (a, b, Bigarray.c_layout) Bigarray.Array1.t) () =
669669+ let open Js_of_ocaml in
670670+ let ort = Js_helpers.ort () in
671671+ let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in
672672+ let options = build_options ?execution_providers ?graph_optimization
673673+ ?preferred_output_location ?log_level () in
674674+ (* Convert bigarray to Uint8Array via ArrayBuffer *)
675675+ let ga = Bigarray.genarray_of_array1 buffer in
676676+ let ta = Typed_array.from_genarray ga in
677677+ let ab : Typed_array.arrayBuffer Js.t =
678678+ Js.Unsafe.get (Js.Unsafe.coerce ta) (Js.string "buffer") in
679679+ let uint8 = Js.Unsafe.new_obj
680680+ (Js.Unsafe.global##._Uint8Array)
681681+ [| Js.Unsafe.inject ab |] in
682682+ let promise = Js.Unsafe.meth_call inference_session "create"
683683+ [| Js.Unsafe.inject uint8;
684684+ Js.Unsafe.inject options |] in
685685+ let open Lwt.Syntax in
686686+ let+ js_session = Promise_lwt.to_lwt promise in
687687+ wrap_session js_session
688688+689689+ let run t inputs =
690690+ let open Js_of_ocaml in
691691+ let feeds = Js.Unsafe.obj
692692+ (Array.of_list
693693+ (List.map (fun (name, (tensor : Tensor.t)) ->
694694+ (name, Js.Unsafe.inject tensor.js_tensor))
695695+ inputs)) in
696696+ let promise = Js.Unsafe.meth_call t.js_session "run"
697697+ [| Js.Unsafe.inject feeds |] in
698698+ let open Lwt.Syntax in
699699+ let+ results = Promise_lwt.to_lwt promise in
700700+ List.map (fun name ->
701701+ let js_tensor = Js.Unsafe.get results (Js.string name) in
702702+ (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor;
703703+ disposed = false }))
704704+ t.output_names_
705705+706706+ let run_with_outputs t inputs ~output_names =
707707+ let open Js_of_ocaml in
708708+ let feeds = Js.Unsafe.obj
709709+ (Array.of_list
710710+ (List.map (fun (name, (tensor : Tensor.t)) ->
711711+ (name, Js.Unsafe.inject tensor.js_tensor))
712712+ inputs)) in
713713+ let promise = Js.Unsafe.meth_call t.js_session "run"
714714+ [| Js.Unsafe.inject feeds |] in
715715+ let open Lwt.Syntax in
716716+ let+ results = Promise_lwt.to_lwt promise in
717717+ List.map (fun name ->
718718+ let js_tensor = Js.Unsafe.get results (Js.string name) in
719719+ (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor;
720720+ disposed = false }))
721721+ output_names
722722+723723+ let input_names t = t.input_names_
724724+ let output_names t = t.output_names_
725725+726726+ let release t =
727727+ let open Js_of_ocaml in
728728+ let promise = Js.Unsafe.meth_call t.js_session "release" [||] in
729729+ Promise_lwt.to_lwt promise |> Lwt.map (fun (_ : Js.Unsafe.any) -> ())
730730+end
731731+```
732732+733733+**Step 2: Verify it compiles**
734734+735735+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
736736+Expected: Build succeeds
737737+738738+**Step 3: Commit**
739739+740740+```
741741+feat: implement Session module (create, run, release)
742742+```
743743+744744+---
745745+746746+### Task 8: Final build verification and cleanup
747747+748748+Make sure the full library compiles cleanly with no warnings.
749749+750750+**Files:**
751751+- Review: `lib/onnxrt.ml`
752752+- Review: `lib/onnxrt.mli`
753753+754754+**Step 1: Full build with warnings enabled**
755755+756756+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1`
757757+Expected: Clean build, no warnings
758758+759759+**Step 2: Check module structure**
760760+761761+Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune describe pp lib/onnxrt.ml 2>&1 | head -20`
762762+Expected: Shows the preprocessed output with js_of_ocaml-ppx applied
763763+764764+**Step 3: Commit**
765765+766766+```
767767+feat: onnxrt library complete — ONNX Runtime Web bindings for OCaml
768768+```
769769+770770+---
771771+772772+## File Summary
773773+774774+| File | Purpose |
775775+|------|---------|
776776+| `lib/onnxrt.mli` | Public API (already written) |
777777+| `lib/onnxrt.ml` | Implementation of all public modules |
778778+| `lib/promise_lwt.mli` | Internal: Promise→Lwt bridge signature |
779779+| `lib/promise_lwt.ml` | Internal: Promise→Lwt bridge implementation |
780780+| `lib/js_helpers.ml` | Internal: shared JS interop utilities |
781781+| `lib/dune` | Build config |
782782+| `dune-project` | Project metadata and opam generation |
783783+784784+## Key Implementation Notes
785785+786786+1. **`Js.Unsafe` throughout**: All JS interop uses `Js.Unsafe.get`, `Js.Unsafe.set`,
787787+ `Js.Unsafe.meth_call`, `Js.Unsafe.new_obj`, and `Js.Unsafe.obj`. No typed class
788788+ bindings.
789789+790790+2. **`Obj.magic` in tensor conversion**: `to_bigarray1_exn` uses `Obj.magic` to cast
791791+ the Bigarray kind after a runtime dtype check. This is safe because
792792+ `Typed_array.to_genarray` returns a bigarray whose kind matches the JS
793793+ TypedArray, and `check_dtype` verifies the match. The GADT prevents misuse
794794+ at the public API boundary.
795795+796796+3. **Promise bridging**: Uses `.then(resolve, reject)` pattern. The reject handler
797797+ calls `error.toString()` to extract a message string.
798798+799799+4. **Tensor.t record**: Holds `js_tensor` as `Js.Unsafe.any` (erased type) plus a
800800+ `disposed` flag. The Session module constructs Tensor.t values directly
801801+ via record syntax (same-module access isn't needed since we use the dot
802802+ path `Tensor.{ ... }`).
803803+804804+5. **`ort` global**: All JS calls go through `Js_helpers.ort()` which checks that
805805+ the onnxruntime-web library is loaded before proceeding.
+21
onnxrt/dune-project
···11+(lang dune 3.17)
22+33+(using directory-targets 0.1)
44+55+(name onnxrt)
66+77+(generate_opam_files true)
88+99+(license ISC)
1010+1111+(package
1212+ (name onnxrt)
1313+ (synopsis "OCaml bindings to ONNX Runtime Web for browser-based ML inference")
1414+ (description
1515+ "Type-safe OCaml bindings to onnxruntime-web, enabling ML model inference in the browser via js_of_ocaml or wasm_of_ocaml. Supports WebAssembly (CPU) and WebGPU (GPU) execution providers.")
1616+ (depends
1717+ (ocaml (>= 5.2))
1818+ (js_of_ocaml (>= 5.8))
1919+ (js_of_ocaml-ppx (>= 5.8))
2020+ (lwt (>= 5.7))
2121+ (js_of_ocaml-lwt (>= 5.8))))
···11+open Js_of_ocaml
22+open Onnxrt
33+44+let max_length = 128
55+66+(* Inline promise-to-lwt bridge (Promise_lwt is internal to onnxrt) *)
77+let promise_to_lwt (promise : 'a Js.t) : 'a Lwt.t =
88+ let p, resolver = Lwt.wait () in
99+ let on_resolve result = Lwt.wakeup resolver result in
1010+ let on_reject error =
1111+ let msg =
1212+ Js.to_string
1313+ (Js.Unsafe.meth_call error "toString" [||] : Js.js_string Js.t)
1414+ in
1515+ Lwt.wakeup_exn resolver (Failure msg)
1616+ in
1717+ ignore
1818+ (Js.Unsafe.meth_call promise "then"
1919+ [| Js.Unsafe.inject (Js.wrap_callback on_resolve);
2020+ Js.Unsafe.inject (Js.wrap_callback on_reject) |]
2121+ : 'b Js.t);
2222+ p
2323+2424+let set_status msg =
2525+ let el = Dom_html.getElementById "status" in
2626+ el##.textContent := Js.some (Js.string msg)
2727+2828+let set_result html =
2929+ let el = Dom_html.getElementById "result" in
3030+ el##.innerHTML := Js.string html
3131+3232+let fetch_text url =
3333+ let open Lwt.Syntax in
3434+ let promise =
3535+ Js.Unsafe.meth_call Js.Unsafe.global "fetch"
3636+ [| Js.Unsafe.inject (Js.string url) |]
3737+ in
3838+ let* response = promise_to_lwt promise in
3939+ let text_promise = Js.Unsafe.meth_call response "text" [||] in
4040+ let+ text = promise_to_lwt text_promise in
4141+ Js.to_string (Js.Unsafe.coerce text : Js.js_string Js.t)
4242+4343+(* Access the global ort object *)
4444+let ort () : Js.Unsafe.any =
4545+ Js.Unsafe.get Js.Unsafe.global (Js.string "ort")
4646+4747+(* Create an int64 tensor via Js.Unsafe.
4848+ onnxrt's Dtype doesn't support int64, but DistilBERT requires it.
4949+ We construct: new ort.Tensor("int64", BigInt64Array.from(data.map(BigInt)), dims)
5050+ then wrap it as a Tensor.t using Obj.magic on the internal { js_tensor; disposed } record. *)
5151+let make_int64_tensor (data : int array) (dims : int array) : Tensor.t =
5252+ let ort_obj = ort () in
5353+ let tensor_ctor = Js.Unsafe.get ort_obj (Js.string "Tensor") in
5454+ (* Build a BigInt64Array from the int array *)
5555+ let js_data =
5656+ Js.array
5757+ (Array.map
5858+ (fun x -> Js.Unsafe.eval_string (Printf.sprintf "%dn" x))
5959+ data)
6060+ in
6161+ let bigint64_ctor = Js.Unsafe.global##._BigInt64Array in
6262+ let bigint64_arr =
6363+ Js.Unsafe.meth_call bigint64_ctor "from" [| Js.Unsafe.inject js_data |]
6464+ in
6565+ let js_dims =
6666+ Js.Unsafe.inject
6767+ (Js.array (Array.map (fun d -> Js.Unsafe.inject d) dims))
6868+ in
6969+ let js_tensor =
7070+ Js.Unsafe.new_obj tensor_ctor
7171+ [| Js.Unsafe.inject (Js.string "int64");
7272+ Js.Unsafe.inject bigint64_arr;
7373+ js_dims |]
7474+ in
7575+ (* Forge a Tensor.t value matching the internal record layout:
7676+ type t = { js_tensor : Js.Unsafe.any; mutable disposed : bool } *)
7777+ (Obj.magic (Js.Unsafe.coerce js_tensor, false) : Tensor.t)
7878+7979+let softmax (logits : float array) : float array =
8080+ let max_val = Array.fold_left max neg_infinity logits in
8181+ let exps = Array.map (fun x -> exp (x -. max_val)) logits in
8282+ let sum = Array.fold_left ( +. ) 0.0 exps in
8383+ Array.map (fun e -> e /. sum) exps
8484+8585+let () =
8686+ Lwt.async @@ fun () ->
8787+ let open Lwt.Syntax in
8888+ (* Load vocabulary *)
8989+ set_status "Loading vocabulary...";
9090+ let* vocab_text = fetch_text "vocab.txt" in
9191+ let vocab = Vocab.load_from_string vocab_text in
9292+9393+ (* Load model *)
9494+ set_status "Loading model...";
9595+ let* session = Session.create "model_quantized.onnx" () in
9696+ set_status "Ready!";
9797+9898+ (* Enable the button *)
9999+ let btn : Dom_html.buttonElement Js.t =
100100+ Js.Unsafe.coerce (Dom_html.getElementById "analyze-btn")
101101+ in
102102+ btn##.disabled := Js._false;
103103+104104+ (* Set up click handler *)
105105+ let textarea : Dom_html.textAreaElement Js.t =
106106+ Js.Unsafe.coerce (Dom_html.getElementById "input-text")
107107+ in
108108+ let _listener =
109109+ Dom_html.addEventListener btn Dom_html.Event.click
110110+ (Dom_html.handler (fun _evt ->
111111+ let text = Js.to_string textarea##.value in
112112+ if String.trim text <> "" then begin
113113+ Lwt.async (fun () ->
114114+ let open Lwt.Syntax in
115115+ set_status "Analyzing...";
116116+ set_result "";
117117+ let encoded = Tokenizer.encode vocab text ~max_length in
118118+119119+ (* Create int64 tensors *)
120120+ let input_ids_tensor =
121121+ make_int64_tensor encoded.Tokenizer.input_ids [| 1; max_length |]
122122+ in
123123+ let attention_mask_tensor =
124124+ make_int64_tensor encoded.Tokenizer.attention_mask [| 1; max_length |]
125125+ in
126126+127127+ (* Run inference *)
128128+ let* outputs =
129129+ Session.run session
130130+ [
131131+ ("input_ids", input_ids_tensor);
132132+ ("attention_mask", attention_mask_tensor);
133133+ ]
134134+ in
135135+136136+ (* Extract logits *)
137137+ let logits_tensor = List.assoc "logits" outputs in
138138+ let logits_data =
139139+ Tensor.to_bigarray1_exn Dtype.Float32 logits_tensor
140140+ in
141141+ let logits =
142142+ [|
143143+ Bigarray.Array1.get logits_data 0;
144144+ Bigarray.Array1.get logits_data 1;
145145+ |]
146146+ in
147147+148148+ (* Compute probabilities *)
149149+ let probs = softmax logits in
150150+ let neg_prob = probs.(0) in
151151+ let pos_prob = probs.(1) in
152152+ let label, confidence =
153153+ if pos_prob > neg_prob then ("POSITIVE", pos_prob)
154154+ else ("NEGATIVE", neg_prob)
155155+ in
156156+ let color =
157157+ if label = "POSITIVE" then "#4ec9b0" else "#f44747"
158158+ in
159159+ set_result
160160+ (Printf.sprintf
161161+ "<span style=\"color:%s;font-size:1.5em;font-weight:bold\">%s</span><br>Confidence: %.1f%%"
162162+ color label (confidence *. 100.0));
163163+ set_status "Ready!";
164164+165165+ (* Dispose tensors *)
166166+ Tensor.dispose input_ids_tensor;
167167+ Tensor.dispose attention_mask_tensor;
168168+ Tensor.dispose logits_tensor;
169169+170170+ Lwt.return_unit);
171171+ end;
172172+ Js._false))
173173+ Js._false
174174+ in
175175+ ignore _listener;
176176+ Lwt.return_unit
+101
onnxrt/example/sentiment/tokenizer.ml
···11+type encoded = {
22+ input_ids : int array;
33+ attention_mask : int array;
44+}
55+66+let is_punctuation c =
77+ match c with
88+ | '!' | '"' | '#' | '$' | '%' | '&' | '\'' | '(' | ')' | '*' | '+' | ','
99+ | '-' | '.' | '/' | ':' | ';' | '<' | '=' | '>' | '?' | '@' | '[' | '\\'
1010+ | ']' | '^' | '_' | '`' | '{' | '|' | '}' | '~' ->
1111+ true
1212+ | _ -> false
1313+1414+let split_on_punctuation word =
1515+ let len = String.length word in
1616+ if len = 0 then []
1717+ else begin
1818+ let tokens = ref [] in
1919+ let buf = Buffer.create 16 in
2020+ for i = 0 to len - 1 do
2121+ let c = word.[i] in
2222+ if is_punctuation c then begin
2323+ if Buffer.length buf > 0 then begin
2424+ tokens := Buffer.contents buf :: !tokens;
2525+ Buffer.clear buf
2626+ end;
2727+ tokens := String.make 1 c :: !tokens
2828+ end else
2929+ Buffer.add_char buf c
3030+ done;
3131+ if Buffer.length buf > 0 then
3232+ tokens := Buffer.contents buf :: !tokens;
3333+ List.rev !tokens
3434+ end
3535+3636+let wordpiece_tokenize vocab word =
3737+ let len = String.length word in
3838+ if len = 0 then []
3939+ else begin
4040+ let tokens = ref [] in
4141+ let start = ref 0 in
4242+ let failed = ref false in
4343+ while !start < len && not !failed do
4444+ let found = ref false in
4545+ let sub_end = ref len in
4646+ while !sub_end > !start && not !found do
4747+ let sub =
4848+ if !start > 0 then "##" ^ String.sub word !start (!sub_end - !start)
4949+ else String.sub word !start (!sub_end - !start)
5050+ in
5151+ match Vocab.find_token vocab sub with
5252+ | Some _id ->
5353+ tokens := sub :: !tokens;
5454+ start := !sub_end;
5555+ found := true
5656+ | None -> decr sub_end
5757+ done;
5858+ if not !found then begin
5959+ tokens := "[UNK]" :: !tokens;
6060+ failed := true
6161+ end
6262+ done;
6363+ List.rev !tokens
6464+ end
6565+6666+let encode vocab text ~max_length =
6767+ let text = String.lowercase_ascii text in
6868+ (* Split on whitespace *)
6969+ let words =
7070+ String.split_on_char ' ' text
7171+ |> List.concat_map (String.split_on_char '\t')
7272+ |> List.concat_map (String.split_on_char '\n')
7373+ |> List.filter (fun s -> s <> "")
7474+ in
7575+ (* Split punctuation from words, then WordPiece tokenize *)
7676+ let subtokens =
7777+ words
7878+ |> List.concat_map split_on_punctuation
7979+ |> List.concat_map (wordpiece_tokenize vocab)
8080+ in
8181+ (* Convert to IDs: [CLS] + tokens + [SEP] *)
8282+ let lookup tok =
8383+ match Vocab.find_token vocab tok with Some id -> id | None -> vocab.unk_id
8484+ in
8585+ (* Truncate to max_length - 2 to leave room for [CLS] and [SEP] *)
8686+ let max_tokens = max_length - 2 in
8787+ let subtokens =
8888+ if List.length subtokens > max_tokens then
8989+ List.filteri (fun i _ -> i < max_tokens) subtokens
9090+ else subtokens
9191+ in
9292+ let ids = List.map lookup subtokens in
9393+ let token_ids = [ vocab.cls_id ] @ ids @ [ vocab.sep_id ] in
9494+ let real_len = List.length token_ids in
9595+ let input_ids = Array.make max_length vocab.pad_id in
9696+ let attention_mask = Array.make max_length 0 in
9797+ List.iteri (fun i id -> input_ids.(i) <- id) token_ids;
9898+ for i = 0 to real_len - 1 do
9999+ attention_mask.(i) <- 1
100100+ done;
101101+ { input_ids; attention_mask }
···11+(** Internal JS interop helpers. Not part of the public API. *)
22+33+open Js_of_ocaml
44+55+(** Access the global [ort] object (onnxruntime-web). *)
66+let ort () : 'a Js.t =
77+ let o = Js.Unsafe.global##.ort in
88+ if Js.Optdef.test o then (Js.Unsafe.coerce o : 'a Js.t)
99+ else failwith "onnxruntime-web is not loaded: global 'ort' object not found"
1010+1111+(** Convert an OCaml string list to a JS array of JS strings. *)
1212+let js_string_array (strs : string list) : Js.js_string Js.t Js.js_array Js.t =
1313+ Js.array (Array.of_list (List.map Js.string strs))
1414+1515+(** Convert a JS array of JS strings to an OCaml string list. *)
1616+let string_list_of_js_array (arr : Js.js_string Js.t Js.js_array Js.t) : string list =
1717+ Array.to_list (Array.map Js.to_string (Js.to_array arr))
1818+1919+(** Convert an OCaml int array to a JS array of ints. *)
2020+let js_int_array (dims : int array) : int Js.js_array Js.t =
2121+ Js.array dims
2222+2323+(** Convert a JS array of ints to an OCaml int array. *)
2424+let int_array_of_js (arr : int Js.js_array Js.t) : int array =
2525+ Js.to_array arr
2626+2727+(** Read a string property from a JS object. *)
2828+let get_string (obj : 'a Js.t) (key : string) : string =
2929+ Js.to_string (Js.Unsafe.get obj (Js.string key))
3030+3131+(** Read an int property from a JS object. *)
3232+let get_int (obj : 'a Js.t) (key : string) : int =
3333+ Js.Unsafe.get obj (Js.string key)
3434+3535+(** Set a property on a JS object. *)
3636+let set (obj : 'a Js.t) (key : string) (value : 'b) : unit =
3737+ Js.Unsafe.set obj (Js.string key) value
3838+3939+(** Get a nested property: obj.key1.key2 *)
4040+let get_nested (obj : 'a Js.t) (key1 : string) (key2 : string) : 'b Js.t =
4141+ Js.Unsafe.get (Js.Unsafe.get obj (Js.string key1)) (Js.string key2)
+409
onnxrt/lib/onnxrt.ml
···11+module Dtype = struct
22+ type ('ocaml, 'elt) t =
33+ | Float32 : (float, Bigarray.float32_elt) t
44+ | Float64 : (float, Bigarray.float64_elt) t
55+ | Int8 : (int, Bigarray.int8_signed_elt) t
66+ | Uint8 : (int, Bigarray.int8_unsigned_elt) t
77+ | Int16 : (int, Bigarray.int16_signed_elt) t
88+ | Uint16 : (int, Bigarray.int16_unsigned_elt) t
99+ | Int32 : (int32, Bigarray.int32_elt) t
1010+1111+ type packed = Pack : ('ocaml, 'elt) t -> packed
1212+1313+ let to_string : type a b. (a, b) t -> string = function
1414+ | Float32 -> "float32"
1515+ | Float64 -> "float64"
1616+ | Int8 -> "int8"
1717+ | Uint8 -> "uint8"
1818+ | Int16 -> "int16"
1919+ | Uint16 -> "uint16"
2020+ | Int32 -> "int32"
2121+2222+ let of_string = function
2323+ | "float32" -> Some (Pack Float32)
2424+ | "float64" -> Some (Pack Float64)
2525+ | "int8" -> Some (Pack Int8)
2626+ | "uint8" -> Some (Pack Uint8)
2727+ | "int16" -> Some (Pack Int16)
2828+ | "uint16" -> Some (Pack Uint16)
2929+ | "int32" -> Some (Pack Int32)
3030+ | _ -> None
3131+3232+ let equal : type a b c d. (a, b) t -> (c, d) t -> bool =
3333+ fun a b ->
3434+ match (a, b) with
3535+ | Float32, Float32 -> true
3636+ | Float64, Float64 -> true
3737+ | Int8, Int8 -> true
3838+ | Uint8, Uint8 -> true
3939+ | Int16, Int16 -> true
4040+ | Uint16, Uint16 -> true
4141+ | Int32, Int32 -> true
4242+ | _ -> false
4343+4444+ let _to_bigarray_kind : type a b. (a, b) t -> (a, b) Bigarray.kind = function
4545+ | Float32 -> Bigarray.float32
4646+ | Float64 -> Bigarray.float64
4747+ | Int8 -> Bigarray.int8_signed
4848+ | Uint8 -> Bigarray.int8_unsigned
4949+ | Int16 -> Bigarray.int16_signed
5050+ | Uint16 -> Bigarray.int16_unsigned
5151+ | Int32 -> Bigarray.int32
5252+end
5353+5454+module Tensor = struct
5555+ type t = {
5656+ js_tensor : Js_of_ocaml.Js.Unsafe.any;
5757+ mutable disposed : bool;
5858+ }
5959+6060+ type location = Cpu | Gpu_buffer
6161+6262+ let check_not_disposed t =
6363+ if t.disposed then invalid_arg "Tensor has been disposed"
6464+6565+ let check_cpu t =
6666+ check_not_disposed t;
6767+ let loc = Js_helpers.get_string
6868+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in
6969+ if loc <> "cpu" then
7070+ invalid_arg "Tensor data is on GPU; use Tensor.download first"
7171+7272+ let check_dtype : type a b. (a, b) Dtype.t -> t -> unit =
7373+ fun expected t ->
7474+ let actual_str = Js_helpers.get_string
7575+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in
7676+ let expected_str = Dtype.to_string expected in
7777+ if actual_str <> expected_str then
7878+ failwith (Printf.sprintf "Dtype mismatch: tensor is %s, expected %s"
7979+ actual_str expected_str)
8080+8181+ let typed_array_of_bigarray :
8282+ type a b. (a, b) Dtype.t ->
8383+ (a, b, Bigarray.c_layout) Bigarray.Array1.t ->
8484+ Js_of_ocaml.Js.Unsafe.any =
8585+ fun dtype ba ->
8686+ let open Js_of_ocaml in
8787+ let ga = Bigarray.genarray_of_array1 ba in
8888+ match dtype with
8989+ | Dtype.Float32 ->
9090+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Float32 ga)
9191+ | Dtype.Float64 ->
9292+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Float64 ga)
9393+ | Dtype.Int8 ->
9494+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int8_signed ga)
9595+ | Dtype.Uint8 ->
9696+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int8_unsigned ga)
9797+ | Dtype.Int16 ->
9898+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int16_signed ga)
9999+ | Dtype.Uint16 ->
100100+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int16_unsigned ga)
101101+ | Dtype.Int32 ->
102102+ Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int32_signed ga)
103103+104104+ let of_bigarray1 :
105105+ type a b. (a, b) Dtype.t ->
106106+ (a, b, Bigarray.c_layout) Bigarray.Array1.t ->
107107+ dims:int array -> t =
108108+ fun dtype ba ~dims ->
109109+ let expected_size = Array.fold_left ( * ) 1 dims in
110110+ let actual_size = Bigarray.Array1.dim ba in
111111+ if expected_size <> actual_size then
112112+ invalid_arg (Printf.sprintf
113113+ "Tensor.of_bigarray1: dims product (%d) <> bigarray length (%d)"
114114+ expected_size actual_size);
115115+ let open Js_of_ocaml in
116116+ let ta = typed_array_of_bigarray dtype ba in
117117+ let js_tensor =
118118+ Js.Unsafe.new_obj
119119+ (Js.Unsafe.get (Js_helpers.ort ()) (Js.string "Tensor"))
120120+ [| Js.Unsafe.inject (Js.string (Dtype.to_string dtype));
121121+ ta;
122122+ Js.Unsafe.inject (Js_helpers.js_int_array dims) |]
123123+ in
124124+ { js_tensor = Js.Unsafe.coerce js_tensor; disposed = false }
125125+126126+ let of_bigarray :
127127+ type a b. (a, b) Dtype.t ->
128128+ (a, b, Bigarray.c_layout) Bigarray.Genarray.t -> t =
129129+ fun dtype ga ->
130130+ let dims = Bigarray.Genarray.dims ga in
131131+ let flat = Bigarray.reshape_1 ga (Array.fold_left ( * ) 1 dims) in
132132+ of_bigarray1 dtype flat ~dims
133133+134134+ let of_float32s data ~dims =
135135+ let expected_size = Array.fold_left ( * ) 1 dims in
136136+ if Array.length data <> expected_size then
137137+ invalid_arg (Printf.sprintf
138138+ "Tensor.of_float32s: array length (%d) <> dims product (%d)"
139139+ (Array.length data) expected_size);
140140+ let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout
141141+ expected_size in
142142+ Array.iteri (fun i v -> Bigarray.Array1.set ba i v) data;
143143+ of_bigarray1 Float32 ba ~dims
144144+145145+ let to_bigarray1_exn :
146146+ type a b. (a, b) Dtype.t -> t ->
147147+ (a, b, Bigarray.c_layout) Bigarray.Array1.t =
148148+ fun dtype t ->
149149+ check_cpu t;
150150+ check_dtype dtype t;
151151+ let open Js_of_ocaml in
152152+ let data : Js.Unsafe.any = Js.Unsafe.get t.js_tensor (Js.string "data") in
153153+ let ta = (Js.Unsafe.coerce data : (_, _, _) Typed_array.typedArray Js.t) in
154154+ let ga = Typed_array.to_genarray ta in
155155+ let size = Bigarray.Genarray.nth_dim ga 0 in
156156+ let ba = Bigarray.reshape_1 ga size in
157157+ (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t)
158158+159159+ let to_bigarray_exn :
160160+ type a b. (a, b) Dtype.t -> t ->
161161+ (a, b, Bigarray.c_layout) Bigarray.Genarray.t =
162162+ fun dtype t ->
163163+ let flat = to_bigarray1_exn dtype t in
164164+ let dims_js : Js_of_ocaml.Js.Unsafe.any =
165165+ Js_of_ocaml.Js.Unsafe.get t.js_tensor (Js_of_ocaml.Js.string "dims") in
166166+ let dims = Js_helpers.int_array_of_js (Js_of_ocaml.Js.Unsafe.coerce dims_js) in
167167+ Bigarray.genarray_of_array1 flat |> fun ga -> Bigarray.reshape ga dims
168168+169169+ let download :
170170+ type a b. (a, b) Dtype.t -> t ->
171171+ (a, b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t =
172172+ fun dtype t ->
173173+ check_not_disposed t;
174174+ check_dtype dtype t;
175175+ let open Js_of_ocaml in
176176+ let promise = Js.Unsafe.meth_call t.js_tensor "getData" [||] in
177177+ let open Lwt.Syntax in
178178+ let+ data = Promise_lwt.to_lwt promise in
179179+ let ta = (Js.Unsafe.coerce data : (_, _, _) Typed_array.typedArray Js.t) in
180180+ let ga = Typed_array.to_genarray ta in
181181+ let size = Bigarray.Genarray.nth_dim ga 0 in
182182+ let ba = Bigarray.reshape_1 ga size in
183183+ (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t)
184184+185185+ let dims t =
186186+ check_not_disposed t;
187187+ let open Js_of_ocaml in
188188+ let dims_js = Js.Unsafe.get t.js_tensor (Js.string "dims") in
189189+ Js_helpers.int_array_of_js (Js.Unsafe.coerce dims_js)
190190+191191+ let dtype t =
192192+ check_not_disposed t;
193193+ let type_str = Js_helpers.get_string
194194+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in
195195+ match Dtype.of_string type_str with
196196+ | Some p -> p
197197+ | None -> failwith (Printf.sprintf "Unknown tensor dtype: %s" type_str)
198198+199199+ let size t =
200200+ check_not_disposed t;
201201+ Js_helpers.get_int (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "size"
202202+203203+ let location t =
204204+ check_not_disposed t;
205205+ let loc = Js_helpers.get_string
206206+ (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in
207207+ match loc with
208208+ | "cpu" -> Cpu
209209+ | "gpu-buffer" -> Gpu_buffer
210210+ | s -> failwith (Printf.sprintf "Unknown tensor location: %s" s)
211211+212212+ let dispose t =
213213+ if not t.disposed then begin
214214+ let open Js_of_ocaml in
215215+ ignore (Js.Unsafe.meth_call t.js_tensor "dispose" [||] : Js.Unsafe.any);
216216+ t.disposed <- true
217217+ end
218218+end
219219+220220+module Execution_provider = struct
221221+ type t = Wasm | Webgpu
222222+ let to_string = function Wasm -> "wasm" | Webgpu -> "webgpu"
223223+end
224224+225225+type output_location = Cpu | Gpu_buffer
226226+type graph_optimization = Disabled | Basic | Extended | All
227227+228228+let graph_optimization_to_string = function
229229+ | Disabled -> "disabled"
230230+ | Basic -> "basic"
231231+ | Extended -> "extended"
232232+ | All -> "all"
233233+234234+let output_location_to_js = function
235235+ | Cpu -> Js_of_ocaml.Js.string "cpu"
236236+ | Gpu_buffer -> Js_of_ocaml.Js.string "gpu-buffer"
237237+238238+let log_level_to_string = function
239239+ | `Verbose -> "verbose"
240240+ | `Info -> "info"
241241+ | `Warning -> "warning"
242242+ | `Error -> "error"
243243+ | `Fatal -> "fatal"
244244+245245+module Session = struct
246246+ type t = {
247247+ js_session : Js_of_ocaml.Js.Unsafe.any;
248248+ input_names_ : string list;
249249+ output_names_ : string list;
250250+ }
251251+252252+ let build_options ?execution_providers ?graph_optimization
253253+ ?preferred_output_location ?log_level () =
254254+ let open Js_of_ocaml in
255255+ let pairs = ref [] in
256256+ (match execution_providers with
257257+ | Some eps ->
258258+ let js_eps = Js.array (Array.of_list
259259+ (List.map (fun ep ->
260260+ Js.Unsafe.inject (Js.string (Execution_provider.to_string ep)))
261261+ eps)) in
262262+ pairs := ("executionProviders", Js.Unsafe.inject js_eps) :: !pairs
263263+ | None -> ());
264264+ (match graph_optimization with
265265+ | Some go ->
266266+ pairs := ("graphOptimizationLevel",
267267+ Js.Unsafe.inject (Js.string (graph_optimization_to_string go)))
268268+ :: !pairs
269269+ | None -> ());
270270+ (match preferred_output_location with
271271+ | Some loc ->
272272+ pairs := ("preferredOutputLocation",
273273+ Js.Unsafe.inject (output_location_to_js loc))
274274+ :: !pairs
275275+ | None -> ());
276276+ (match log_level with
277277+ | Some level ->
278278+ pairs := ("logSeverityLevel",
279279+ Js.Unsafe.inject (Js.string (log_level_to_string level)))
280280+ :: !pairs
281281+ | None -> ());
282282+ Js.Unsafe.obj (Array.of_list !pairs)
283283+284284+ let wrap_session js_session =
285285+ let open Js_of_ocaml in
286286+ let input_names_ = Js_helpers.string_list_of_js_array
287287+ (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "inputNames"))) in
288288+ let output_names_ = Js_helpers.string_list_of_js_array
289289+ (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "outputNames"))) in
290290+ { js_session = Js.Unsafe.coerce js_session; input_names_; output_names_ }
291291+292292+ let create ?execution_providers ?graph_optimization
293293+ ?preferred_output_location ?log_level model_url () =
294294+ let open Js_of_ocaml in
295295+ let ort = Js_helpers.ort () in
296296+ let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in
297297+ let options = build_options ?execution_providers ?graph_optimization
298298+ ?preferred_output_location ?log_level () in
299299+ let promise = Js.Unsafe.meth_call inference_session "create"
300300+ [| Js.Unsafe.inject (Js.string model_url);
301301+ Js.Unsafe.inject options |] in
302302+ let open Lwt.Syntax in
303303+ let+ js_session = Promise_lwt.to_lwt promise in
304304+ wrap_session js_session
305305+306306+ let create_from_buffer (type a b) ?execution_providers ?graph_optimization
307307+ ?preferred_output_location ?log_level
308308+ (buffer : (a, b, Bigarray.c_layout) Bigarray.Array1.t) () =
309309+ let open Js_of_ocaml in
310310+ let ort = Js_helpers.ort () in
311311+ let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in
312312+ let options = build_options ?execution_providers ?graph_optimization
313313+ ?preferred_output_location ?log_level () in
314314+ let ga = Bigarray.genarray_of_array1 buffer in
315315+ let ta = Typed_array.from_genarray Typed_array.Int8_unsigned (Obj.magic ga) in
316316+ let ab : Typed_array.arrayBuffer Js.t =
317317+ Js.Unsafe.get (Js.Unsafe.coerce ta) (Js.string "buffer") in
318318+ let uint8 = Js.Unsafe.new_obj
319319+ (Js.Unsafe.global##._Uint8Array)
320320+ [| Js.Unsafe.inject ab |] in
321321+ let promise = Js.Unsafe.meth_call inference_session "create"
322322+ [| Js.Unsafe.inject uint8;
323323+ Js.Unsafe.inject options |] in
324324+ let open Lwt.Syntax in
325325+ let+ js_session = Promise_lwt.to_lwt promise in
326326+ wrap_session js_session
327327+328328+ let run t inputs =
329329+ let open Js_of_ocaml in
330330+ let feeds = Js.Unsafe.obj
331331+ (Array.of_list
332332+ (List.map (fun (name, (tensor : Tensor.t)) ->
333333+ (name, Js.Unsafe.inject tensor.js_tensor))
334334+ inputs)) in
335335+ let promise = Js.Unsafe.meth_call t.js_session "run"
336336+ [| Js.Unsafe.inject feeds |] in
337337+ let open Lwt.Syntax in
338338+ let+ results = Promise_lwt.to_lwt promise in
339339+ List.map (fun name ->
340340+ let js_tensor = Js.Unsafe.get results (Js.string name) in
341341+ (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor;
342342+ disposed = false }))
343343+ t.output_names_
344344+345345+ let run_with_outputs t inputs ~output_names =
346346+ let open Js_of_ocaml in
347347+ let feeds = Js.Unsafe.obj
348348+ (Array.of_list
349349+ (List.map (fun (name, (tensor : Tensor.t)) ->
350350+ (name, Js.Unsafe.inject tensor.js_tensor))
351351+ inputs)) in
352352+ let promise = Js.Unsafe.meth_call t.js_session "run"
353353+ [| Js.Unsafe.inject feeds |] in
354354+ let open Lwt.Syntax in
355355+ let+ results = Promise_lwt.to_lwt promise in
356356+ List.map (fun name ->
357357+ let js_tensor = Js.Unsafe.get results (Js.string name) in
358358+ (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor;
359359+ disposed = false }))
360360+ output_names
361361+362362+ let input_names t = t.input_names_
363363+ let output_names t = t.output_names_
364364+365365+ let release t =
366366+ let open Js_of_ocaml in
367367+ let promise = Js.Unsafe.meth_call t.js_session "release" [||] in
368368+ Promise_lwt.to_lwt promise |> Lwt.map (fun (_ : Js.Unsafe.any) -> ())
369369+end
370370+371371+module Env = struct
372372+ module Wasm = struct
373373+ let set_num_threads n =
374374+ let ort = Js_helpers.ort () in
375375+ Js_helpers.set
376376+ (Js_helpers.get_nested ort "env" "wasm")
377377+ "numThreads" n
378378+379379+ let set_simd enabled =
380380+ let ort = Js_helpers.ort () in
381381+ Js_helpers.set
382382+ (Js_helpers.get_nested ort "env" "wasm")
383383+ "simd" (Js_of_ocaml.Js.bool enabled)
384384+385385+ let set_proxy enabled =
386386+ let ort = Js_helpers.ort () in
387387+ Js_helpers.set
388388+ (Js_helpers.get_nested ort "env" "wasm")
389389+ "proxy" (Js_of_ocaml.Js.bool enabled)
390390+391391+ let set_wasm_paths prefix =
392392+ let ort = Js_helpers.ort () in
393393+ Js_helpers.set
394394+ (Js_helpers.get_nested ort "env" "wasm")
395395+ "wasmPaths" (Js_of_ocaml.Js.string prefix)
396396+ end
397397+398398+ module Webgpu = struct
399399+ let set_power_preference pref =
400400+ let ort = Js_helpers.ort () in
401401+ let s = match pref with
402402+ | `High_performance -> "high-performance"
403403+ | `Low_power -> "low-power"
404404+ in
405405+ Js_helpers.set
406406+ (Js_helpers.get_nested ort "env" "webgpu")
407407+ "powerPreference" (Js_of_ocaml.Js.string s)
408408+ end
409409+end
+476
onnxrt/lib/onnxrt.mli
···11+(** ONNX Runtime Web bindings for OCaml.
22+33+ This library provides OCaml bindings to
44+ {{:https://onnxruntime.ai/} ONNX Runtime Web}, enabling ML model inference
55+ in the browser via [js_of_ocaml] or [wasm_of_ocaml].
66+77+ The bindings target the [onnxruntime-web] npm package and support both the
88+ WebAssembly (CPU) and WebGPU (GPU) execution providers.
99+1010+ {1 Quick start}
1111+1212+ {[
1313+ open Onnxrt
1414+1515+ let () =
1616+ Lwt.async @@ fun () ->
1717+ let open Lwt.Syntax in
1818+ (* Configure before creating any session *)
1919+ Env.Wasm.set_num_threads 2;
2020+ (* Load model *)
2121+ let* session = Session.create "model.onnx" () in
2222+ (* Prepare input *)
2323+ let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout (3 * 224 * 224) in
2424+ (* ... fill ba with image data ... *)
2525+ let input = Tensor.of_bigarray1 Dtype.Float32 ba ~dims:[| 1; 3; 224; 224 |] in
2626+ (* Run inference *)
2727+ let* outputs = Session.run session [ "input", input ] in
2828+ let output = List.assoc "output" outputs in
2929+ let result = Tensor.to_bigarray1_exn Dtype.Float32 output in
3030+ (* Clean up *)
3131+ Tensor.dispose output;
3232+ let* () = Session.release session in
3333+ Lwt.return_unit
3434+ ]}
3535+3636+ {1 Architecture}
3737+3838+ The library is structured in two layers:
3939+4040+ - {b Low-level}: Direct bindings to the onnxruntime-web JavaScript API via
4141+ [Js.Unsafe]. Not exposed publicly.
4242+ - {b High-level}: Pure OCaml types with {!Bigarray} for tensor data and
4343+ {!Lwt.t} for async operations. This is the public API documented here.
4444+4545+ {1 Execution providers}
4646+4747+ ONNX Runtime Web supports multiple backends for executing model operators:
4848+4949+ - {!Execution_provider.Wasm}: CPU inference via WebAssembly with SIMD and
5050+ optional multi-threading. Supports {b all} ONNX operators. This is the
5151+ default and most portable backend.
5252+ - {!Execution_provider.Webgpu}: GPU inference via WebGPU compute shaders.
5353+ Supports ~140 operators; unsupported operators fall back to WASM
5454+ automatically, though each fallback incurs a GPU↔CPU data transfer.
5555+5656+ Execution providers are specified as a preference list when creating a
5757+ session. The runtime tries each in order and falls back to the next:
5858+5959+ {[
6060+ Session.create "model.onnx"
6161+ ~execution_providers:[ Webgpu; Wasm ]
6262+ ()
6363+ ]}
6464+6565+ {1 Threading model}
6666+6767+ All operations that may block return [Lwt.t] promises. The WASM backend may
6868+ use internal Web Workers for multi-threading (transparent to the caller, but
6969+ requires [SharedArrayBuffer] and cross-origin isolation headers). WebGPU
7070+ dispatches compute shaders on the GPU asynchronously.
7171+7272+ {1 GPU tensors}
7373+7474+ When using the WebGPU backend, tensors can reside on the GPU to avoid
7575+ CPU↔GPU transfers between chained inference calls. See {!Tensor.location},
7676+ {!Tensor.download}, and {!Session.create} with
7777+ [~preferred_output_location:`Gpu_buffer].
7878+7979+ {1 Prerequisites}
8080+8181+ The [onnxruntime-web] npm package must be loaded in the JavaScript
8282+ environment before using this library. For WebGPU support, import from
8383+ [onnxruntime-web/webgpu]. The WASM files ([ort-wasm-simd-threaded.wasm]
8484+ etc.) must be served at a path configured via {!Env.Wasm.set_wasm_paths}.
8585+*)
8686+8787+(** {1 Data types} *)
8888+8989+(** Tensor element types.
9090+9191+ Each constructor carries the correspondence between the ONNX type name,
9292+ the OCaml value type, and the {!Bigarray} element type. This allows
9393+ type-safe tensor creation and extraction via GADTs. *)
9494+module Dtype : sig
9595+ (** A tensor element type, parameterised by the OCaml value type ['ocaml]
9696+ and the Bigarray element kind ['elt]. *)
9797+ type ('ocaml, 'elt) t =
9898+ | Float32 : (float, Bigarray.float32_elt) t
9999+ (** 32-bit floating point. The most common type for ML models. *)
100100+ | Float64 : (float, Bigarray.float64_elt) t
101101+ (** 64-bit floating point. *)
102102+ | Int8 : (int, Bigarray.int8_signed_elt) t
103103+ (** Signed 8-bit integer. Used in quantized models. *)
104104+ | Uint8 : (int, Bigarray.int8_unsigned_elt) t
105105+ (** Unsigned 8-bit integer. Common for image data and quantized
106106+ models. *)
107107+ | Int16 : (int, Bigarray.int16_signed_elt) t
108108+ (** Signed 16-bit integer. *)
109109+ | Uint16 : (int, Bigarray.int16_unsigned_elt) t
110110+ (** Unsigned 16-bit integer. *)
111111+ | Int32 : (int32, Bigarray.int32_elt) t
112112+ (** 32-bit integer. Common for token IDs in NLP models. *)
113113+114114+ (** An existentially packed dtype for cases where the element type is only
115115+ known at runtime (e.g. reading a model's output dtype). *)
116116+ type packed = Pack : ('ocaml, 'elt) t -> packed
117117+118118+ val to_string : ('ocaml, 'elt) t -> string
119119+ (** [to_string dtype] returns the ONNX type name (e.g. ["float32"],
120120+ ["int32"]). *)
121121+122122+ val of_string : string -> packed option
123123+ (** [of_string s] parses an ONNX type name. Returns [None] for unsupported
124124+ types. *)
125125+126126+ val equal : ('a, 'b) t -> ('c, 'd) t -> bool
127127+ (** [equal a b] returns [true] if [a] and [b] represent the same element
128128+ type. *)
129129+end
130130+131131+(** {1 Tensors} *)
132132+133133+(** Multi-dimensional typed arrays for model input and output.
134134+135135+ Tensors are the primary data exchange type between OCaml and the ONNX
136136+ runtime. On the CPU side, they are backed by JavaScript TypedArrays which
137137+ share memory with OCaml {!Bigarray} values (zero-copy in [js_of_ocaml]).
138138+139139+ {2 Lifecycle}
140140+141141+ Tensors obtained from {!Session.run} should be {!dispose}d when no longer
142142+ needed. For CPU tensors this is a hint to the garbage collector; for GPU
143143+ tensors it releases the underlying [GPUBuffer] and failure to dispose will
144144+ leak GPU memory.
145145+146146+ {2 GPU tensors}
147147+148148+ When a session is configured with
149149+ [~preferred_output_location:`Gpu_buffer], output tensors reside on the GPU.
150150+ Their data is not accessible synchronously — use {!download} to transfer
151151+ to CPU, or pass them directly as input to another {!Session.run} call to
152152+ keep computation on the GPU. *)
153153+module Tensor : sig
154154+ (** An opaque tensor handle. *)
155155+ type t
156156+157157+ (** Where the tensor's data is stored. *)
158158+ type location =
159159+ | Cpu
160160+ (** Data is in CPU memory (a JavaScript TypedArray). Accessible
161161+ synchronously via {!to_bigarray1_exn}. *)
162162+ | Gpu_buffer
163163+ (** Data is in a WebGPU GPUBuffer. Must be {!download}ed before
164164+ CPU-side access, or passed directly to {!Session.run}. *)
165165+166166+ (** {2 Creating tensors} *)
167167+168168+ val of_bigarray1 :
169169+ ('a, 'b) Dtype.t ->
170170+ ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t ->
171171+ dims:int array ->
172172+ t
173173+ (** [of_bigarray1 dtype ba ~dims] creates a tensor from a 1-dimensional
174174+ bigarray. The [dims] array specifies the logical shape (e.g.
175175+ [[| 1; 3; 224; 224 |]]). The product of [dims] must equal the length
176176+ of [ba].
177177+178178+ The bigarray's underlying buffer is shared with the tensor (zero-copy
179179+ in [js_of_ocaml]). Modifying [ba] after tensor creation will affect the
180180+ tensor's data.
181181+182182+ @raise Invalid_argument if [Array.fold_left ( * ) 1 dims <> Bigarray.Array1.dim ba] *)
183183+184184+ val of_bigarray :
185185+ ('a, 'b) Dtype.t ->
186186+ ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t ->
187187+ t
188188+ (** [of_bigarray dtype ga] creates a tensor from a generic bigarray, using
189189+ the bigarray's dimensions as the tensor shape. Zero-copy. *)
190190+191191+ val of_float32s : float array -> dims:int array -> t
192192+ (** [of_float32s data ~dims] creates a Float32 tensor from an OCaml float
193193+ array. Copies the data into a new Float32Array.
194194+195195+ @raise Invalid_argument if [Array.length data] doesn't match the product
196196+ of [dims] *)
197197+198198+ (** {2 Reading tensor data} *)
199199+200200+ val to_bigarray1_exn :
201201+ ('a, 'b) Dtype.t ->
202202+ t ->
203203+ ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t
204204+ (** [to_bigarray1_exn dtype tensor] returns the tensor's data as a
205205+ flat 1-dimensional bigarray. Zero-copy when possible.
206206+207207+ @raise Invalid_argument if the tensor is on the GPU (use {!download}
208208+ first)
209209+ @raise Failure if [dtype] does not match the tensor's actual dtype *)
210210+211211+ val to_bigarray_exn :
212212+ ('a, 'b) Dtype.t ->
213213+ t ->
214214+ ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t
215215+ (** [to_bigarray_exn dtype tensor] returns the tensor's data as a generic
216216+ bigarray with the tensor's shape as dimensions. Zero-copy when possible.
217217+218218+ @raise Invalid_argument if the tensor is on the GPU
219219+ @raise Failure if [dtype] does not match the tensor's actual dtype *)
220220+221221+ val download :
222222+ ('a, 'b) Dtype.t ->
223223+ t ->
224224+ ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t
225225+ (** [download dtype tensor] retrieves the tensor's data, transferring from
226226+ GPU to CPU if necessary. For CPU tensors, this resolves immediately.
227227+228228+ This is the only way to access data from a GPU tensor.
229229+230230+ @raise Failure if [dtype] does not match the tensor's actual dtype *)
231231+232232+ (** {2 Tensor metadata} *)
233233+234234+ val dims : t -> int array
235235+ (** [dims tensor] returns the tensor's shape (e.g. [[| 1; 3; 224; 224 |]]). *)
236236+237237+ val dtype : t -> Dtype.packed
238238+ (** [dtype tensor] returns the tensor's element type as a packed value.
239239+ Use pattern matching to recover the type:
240240+241241+ {[
242242+ match Tensor.dtype t with
243243+ | Dtype.Pack Float32 -> (* ... *)
244244+ | Dtype.Pack Int32 -> (* ... *)
245245+ | _ -> failwith "unexpected dtype"
246246+ ]} *)
247247+248248+ val size : t -> int
249249+ (** [size tensor] returns the total number of elements (product of dims). *)
250250+251251+ val location : t -> location
252252+ (** [location tensor] returns where the tensor's data currently resides. *)
253253+254254+ (** {2 Lifecycle} *)
255255+256256+ val dispose : t -> unit
257257+ (** [dispose tensor] releases the tensor's resources. For CPU tensors, drops
258258+ the internal reference (data may still be accessible via a bigarray alias).
259259+ For GPU tensors, destroys the underlying [GPUBuffer]. Always dispose GPU
260260+ tensors to avoid memory leaks.
261261+262262+ After disposal, any access to the tensor's data raises. *)
263263+end
264264+265265+(** {1 Inference sessions} *)
266266+267267+(** Execution providers determine how model operators are executed.
268268+269269+ Providers are specified as a preference list when creating a session. The
270270+ runtime tries each in order, falling back to the next if unavailable.
271271+ The WASM provider is always available as a final fallback. *)
272272+module Execution_provider : sig
273273+ type t =
274274+ | Wasm
275275+ (** CPU inference via WebAssembly. Supports all ONNX operators. Uses
276276+ SIMD and optional multi-threading. This is the default. *)
277277+ | Webgpu
278278+ (** GPU inference via WebGPU compute shaders. Requires a browser with
279279+ WebGPU support (Chrome 113+, Firefox 141+, Safari 26+). Operators
280280+ without WebGPU kernels fall back to WASM automatically. *)
281281+282282+ val to_string : t -> string
283283+ (** [to_string ep] returns the JavaScript name (["wasm"] or ["webgpu"]). *)
284284+end
285285+286286+(** Where session outputs should be placed. *)
287287+type output_location =
288288+ | Cpu
289289+ (** Transfer results to CPU (default). Data is immediately accessible. *)
290290+ | Gpu_buffer
291291+ (** Keep results on the GPU. Avoids GPU→CPU transfer overhead when
292292+ chaining inference calls. Use {!Tensor.download} to read the data. *)
293293+294294+(** Graph optimization level applied during session creation. *)
295295+type graph_optimization =
296296+ | Disabled (** No graph optimizations. *)
297297+ | Basic (** Basic optimizations (constant folding, redundancy elimination). *)
298298+ | Extended (** Extended optimizations (includes basic + more advanced rewrites). *)
299299+ | All (** All available optimizations (default). *)
300300+301301+(** An inference session: a loaded and optimized ONNX model ready to run.
302302+303303+ {2 Session lifecycle}
304304+305305+ 1. Create a session with {!create}, which loads the model, applies graph
306306+ optimizations, and partitions operators across execution providers.
307307+ 2. Run inference with {!run}, passing named input tensors and receiving
308308+ named output tensors.
309309+ 3. Release with {!release} when done, to free model weights and any
310310+ GPU resources.
311311+312312+ {2 Warm-up}
313313+314314+ When using WebGPU, compute shaders are compiled lazily on the first
315315+ {!run} call. The first inference will be significantly slower than
316316+ subsequent ones. Run a warm-up inference with dummy data after session
317317+ creation if latency matters.
318318+319319+ {2 Thread safety}
320320+321321+ Sessions do not support concurrent {!run} calls. Await each result before
322322+ starting the next inference. *)
323323+module Session : sig
324324+ (** An opaque inference session handle. *)
325325+ type t
326326+327327+ val create :
328328+ ?execution_providers:Execution_provider.t list ->
329329+ ?graph_optimization:graph_optimization ->
330330+ ?preferred_output_location:output_location ->
331331+ ?log_level:[ `Verbose | `Info | `Warning | `Error | `Fatal ] ->
332332+ string ->
333333+ unit ->
334334+ t Lwt.t
335335+ (** [create ?execution_providers ?graph_optimization ?preferred_output_location
336336+ ?log_level model_url ()] loads an ONNX model and creates an inference session.
337337+338338+ @param execution_providers Preference-ordered list of backends to try.
339339+ Defaults to [[Wasm]].
340340+ @param graph_optimization Level of graph optimization to apply.
341341+ Defaults to [All].
342342+ @param preferred_output_location Where to place output tensors. Defaults
343343+ to [Cpu]. Set to [Gpu_buffer] when chaining inference calls on the GPU.
344344+ @param log_level Minimum severity for runtime log messages.
345345+ Defaults to [`Warning].
346346+ @param model_url URL or path to the [.onnx] or [.ort] model file.
347347+348348+ @raise Failure if the model cannot be loaded or parsed *)
349349+350350+ val create_from_buffer :
351351+ ?execution_providers:Execution_provider.t list ->
352352+ ?graph_optimization:graph_optimization ->
353353+ ?preferred_output_location:output_location ->
354354+ ?log_level:[ `Verbose | `Info | `Warning | `Error | `Fatal ] ->
355355+ ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t ->
356356+ unit ->
357357+ t Lwt.t
358358+ (** [create_from_buffer ?... buffer ()] creates a session from model bytes
359359+ already in memory. The [buffer] should contain the raw [.onnx] or [.ort]
360360+ file content (typically fetched separately and cached in IndexedDB).
361361+362362+ Takes any Bigarray element type so you can pass [int8_unsigned] bytes
363363+ directly. *)
364364+365365+ val run :
366366+ t ->
367367+ (string * Tensor.t) list ->
368368+ (string * Tensor.t) list Lwt.t
369369+ (** [run session inputs] runs inference on the model.
370370+371371+ [inputs] is an association list mapping input names to tensors. Use
372372+ {!input_names} to discover the expected names.
373373+374374+ Returns an association list mapping output names to result tensors.
375375+ The caller is responsible for {!Tensor.dispose}ing the returned tensors.
376376+377377+ @raise Failure if an input name is not recognised, a tensor shape is
378378+ incompatible with the model, or inference fails *)
379379+380380+ val run_with_outputs :
381381+ t ->
382382+ (string * Tensor.t) list ->
383383+ output_names:string list ->
384384+ (string * Tensor.t) list Lwt.t
385385+ (** [run_with_outputs session inputs ~output_names] runs inference, fetching
386386+ only the specified outputs. This can be more efficient than {!run} when
387387+ a model has multiple outputs but you only need some of them.
388388+389389+ @raise Failure if an output name is not recognised *)
390390+391391+ val input_names : t -> string list
392392+ (** [input_names session] returns the model's expected input tensor names,
393393+ in the order defined by the model. *)
394394+395395+ val output_names : t -> string list
396396+ (** [output_names session] returns the model's output tensor names, in the
397397+ order defined by the model. *)
398398+399399+ val release : t -> unit Lwt.t
400400+ (** [release session] frees all resources held by the session, including
401401+ model weights and any GPU resources. The session must not be used after
402402+ this call. *)
403403+end
404404+405405+(** {1 Environment configuration}
406406+407407+ Global settings that affect all sessions. These {b must} be set before
408408+ the first call to {!Session.create}; changing them afterwards has no
409409+ effect. *)
410410+module Env : sig
411411+ (** WebAssembly backend configuration. *)
412412+ module Wasm : sig
413413+ val set_num_threads : int -> unit
414414+ (** [set_num_threads n] sets the number of threads for the WASM backend.
415415+416416+ - [0] (default): auto-detect ([navigator.hardwareConcurrency / 2],
417417+ capped at 4)
418418+ - [1]: single-threaded (no Web Workers, no [SharedArrayBuffer] needed)
419419+ - [n]: use [n] threads (requires cross-origin isolation)
420420+421421+ Multi-threading requires the page to be served with:
422422+ {v
423423+Cross-Origin-Opener-Policy: same-origin
424424+Cross-Origin-Embedder-Policy: require-corp
425425+ v} *)
426426+427427+ val set_simd : bool -> unit
428428+ (** [set_simd enabled] enables or disables WASM SIMD. Defaults to [true]
429429+ (auto-detect). SIMD provides ~2x speedup on supported hardware. *)
430430+431431+ val set_proxy : bool -> unit
432432+ (** [set_proxy enabled] enables the proxy worker, which offloads WASM
433433+ inference to a dedicated Web Worker for UI responsiveness.
434434+435435+ {b Incompatible with the WebGPU execution provider.}
436436+437437+ Defaults to [false]. *)
438438+439439+ val set_wasm_paths : string -> unit
440440+ (** [set_wasm_paths prefix] sets the URL prefix where [.wasm] files are
441441+ served. For example, [set_wasm_paths "/static/wasm/"] causes the
442442+ runtime to load [/static/wasm/ort-wasm-simd-threaded.wasm].
443443+444444+ By default, files are loaded relative to the current page or worker
445445+ script location. *)
446446+ end
447447+448448+ (** WebGPU backend configuration. *)
449449+ module Webgpu : sig
450450+ val set_power_preference : [ `High_performance | `Low_power ] -> unit
451451+ (** [set_power_preference pref] sets the GPU adapter power preference.
452452+ Defaults to [`High_performance].
453453+454454+ - [`High_performance]: prefer discrete GPU (better throughput)
455455+ - [`Low_power]: prefer integrated GPU (better battery life) *)
456456+ end
457457+end
458458+459459+(** {1 Errors}
460460+461461+ All functions that interact with the ONNX runtime raise [Failure] on
462462+ error with a descriptive message from the runtime. Async operations may
463463+ also reject the [Lwt.t] promise with [Failure].
464464+465465+ In a production application, wrap calls in [Lwt.catch]:
466466+467467+ {[
468468+ Lwt.catch
469469+ (fun () ->
470470+ let* session = Session.create "model.onnx" () in
471471+ (* ... *))
472472+ (fun exn ->
473473+ Logs.err (fun m -> m "ONNX error: %s" (Printexc.to_string exn));
474474+ Lwt.return_unit)
475475+ ]}
476476+*)
+17
onnxrt/lib/promise_lwt.ml
···11+open Js_of_ocaml
22+33+let to_lwt (promise : 'a Js.t) : 'a Lwt.t =
44+ let lwt_promise, resolver = Lwt.wait () in
55+ let on_resolve result = Lwt.wakeup resolver result in
66+ let on_reject error =
77+ let msg =
88+ Js.to_string (Js.Unsafe.meth_call error "toString" [||] : Js.js_string Js.t)
99+ in
1010+ Lwt.wakeup_exn resolver (Failure msg)
1111+ in
1212+ let _ignored : 'b Js.t =
1313+ Js.Unsafe.meth_call promise "then"
1414+ [| Js.Unsafe.inject (Js.wrap_callback on_resolve);
1515+ Js.Unsafe.inject (Js.wrap_callback on_reject) |]
1616+ in
1717+ lwt_promise
+7
onnxrt/lib/promise_lwt.mli
···11+(** Internal: bridge JavaScript Promises to Lwt.
22+33+ Not part of the public API. *)
44+55+val to_lwt : 'a Js_of_ocaml.Js.t -> 'a Lwt.t
66+(** [to_lwt js_promise] converts a JavaScript Promise to an Lwt thread.
77+ If the Promise rejects, the Lwt thread fails with [Failure msg]. *)
+29
onnxrt/onnxrt.opam
···11+# This file is generated by dune, edit dune-project instead
22+opam-version: "2.0"
33+synopsis: "OCaml bindings to ONNX Runtime Web for browser-based ML inference"
44+description:
55+ "Type-safe OCaml bindings to onnxruntime-web, enabling ML model inference in the browser via js_of_ocaml or wasm_of_ocaml. Supports WebAssembly (CPU) and WebGPU (GPU) execution providers."
66+license: "ISC"
77+depends: [
88+ "dune" {>= "3.17"}
99+ "ocaml" {>= "5.2"}
1010+ "js_of_ocaml" {>= "5.8"}
1111+ "js_of_ocaml-ppx" {>= "5.8"}
1212+ "lwt" {>= "5.7"}
1313+ "js_of_ocaml-lwt" {>= "5.8"}
1414+ "odoc" {with-doc}
1515+]
1616+build: [
1717+ ["dune" "subst"] {dev}
1818+ [
1919+ "dune"
2020+ "build"
2121+ "-p"
2222+ name
2323+ "-j"
2424+ jobs
2525+ "@install"
2626+ "@runtest" {with-test}
2727+ "@doc" {with-doc}
2828+ ]
2929+]