""" Simple in-memory vector index for memory blocks. Features: - Uses `ollama_adapter.get_ollama_client()` for embeddings when available. - Falls back to a deterministic, test-friendly embedding function when no client is available. - Provides add(), build(), and search() methods with cosine similarity. This module avoids external heavy deps (no numpy) so tests remain lightweight. """ import hashlib import json import logging import math from pathlib import Path from typing import Any, Dict, List, Optional, Tuple try: import yaml except Exception: yaml = None try: import numpy as np except Exception: np = None from ollama_adapter import get_ollama_client logger = logging.getLogger(__name__) def _sha256_embedding(text: str, dim: int = 32) -> List[float]: """Deterministic fallback embedding using SHA256 digest. Produces a fixed-length vector in [-1, 1]. Good for tests. """ h = hashlib.sha256(text.encode('utf-8')).digest() # Expand or fold to get `dim` floats from digest bytes floats: List[float] = [] i = 0 while len(floats) < dim: # reuse digest with counter for more entropy counter = i.to_bytes(2, 'big') chunk = hashlib.sha256(h + counter).digest() for b in chunk: if len(floats) >= dim: break # map byte 0-255 to float -1..1 floats.append((b / 255.0) * 2.0 - 1.0) i += 1 return floats def _dot(a: List[float], b: List[float]) -> float: return sum(x * y for x, y in zip(a, b)) def _norm(a: List[float]) -> float: return math.sqrt(sum(x * x for x in a)) class EmbeddingClient: """Wrapper that prefers Ollama client embeddings and falls back to SHA256 embeddings.""" def __init__(self, dim: int = 32): self.dim = dim self._client = get_ollama_client() def embed_texts(self, texts: List[str]) -> List[List[float]]: # Validate inputs: empty or whitespace-only strings should not be sent # to the embedding backend as they commonly produce a server-side 500 # (no input provided). Treat this as a caller error and raise. if any((t is None) or (isinstance(t, str) and not t.strip()) for t in texts): logger.error('Attempted to embed empty or blank text; caller must validate inputs. texts=%r', texts) raise ValueError('Attempted to embed empty or blank text') if self._client is None: # fallback deterministic embeddings return [_sha256_embedding(t, dim=self.dim) for t in texts] try: # Ollama adapter returns List[List[float]] or raises OllamaError embs = self._client.embeddings(texts) # ensure proper shape if not embs or not isinstance(embs, list) or not isinstance(embs[0], list): raise RuntimeError('Unexpected embeddings shape from client') return embs except Exception as e: # Log an error when embeddings fail and we must fall back. This # indicates a problem with the embedding backend that an operator # should investigate (server response/body is included). logger.error('Ollama embeddings unavailable; using deterministic fallback embeddings. Adapter error: %s', e) return [_sha256_embedding(t, dim=self.dim) for t in texts] class MemoryVectorIndex: """In-memory index mapping id -> vector and text. Usage: idx = MemoryVectorIndex() idx.add('id1', 'some text') idx.build() results = idx.search('query text', top_k=5) """ def __init__(self, embedding_dim: int = 32): self.embedder = EmbeddingClient(dim=embedding_dim) self._texts: Dict[str, str] = {} self._vectors: Dict[str, List[float]] = {} # optional numpy-backed arrays for faster math self._use_numpy = np is not None self._matrix: Optional[Any] = None # chunk metadata: id -> list of chunk dicts {file, key, start, end, text} self._chunks: Dict[str, List[Dict[str, Any]]] = {} def add(self, id: str, text: str) -> None: self._texts[id] = text def build(self) -> None: # compute embeddings for all texts ids = list(self._texts.keys()) texts = [self._texts[i] for i in ids] if not texts: return embs = self.embedder.embed_texts(texts) # store vectors by id for i, id in enumerate(ids): vec = embs[i] self._vectors[id] = vec # build numpy matrix if available if self._use_numpy and self._vectors: self._matrix = np.vstack([np.array(self._vectors[k], dtype=float) for k in ids]) else: self._matrix = None def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """Return list of (id, score) sorted by descending cosine similarity. Score is cosine similarity in [-1,1]. If no vectors exist, returns [] """ if not self._vectors: return [] qvec = self.embedder.embed_texts([query])[0] # If numpy matrix exists, use fast batched dot products if self._matrix is not None: qarr = np.array(qvec, dtype=float) norms = np.linalg.norm(self._matrix, axis=1) * np.linalg.norm(qarr) dots = self._matrix.dot(qarr) scores = [] ids = list(self._vectors.keys()) for i, id in enumerate(ids): denom = norms[i] score = 0.0 if denom == 0 else float(dots[i] / denom) scores.append((id, score)) else: qnorm = _norm(qvec) if qnorm == 0: return [] scores: List[Tuple[str, float]] = [] for id, vec in self._vectors.items(): denom = qnorm * _norm(vec) if denom == 0: score = 0.0 else: score = _dot(qvec, vec) / denom scores.append((id, score)) # sort descending by score scores.sort(key=lambda x: x[1], reverse=True) return scores[:top_k] def keyword_search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """Simple substring-based keyword search over chunks. Returns (id, score). Score is heuristic: presence gives base 0.5, multiple occurrences increase up to 1.0. """ q = query.lower().strip() if not q: return [] hits: List[Tuple[str, float]] = [] # iterate chunk metadata for file, chunks in self._chunks.items(): for c in chunks: text = (c.get('text') or '').lower() if not text: continue occ = text.count(q) if occ <= 0: # also check word-inclusion if q in text: occ = 1 else: continue # heuristic scoring score = 0.5 + 0.5 * min(1.0, occ / 3.0) hits.append((c.get('id'), float(score))) # sort by score desc hits.sort(key=lambda x: x[1], reverse=True) return hits[:top_k] def combined_search(self, query: str, top_k: int = 5, boost_keyword: float = 0.05) -> List[Tuple[str, float]]: """Run semantic and keyword search in parallel and merge results. Merging strategy: take the max score per id, apply a small boost to keyword hits so exact matches are slightly prioritized. Return top_k ids sorted by score. """ import threading sem_res: List[Tuple[str, float]] = [] key_res: List[Tuple[str, float]] = [] def run_sem(): nonlocal sem_res try: sem_res = self.search(query, top_k * 3) except Exception: sem_res = [] def run_key(): nonlocal key_res try: key_res = self.keyword_search(query, top_k * 3) except Exception: key_res = [] t1 = threading.Thread(target=run_sem) t2 = threading.Thread(target=run_key) t1.start(); t2.start() t1.join(); t2.join() merged: Dict[str, float] = {} for id, score in sem_res: merged[id] = max(merged.get(id, -1.0), float(score)) for id, score in key_res: s = float(score) + boost_keyword merged[id] = max(merged.get(id, -1.0), s) # convert to list and sort items = list(merged.items()) items.sort(key=lambda x: x[1], reverse=True) return items[:top_k] def reindex(self, data_dir: str = 'memory_blocks', persist_path: str = 'memory_index.json', chunk_size: int = 256, overlap: int = 32) -> None: """Rebuild the index from memory blocks and persist it.""" # clear previous self._texts = {} self._vectors = {} self._chunks = {} self._matrix = None # load and build self.load_memory_blocks(data_dir, chunk_size=chunk_size, overlap=overlap) self.build() try: self.save(persist_path) except Exception: logger.exception('Failed to persist index') def refresh(self, persist_path: str = 'memory_index.json') -> None: """Reload from a persisted index file if present.""" p = Path(persist_path) if p.exists(): self.load(persist_path) else: # nothing to refresh return # ----- persistence and memory loading ----- def save(self, path: str) -> None: data = { 'texts': self._texts, 'vectors': self._vectors, 'chunks': self._chunks, # convenience flat mapping from chunk_id -> metadata for fast lookup 'chunk_index': {c['id']: c for file_chunks in self._chunks.values() for c in file_chunks}, } Path(path).write_text(json.dumps(data, ensure_ascii=False)) def load(self, path: str) -> None: p = Path(path) if not p.exists(): raise FileNotFoundError(path) data = json.loads(p.read_text()) self._texts = data.get('texts', {}) self._vectors = data.get('vectors', {}) self._chunks = data.get('chunks', {}) # build a flat chunk index for fast lookup self._chunk_index = data.get('chunk_index') or {c['id']: c for file_chunks in self._chunks.values() for c in file_chunks} # rebuild numpy matrix if possible if self._use_numpy and self._vectors: ids = list(self._vectors.keys()) self._matrix = np.vstack([np.array(self._vectors[k], dtype=float) for k in ids]) def load_memory_blocks(self, folder: str, chunk_size: int = 256, overlap: int = 32) -> None: """Load YAML memory blocks from a folder and chunk their textual content. Each memory block is expected to be a YAML file; keys with string values will be chunked and indexed. Chunk metadata includes file and key names and character offsets (start, end). """ base = Path(folder) if not base.exists(): raise FileNotFoundError(folder) for p in base.glob('*.yaml'): try: raw = p.read_text(encoding='utf-8') if yaml is None: # treat whole file as single text field self.add(str(p), raw) self._chunks[str(p)] = [{'file': str(p), 'key': '__full__', 'start': 0, 'end': len(raw), 'text': raw}] continue doc = yaml.safe_load(raw) or {} except Exception as e: logger.warning('Failed to read YAML %s: %s', p, e) continue # recursively find string leaves def walk(obj, prefix=''): if isinstance(obj, dict): for k, v in obj.items(): yield from walk(v, prefix + ('.' + k if prefix else k)) elif isinstance(obj, list): for i, v in enumerate(obj): yield from walk(v, prefix + f'[{i}]') elif isinstance(obj, str): yield prefix, obj self._chunks[str(p)] = [] for key, text in walk(doc): # simple character-based chunking L = len(text) start = 0 chunk_id_base = f"{p.name}:{key}" i = 0 while start < L: end = min(start + chunk_size, L) chunk_text = text[start:end] chunk_id = f"{chunk_id_base}:{i}" self.add(chunk_id, chunk_text) self._chunks[str(p)].append({'file': str(p), 'key': key, 'start': start, 'end': end, 'text': chunk_text, 'id': chunk_id}) i += 1 # move start forward with overlap start = end - overlap if end < L else end def resolve_flattened_key(self, file_path: str, flattened_key: str): """Resolve a flattened key (dot + bracket notation) in the given YAML file. Example keys: 'profile.bio', 'profile.notes[0]', 'outer.inner[2].name' """ if yaml is None: raise RuntimeError('yaml not available') raw = Path(file_path).read_text(encoding='utf-8') doc = yaml.safe_load(raw) or {} parts = flattened_key.split('.') if flattened_key else [] cur = doc import re for part in parts: if '[' in part: # e.g. notes[0][1] base = part.split('[', 1)[0] if base: if not isinstance(cur, dict): raise KeyError(base) cur = cur[base] # find all indices indices = [int(m) for m in re.findall(r"\[(\d+)\]", part)] for idx in indices: if not isinstance(cur, list): raise IndexError(f'Expected list when indexing {part}') cur = cur[idx] else: if not isinstance(cur, dict): raise KeyError(part) cur = cur[part] return cur def resolve_flattened_key(file_path: str, flattened_key: str): """Convenience wrapper that uses the global index if available.""" # Try to find a global index (if module using singleton pattern) try: # avoid import cycle; import here from memory_vector import MemoryVectorIndex as _MVI # type: ignore except Exception: _MVI = None # If caller created an index, they should call its method directly. Here we # raise a helpful error to guide usage. raise RuntimeError('Call MemoryVectorIndex.resolve_flattened_key on an index instance') __all__ = ["MemoryVectorIndex", "EmbeddingClient", "resolve_flattened_key_in_file"] def resolve_flattened_key_in_file(file_path: str, flattened_key: str): """Public helper: resolve a flattened key directly from a YAML file. This duplicates the navigation logic in MemoryVectorIndex.resolve_flattened_key but is convenient for callers that only need to read a single file without instantiating an index. """ if yaml is None: raise RuntimeError('yaml not available') raw = Path(file_path).read_text(encoding='utf-8') doc = yaml.safe_load(raw) or {} parts = flattened_key.split('.') if flattened_key else [] cur = doc import re for part in parts: if '[' in part: base = part.split('[', 1)[0] if base: if not isinstance(cur, dict): raise KeyError(base) cur = cur[base] indices = [int(m) for m in re.findall(r"\[(\d+)\]", part)] for idx in indices: if not isinstance(cur, list): raise IndexError(f'Expected list when indexing {part}') cur = cur[idx] else: if not isinstance(cur, dict): raise KeyError(part) cur = cur[part] return cur