Cameron's void repo torn apart for local testing
at main 429 lines 16 kB view raw
1""" 2Simple in-memory vector index for memory blocks. 3 4Features: 5- Uses `ollama_adapter.get_ollama_client()` for embeddings when available. 6- Falls back to a deterministic, test-friendly embedding function when no client is available. 7- Provides add(), build(), and search() methods with cosine similarity. 8 9This module avoids external heavy deps (no numpy) so tests remain lightweight. 10""" 11import hashlib 12import json 13import logging 14import math 15from pathlib import Path 16from typing import Any, Dict, List, Optional, Tuple 17 18try: 19 import yaml 20except Exception: 21 yaml = None 22 23try: 24 import numpy as np 25except Exception: 26 np = None 27 28from ollama_adapter import get_ollama_client 29 30logger = logging.getLogger(__name__) 31 32 33def _sha256_embedding(text: str, dim: int = 32) -> List[float]: 34 """Deterministic fallback embedding using SHA256 digest. 35 36 Produces a fixed-length vector in [-1, 1]. Good for tests. 37 """ 38 h = hashlib.sha256(text.encode('utf-8')).digest() 39 # Expand or fold to get `dim` floats from digest bytes 40 floats: List[float] = [] 41 i = 0 42 while len(floats) < dim: 43 # reuse digest with counter for more entropy 44 counter = i.to_bytes(2, 'big') 45 chunk = hashlib.sha256(h + counter).digest() 46 for b in chunk: 47 if len(floats) >= dim: 48 break 49 # map byte 0-255 to float -1..1 50 floats.append((b / 255.0) * 2.0 - 1.0) 51 i += 1 52 return floats 53 54 55def _dot(a: List[float], b: List[float]) -> float: 56 return sum(x * y for x, y in zip(a, b)) 57 58 59def _norm(a: List[float]) -> float: 60 return math.sqrt(sum(x * x for x in a)) 61 62 63class EmbeddingClient: 64 """Wrapper that prefers Ollama client embeddings and falls back to SHA256 embeddings.""" 65 66 def __init__(self, dim: int = 32): 67 self.dim = dim 68 self._client = get_ollama_client() 69 70 def embed_texts(self, texts: List[str]) -> List[List[float]]: 71 # Validate inputs: empty or whitespace-only strings should not be sent 72 # to the embedding backend as they commonly produce a server-side 500 73 # (no input provided). Treat this as a caller error and raise. 74 if any((t is None) or (isinstance(t, str) and not t.strip()) for t in texts): 75 logger.error('Attempted to embed empty or blank text; caller must validate inputs. texts=%r', texts) 76 raise ValueError('Attempted to embed empty or blank text') 77 78 if self._client is None: 79 # fallback deterministic embeddings 80 return [_sha256_embedding(t, dim=self.dim) for t in texts] 81 82 try: 83 # Ollama adapter returns List[List[float]] or raises OllamaError 84 embs = self._client.embeddings(texts) 85 # ensure proper shape 86 if not embs or not isinstance(embs, list) or not isinstance(embs[0], list): 87 raise RuntimeError('Unexpected embeddings shape from client') 88 return embs 89 except Exception as e: 90 # Log an error when embeddings fail and we must fall back. This 91 # indicates a problem with the embedding backend that an operator 92 # should investigate (server response/body is included). 93 logger.error('Ollama embeddings unavailable; using deterministic fallback embeddings. Adapter error: %s', e) 94 return [_sha256_embedding(t, dim=self.dim) for t in texts] 95 96 97class MemoryVectorIndex: 98 """In-memory index mapping id -> vector and text. 99 100 Usage: 101 idx = MemoryVectorIndex() 102 idx.add('id1', 'some text') 103 idx.build() 104 results = idx.search('query text', top_k=5) 105 """ 106 107 def __init__(self, embedding_dim: int = 32): 108 self.embedder = EmbeddingClient(dim=embedding_dim) 109 self._texts: Dict[str, str] = {} 110 self._vectors: Dict[str, List[float]] = {} 111 # optional numpy-backed arrays for faster math 112 self._use_numpy = np is not None 113 self._matrix: Optional[Any] = None 114 # chunk metadata: id -> list of chunk dicts {file, key, start, end, text} 115 self._chunks: Dict[str, List[Dict[str, Any]]] = {} 116 117 def add(self, id: str, text: str) -> None: 118 self._texts[id] = text 119 120 def build(self) -> None: 121 # compute embeddings for all texts 122 ids = list(self._texts.keys()) 123 texts = [self._texts[i] for i in ids] 124 if not texts: 125 return 126 embs = self.embedder.embed_texts(texts) 127 # store vectors by id 128 for i, id in enumerate(ids): 129 vec = embs[i] 130 self._vectors[id] = vec 131 # build numpy matrix if available 132 if self._use_numpy and self._vectors: 133 self._matrix = np.vstack([np.array(self._vectors[k], dtype=float) for k in ids]) 134 else: 135 self._matrix = None 136 137 def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: 138 """Return list of (id, score) sorted by descending cosine similarity. 139 140 Score is cosine similarity in [-1,1]. If no vectors exist, returns [] 141 """ 142 if not self._vectors: 143 return [] 144 qvec = self.embedder.embed_texts([query])[0] 145 # If numpy matrix exists, use fast batched dot products 146 if self._matrix is not None: 147 qarr = np.array(qvec, dtype=float) 148 norms = np.linalg.norm(self._matrix, axis=1) * np.linalg.norm(qarr) 149 dots = self._matrix.dot(qarr) 150 scores = [] 151 ids = list(self._vectors.keys()) 152 for i, id in enumerate(ids): 153 denom = norms[i] 154 score = 0.0 if denom == 0 else float(dots[i] / denom) 155 scores.append((id, score)) 156 else: 157 qnorm = _norm(qvec) 158 if qnorm == 0: 159 return [] 160 scores: List[Tuple[str, float]] = [] 161 for id, vec in self._vectors.items(): 162 denom = qnorm * _norm(vec) 163 if denom == 0: 164 score = 0.0 165 else: 166 score = _dot(qvec, vec) / denom 167 scores.append((id, score)) 168 169 # sort descending by score 170 scores.sort(key=lambda x: x[1], reverse=True) 171 return scores[:top_k] 172 173 def keyword_search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: 174 """Simple substring-based keyword search over chunks. Returns (id, score). 175 176 Score is heuristic: presence gives base 0.5, multiple occurrences increase up to 1.0. 177 """ 178 q = query.lower().strip() 179 if not q: 180 return [] 181 hits: List[Tuple[str, float]] = [] 182 # iterate chunk metadata 183 for file, chunks in self._chunks.items(): 184 for c in chunks: 185 text = (c.get('text') or '').lower() 186 if not text: 187 continue 188 occ = text.count(q) 189 if occ <= 0: 190 # also check word-inclusion 191 if q in text: 192 occ = 1 193 else: 194 continue 195 # heuristic scoring 196 score = 0.5 + 0.5 * min(1.0, occ / 3.0) 197 hits.append((c.get('id'), float(score))) 198 # sort by score desc 199 hits.sort(key=lambda x: x[1], reverse=True) 200 return hits[:top_k] 201 202 def combined_search(self, query: str, top_k: int = 5, boost_keyword: float = 0.05) -> List[Tuple[str, float]]: 203 """Run semantic and keyword search in parallel and merge results. 204 205 Merging strategy: take the max score per id, apply a small boost to keyword hits 206 so exact matches are slightly prioritized. Return top_k ids sorted by score. 207 """ 208 import threading 209 210 sem_res: List[Tuple[str, float]] = [] 211 key_res: List[Tuple[str, float]] = [] 212 213 def run_sem(): 214 nonlocal sem_res 215 try: 216 sem_res = self.search(query, top_k * 3) 217 except Exception: 218 sem_res = [] 219 220 def run_key(): 221 nonlocal key_res 222 try: 223 key_res = self.keyword_search(query, top_k * 3) 224 except Exception: 225 key_res = [] 226 227 t1 = threading.Thread(target=run_sem) 228 t2 = threading.Thread(target=run_key) 229 t1.start(); t2.start() 230 t1.join(); t2.join() 231 232 merged: Dict[str, float] = {} 233 for id, score in sem_res: 234 merged[id] = max(merged.get(id, -1.0), float(score)) 235 for id, score in key_res: 236 s = float(score) + boost_keyword 237 merged[id] = max(merged.get(id, -1.0), s) 238 239 # convert to list and sort 240 items = list(merged.items()) 241 items.sort(key=lambda x: x[1], reverse=True) 242 return items[:top_k] 243 244 def reindex(self, data_dir: str = 'memory_blocks', persist_path: str = 'memory_index.json', chunk_size: int = 256, overlap: int = 32) -> None: 245 """Rebuild the index from memory blocks and persist it.""" 246 # clear previous 247 self._texts = {} 248 self._vectors = {} 249 self._chunks = {} 250 self._matrix = None 251 # load and build 252 self.load_memory_blocks(data_dir, chunk_size=chunk_size, overlap=overlap) 253 self.build() 254 try: 255 self.save(persist_path) 256 except Exception: 257 logger.exception('Failed to persist index') 258 259 def refresh(self, persist_path: str = 'memory_index.json') -> None: 260 """Reload from a persisted index file if present.""" 261 p = Path(persist_path) 262 if p.exists(): 263 self.load(persist_path) 264 else: 265 # nothing to refresh 266 return 267 268 # ----- persistence and memory loading ----- 269 def save(self, path: str) -> None: 270 data = { 271 'texts': self._texts, 272 'vectors': self._vectors, 273 'chunks': self._chunks, 274 # convenience flat mapping from chunk_id -> metadata for fast lookup 275 'chunk_index': {c['id']: c for file_chunks in self._chunks.values() for c in file_chunks}, 276 } 277 Path(path).write_text(json.dumps(data, ensure_ascii=False)) 278 279 def load(self, path: str) -> None: 280 p = Path(path) 281 if not p.exists(): 282 raise FileNotFoundError(path) 283 data = json.loads(p.read_text()) 284 self._texts = data.get('texts', {}) 285 self._vectors = data.get('vectors', {}) 286 self._chunks = data.get('chunks', {}) 287 # build a flat chunk index for fast lookup 288 self._chunk_index = data.get('chunk_index') or {c['id']: c for file_chunks in self._chunks.values() for c in file_chunks} 289 # rebuild numpy matrix if possible 290 if self._use_numpy and self._vectors: 291 ids = list(self._vectors.keys()) 292 self._matrix = np.vstack([np.array(self._vectors[k], dtype=float) for k in ids]) 293 294 def load_memory_blocks(self, folder: str, chunk_size: int = 256, overlap: int = 32) -> None: 295 """Load YAML memory blocks from a folder and chunk their textual content. 296 297 Each memory block is expected to be a YAML file; keys with string values 298 will be chunked and indexed. Chunk metadata includes file and key names 299 and character offsets (start, end). 300 """ 301 base = Path(folder) 302 if not base.exists(): 303 raise FileNotFoundError(folder) 304 for p in base.glob('*.yaml'): 305 try: 306 raw = p.read_text(encoding='utf-8') 307 if yaml is None: 308 # treat whole file as single text field 309 self.add(str(p), raw) 310 self._chunks[str(p)] = [{'file': str(p), 'key': '__full__', 'start': 0, 'end': len(raw), 'text': raw}] 311 continue 312 doc = yaml.safe_load(raw) or {} 313 except Exception as e: 314 logger.warning('Failed to read YAML %s: %s', p, e) 315 continue 316 317 # recursively find string leaves 318 def walk(obj, prefix=''): 319 if isinstance(obj, dict): 320 for k, v in obj.items(): 321 yield from walk(v, prefix + ('.' + k if prefix else k)) 322 elif isinstance(obj, list): 323 for i, v in enumerate(obj): 324 yield from walk(v, prefix + f'[{i}]') 325 elif isinstance(obj, str): 326 yield prefix, obj 327 328 self._chunks[str(p)] = [] 329 for key, text in walk(doc): 330 # simple character-based chunking 331 L = len(text) 332 start = 0 333 chunk_id_base = f"{p.name}:{key}" 334 i = 0 335 while start < L: 336 end = min(start + chunk_size, L) 337 chunk_text = text[start:end] 338 chunk_id = f"{chunk_id_base}:{i}" 339 self.add(chunk_id, chunk_text) 340 self._chunks[str(p)].append({'file': str(p), 'key': key, 'start': start, 'end': end, 'text': chunk_text, 'id': chunk_id}) 341 i += 1 342 # move start forward with overlap 343 start = end - overlap if end < L else end 344 345 def resolve_flattened_key(self, file_path: str, flattened_key: str): 346 """Resolve a flattened key (dot + bracket notation) in the given YAML file. 347 348 Example keys: 'profile.bio', 'profile.notes[0]', 'outer.inner[2].name' 349 """ 350 if yaml is None: 351 raise RuntimeError('yaml not available') 352 raw = Path(file_path).read_text(encoding='utf-8') 353 doc = yaml.safe_load(raw) or {} 354 355 parts = flattened_key.split('.') if flattened_key else [] 356 cur = doc 357 import re 358 for part in parts: 359 if '[' in part: 360 # e.g. notes[0][1] 361 base = part.split('[', 1)[0] 362 if base: 363 if not isinstance(cur, dict): 364 raise KeyError(base) 365 cur = cur[base] 366 # find all indices 367 indices = [int(m) for m in re.findall(r"\[(\d+)\]", part)] 368 for idx in indices: 369 if not isinstance(cur, list): 370 raise IndexError(f'Expected list when indexing {part}') 371 cur = cur[idx] 372 else: 373 if not isinstance(cur, dict): 374 raise KeyError(part) 375 cur = cur[part] 376 377 return cur 378 379 380def resolve_flattened_key(file_path: str, flattened_key: str): 381 """Convenience wrapper that uses the global index if available.""" 382 # Try to find a global index (if module using singleton pattern) 383 try: 384 # avoid import cycle; import here 385 from memory_vector import MemoryVectorIndex as _MVI # type: ignore 386 except Exception: 387 _MVI = None 388 # If caller created an index, they should call its method directly. Here we 389 # raise a helpful error to guide usage. 390 raise RuntimeError('Call MemoryVectorIndex.resolve_flattened_key on an index instance') 391 392 393 394__all__ = ["MemoryVectorIndex", "EmbeddingClient", "resolve_flattened_key_in_file"] 395 396 397def resolve_flattened_key_in_file(file_path: str, flattened_key: str): 398 """Public helper: resolve a flattened key directly from a YAML file. 399 400 This duplicates the navigation logic in MemoryVectorIndex.resolve_flattened_key 401 but is convenient for callers that only need to read a single file without 402 instantiating an index. 403 """ 404 if yaml is None: 405 raise RuntimeError('yaml not available') 406 raw = Path(file_path).read_text(encoding='utf-8') 407 doc = yaml.safe_load(raw) or {} 408 409 parts = flattened_key.split('.') if flattened_key else [] 410 cur = doc 411 import re 412 for part in parts: 413 if '[' in part: 414 base = part.split('[', 1)[0] 415 if base: 416 if not isinstance(cur, dict): 417 raise KeyError(base) 418 cur = cur[base] 419 indices = [int(m) for m in re.findall(r"\[(\d+)\]", part)] 420 for idx in indices: 421 if not isinstance(cur, list): 422 raise IndexError(f'Expected list when indexing {part}') 423 cur = cur[idx] 424 else: 425 if not isinstance(cur, dict): 426 raise KeyError(part) 427 cur = cur[part] 428 429 return cur