this repo has no description

tessera-linalg: implement kNN with distance-weighted voting

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

+93 -12
+1 -3
tessera-linalg/lib/dune
··· 1 1 (library 2 2 (name linalg) 3 - (public_name tessera-linalg) 4 - (ocamlopt_flags (:standard -w -69)) 5 - (ocamlc_flags (:standard -w -69))) 3 + (public_name tessera-linalg))
+42 -4
tessera-linalg/lib/linalg.ml
··· 123 123 done; 124 124 result 125 125 126 - (* ---- kNN (stub) ---- *) 126 + (* ---- kNN ---- *) 127 127 128 128 type knn_model = { 129 129 embeddings : mat; ··· 138 138 confidences : float array; 139 139 } 140 140 141 - let knn_predict _model ~k:_ query = 142 - { predictions = Array.make query.rows 0; 143 - confidences = Array.make query.rows 0.0 } 141 + let knn_predict model ~k query = 142 + let n_train = model.embeddings.rows in 143 + let n_features = model.embeddings.cols in 144 + let k = min k n_train in 145 + let predictions = Array.make query.rows 0 in 146 + let confidences = Array.make query.rows 0.0 in 147 + for qi = 0 to query.rows - 1 do 148 + (* Compute squared distances *) 149 + let dists = Array.init n_train (fun ti -> 150 + let d = ref 0.0 in 151 + for j = 0 to n_features - 1 do 152 + let diff = mat_get query qi j -. mat_get model.embeddings ti j in 153 + d := !d +. diff *. diff 154 + done; 155 + (!d, model.labels.(ti)) 156 + ) in 157 + (* Sort by distance *) 158 + Array.sort (fun (d1, _) (d2, _) -> Float.compare d1 d2) dists; 159 + (* Distance-weighted vote for k nearest *) 160 + let class_weights = Hashtbl.create 16 in 161 + let total_weight = ref 0.0 in 162 + for i = 0 to k - 1 do 163 + let (dist, label) = dists.(i) in 164 + let w = 1.0 /. (dist +. 1e-10) in 165 + total_weight := !total_weight +. w; 166 + let prev = try Hashtbl.find class_weights label with Not_found -> 0.0 in 167 + Hashtbl.replace class_weights label (prev +. w) 168 + done; 169 + (* Find class with highest weight *) 170 + let best_class = ref 0 in 171 + let best_weight = ref Float.neg_infinity in 172 + Hashtbl.iter (fun cls w -> 173 + if w > !best_weight then begin 174 + best_class := cls; 175 + best_weight := w 176 + end 177 + ) class_weights; 178 + predictions.(qi) <- !best_class; 179 + confidences.(qi) <- !best_weight /. !total_weight 180 + done; 181 + { predictions; confidences }
+50 -5
tessera-linalg/test/test_linalg.ml
··· 38 38 (* ---- PCA tests ---- *) 39 39 40 40 let test_pca_x_axis_points () = 41 - (* 4 points along x-axis in 3D: (1,0,0), (2,0,0), (3,0,0), (4,0,0) 42 - PCA to 1D should produce evenly spaced output *) 43 41 let data = mat_of_arrays [| 44 42 [| 1.0; 0.0; 0.0 |]; 45 43 [| 2.0; 0.0; 0.0 |]; ··· 50 48 let result = pca_transform model data in 51 49 Alcotest.(check int) "result rows" 4 result.rows; 52 50 Alcotest.(check int) "result cols" 1 result.cols; 53 - (* Should be evenly spaced: differences between consecutive should be equal *) 54 51 let v0 = mat_get result 0 0 in 55 52 let v1 = mat_get result 1 0 in 56 53 let v2 = mat_get result 2 0 in ··· 62 59 Alcotest.(check (float eps)) "spacing d12~d23" d12 d23 63 60 64 61 let test_pca_diagonal_monotonic () = 65 - (* 100 points along y=x in 2D -> PCA to 1D -> monotonic output *) 66 62 let rows_data = Array.init 100 (fun i -> 67 63 let v = Float.of_int i in 68 64 [| v; v |] ··· 71 67 let model = pca_fit data ~n_components:1 in 72 68 let result = pca_transform model data in 73 69 Alcotest.(check int) "result rows" 100 result.rows; 74 - (* Check monotonic (either all increasing or all decreasing) *) 75 70 let increasing = ref true in 76 71 let decreasing = ref true in 77 72 for i = 1 to 99 do ··· 93 88 Alcotest.(check int) "rows" 3 result.rows; 94 89 Alcotest.(check int) "cols" 2 result.cols 95 90 91 + (* ---- kNN tests ---- *) 92 + 93 + let test_knn_two_clusters () = 94 + let train = mat_of_arrays [| 95 + [| 0.0; 0.0 |]; 96 + [| 0.1; 0.1 |]; 97 + [| 0.2; 0.0 |]; 98 + [| 100.0; 100.0 |]; 99 + [| 100.1; 100.1 |]; 100 + [| 99.9; 100.0 |]; 101 + |] in 102 + let labels = [| 0; 0; 0; 1; 1; 1 |] in 103 + let model = knn_fit ~embeddings:train ~labels in 104 + let test_pts = mat_of_arrays [| 105 + [| 0.05; 0.05 |]; 106 + [| 99.95; 100.05 |]; 107 + |] in 108 + let res = knn_predict model ~k:3 test_pts in 109 + Alcotest.(check int) "pred 0" 0 res.predictions.(0); 110 + Alcotest.(check int) "pred 1" 1 res.predictions.(1) 111 + 112 + let test_knn_distance_weighting () = 113 + let train = mat_of_arrays [| 114 + [| 0.0; 0.0 |]; 115 + [| 100.0; 0.0 |]; 116 + [| 100.0; 1.0 |]; 117 + |] in 118 + let labels = [| 0; 1; 1 |] in 119 + let model = knn_fit ~embeddings:train ~labels in 120 + let query = mat_of_arrays [| [| 0.01; 0.01 |] |] in 121 + let res = knn_predict model ~k:3 query in 122 + Alcotest.(check int) "close wins" 0 res.predictions.(0) 123 + 124 + let test_knn_k1_confidence () = 125 + let train = mat_of_arrays [| 126 + [| 0.0; 0.0 |]; 127 + [| 10.0; 10.0 |]; 128 + |] in 129 + let labels = [| 0; 1 |] in 130 + let model = knn_fit ~embeddings:train ~labels in 131 + let query = mat_of_arrays [| [| 0.1; 0.1 |] |] in 132 + let res = knn_predict model ~k:1 query in 133 + Alcotest.(check int) "k=1 pred" 0 res.predictions.(0); 134 + Alcotest.(check (float 1e-6)) "k=1 confidence" 1.0 res.confidences.(0) 135 + 96 136 (* ---- Test runner ---- *) 97 137 98 138 let () = ··· 106 146 Alcotest.test_case "x-axis points evenly spaced" `Quick test_pca_x_axis_points; 107 147 Alcotest.test_case "diagonal monotonic" `Quick test_pca_diagonal_monotonic; 108 148 Alcotest.test_case "output dimensions" `Quick test_pca_output_dims; 149 + ]; 150 + "knn", [ 151 + Alcotest.test_case "two clusters" `Quick test_knn_two_clusters; 152 + Alcotest.test_case "distance weighting" `Quick test_knn_distance_weighting; 153 + Alcotest.test_case "k=1 confidence" `Quick test_knn_k1_confidence; 109 154 ]; 110 155 ]