···123 done;
124 result
125126+(* ---- kNN ---- *)
127128type knn_model = {
129 embeddings : mat;
···138 confidences : float array;
139}
140141+let knn_predict model ~k query =
142+ let n_train = model.embeddings.rows in
143+ let n_features = model.embeddings.cols in
144+ let k = min k n_train in
145+ let predictions = Array.make query.rows 0 in
146+ let confidences = Array.make query.rows 0.0 in
147+ for qi = 0 to query.rows - 1 do
148+ (* Compute squared distances *)
149+ let dists = Array.init n_train (fun ti ->
150+ let d = ref 0.0 in
151+ for j = 0 to n_features - 1 do
152+ let diff = mat_get query qi j -. mat_get model.embeddings ti j in
153+ d := !d +. diff *. diff
154+ done;
155+ (!d, model.labels.(ti))
156+ ) in
157+ (* Sort by distance *)
158+ Array.sort (fun (d1, _) (d2, _) -> Float.compare d1 d2) dists;
159+ (* Distance-weighted vote for k nearest *)
160+ let class_weights = Hashtbl.create 16 in
161+ let total_weight = ref 0.0 in
162+ for i = 0 to k - 1 do
163+ let (dist, label) = dists.(i) in
164+ let w = 1.0 /. (dist +. 1e-10) in
165+ total_weight := !total_weight +. w;
166+ let prev = try Hashtbl.find class_weights label with Not_found -> 0.0 in
167+ Hashtbl.replace class_weights label (prev +. w)
168+ done;
169+ (* Find class with highest weight *)
170+ let best_class = ref 0 in
171+ let best_weight = ref Float.neg_infinity in
172+ Hashtbl.iter (fun cls w ->
173+ if w > !best_weight then begin
174+ best_class := cls;
175+ best_weight := w
176+ end
177+ ) class_weights;
178+ predictions.(qi) <- !best_class;
179+ confidences.(qi) <- !best_weight /. !total_weight
180+ done;
181+ { predictions; confidences }
+50-5
tessera-linalg/test/test_linalg.ml
···38(* ---- PCA tests ---- *)
3940let test_pca_x_axis_points () =
41- (* 4 points along x-axis in 3D: (1,0,0), (2,0,0), (3,0,0), (4,0,0)
42- PCA to 1D should produce evenly spaced output *)
43 let data = mat_of_arrays [|
44 [| 1.0; 0.0; 0.0 |];
45 [| 2.0; 0.0; 0.0 |];
···50 let result = pca_transform model data in
51 Alcotest.(check int) "result rows" 4 result.rows;
52 Alcotest.(check int) "result cols" 1 result.cols;
53- (* Should be evenly spaced: differences between consecutive should be equal *)
54 let v0 = mat_get result 0 0 in
55 let v1 = mat_get result 1 0 in
56 let v2 = mat_get result 2 0 in
···62 Alcotest.(check (float eps)) "spacing d12~d23" d12 d23
6364let test_pca_diagonal_monotonic () =
65- (* 100 points along y=x in 2D -> PCA to 1D -> monotonic output *)
66 let rows_data = Array.init 100 (fun i ->
67 let v = Float.of_int i in
68 [| v; v |]
···71 let model = pca_fit data ~n_components:1 in
72 let result = pca_transform model data in
73 Alcotest.(check int) "result rows" 100 result.rows;
74- (* Check monotonic (either all increasing or all decreasing) *)
75 let increasing = ref true in
76 let decreasing = ref true in
77 for i = 1 to 99 do
···93 Alcotest.(check int) "rows" 3 result.rows;
94 Alcotest.(check int) "cols" 2 result.cols
9500000000000000000000000000000000000000000000096(* ---- Test runner ---- *)
9798let () =
···106 Alcotest.test_case "x-axis points evenly spaced" `Quick test_pca_x_axis_points;
107 Alcotest.test_case "diagonal monotonic" `Quick test_pca_diagonal_monotonic;
108 Alcotest.test_case "output dimensions" `Quick test_pca_output_dims;
00000109 ];
110 ]
···38(* ---- PCA tests ---- *)
3940let test_pca_x_axis_points () =
0041 let data = mat_of_arrays [|
42 [| 1.0; 0.0; 0.0 |];
43 [| 2.0; 0.0; 0.0 |];
···48 let result = pca_transform model data in
49 Alcotest.(check int) "result rows" 4 result.rows;
50 Alcotest.(check int) "result cols" 1 result.cols;
051 let v0 = mat_get result 0 0 in
52 let v1 = mat_get result 1 0 in
53 let v2 = mat_get result 2 0 in
···59 Alcotest.(check (float eps)) "spacing d12~d23" d12 d23
6061let test_pca_diagonal_monotonic () =
062 let rows_data = Array.init 100 (fun i ->
63 let v = Float.of_int i in
64 [| v; v |]
···67 let model = pca_fit data ~n_components:1 in
68 let result = pca_transform model data in
69 Alcotest.(check int) "result rows" 100 result.rows;
070 let increasing = ref true in
71 let decreasing = ref true in
72 for i = 1 to 99 do
···88 Alcotest.(check int) "rows" 3 result.rows;
89 Alcotest.(check int) "cols" 2 result.cols
9091+(* ---- kNN tests ---- *)
92+93+let test_knn_two_clusters () =
94+ let train = mat_of_arrays [|
95+ [| 0.0; 0.0 |];
96+ [| 0.1; 0.1 |];
97+ [| 0.2; 0.0 |];
98+ [| 100.0; 100.0 |];
99+ [| 100.1; 100.1 |];
100+ [| 99.9; 100.0 |];
101+ |] in
102+ let labels = [| 0; 0; 0; 1; 1; 1 |] in
103+ let model = knn_fit ~embeddings:train ~labels in
104+ let test_pts = mat_of_arrays [|
105+ [| 0.05; 0.05 |];
106+ [| 99.95; 100.05 |];
107+ |] in
108+ let res = knn_predict model ~k:3 test_pts in
109+ Alcotest.(check int) "pred 0" 0 res.predictions.(0);
110+ Alcotest.(check int) "pred 1" 1 res.predictions.(1)
111+112+let test_knn_distance_weighting () =
113+ let train = mat_of_arrays [|
114+ [| 0.0; 0.0 |];
115+ [| 100.0; 0.0 |];
116+ [| 100.0; 1.0 |];
117+ |] in
118+ let labels = [| 0; 1; 1 |] in
119+ let model = knn_fit ~embeddings:train ~labels in
120+ let query = mat_of_arrays [| [| 0.01; 0.01 |] |] in
121+ let res = knn_predict model ~k:3 query in
122+ Alcotest.(check int) "close wins" 0 res.predictions.(0)
123+124+let test_knn_k1_confidence () =
125+ let train = mat_of_arrays [|
126+ [| 0.0; 0.0 |];
127+ [| 10.0; 10.0 |];
128+ |] in
129+ let labels = [| 0; 1 |] in
130+ let model = knn_fit ~embeddings:train ~labels in
131+ let query = mat_of_arrays [| [| 0.1; 0.1 |] |] in
132+ let res = knn_predict model ~k:1 query in
133+ Alcotest.(check int) "k=1 pred" 0 res.predictions.(0);
134+ Alcotest.(check (float 1e-6)) "k=1 confidence" 1.0 res.confidences.(0)
135+136(* ---- Test runner ---- *)
137138let () =
···146 Alcotest.test_case "x-axis points evenly spaced" `Quick test_pca_x_axis_points;
147 Alcotest.test_case "diagonal monotonic" `Quick test_pca_diagonal_monotonic;
148 Alcotest.test_case "output dimensions" `Quick test_pca_output_dims;
149+ ];
150+ "knn", [
151+ Alcotest.test_case "two clusters" `Quick test_knn_two_clusters;
152+ Alcotest.test_case "distance weighting" `Quick test_knn_distance_weighting;
153+ Alcotest.test_case "k=1 confidence" `Quick test_knn_k1_confidence;
154 ];
155 ]