#!/usr/bin/env -S uv run --script --quiet # /// script # requires-python = ">=3.12" # dependencies = ["httpx", "pydantic-settings"] # /// """ Populate the DiskANN vector index for semantic search. The index only tracks writes that happen after it's created, so existing embeddings need to be touched (UPDATE SET embedding = embedding) to get indexed. This script does that in batches to avoid turso timeouts. Usage: ./scripts/populate-vector-index # populate index ./scripts/populate-vector-index --batch-size 200 # custom batch size ./scripts/populate-vector-index --check # just check index status """ import argparse import os import sys import httpx from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=os.environ.get("ENV_FILE", ".env"), extra="ignore" ) turso_url: str turso_token: str @property def turso_host(self) -> str: url = self.turso_url if url.startswith("libsql://"): url = url[len("libsql://") :] return url def query(settings: Settings, sql: str, timeout: int = 30) -> dict: response = httpx.post( f"https://{settings.turso_host}/v2/pipeline", headers={ "Authorization": f"Bearer {settings.turso_token}", "Content-Type": "application/json", }, json={ "requests": [ {"type": "execute", "stmt": {"sql": sql}}, {"type": "close"}, ] }, timeout=timeout, ) response.raise_for_status() data = response.json() result = data["results"][0] if result["type"] == "error": raise Exception(f"turso error: {result['error']}") return result["response"]["result"] def exec_sql(settings: Settings, sql: str, timeout: int = 120) -> int: result = query(settings, sql, timeout=timeout) return result.get("affected_row_count", 0) def get_scalar(settings: Settings, sql: str) -> int: result = query(settings, sql) row = result["rows"][0] cell = row[0] return int(cell["value"] if isinstance(cell, dict) else cell) def check_status(settings: Settings) -> tuple[int, int]: """Returns (total_embeddings, indexed_count).""" total = get_scalar( settings, "SELECT COUNT(*) FROM documents WHERE embedding IS NOT NULL" ) try: # use a dummy vector to count indexed docs indexed = get_scalar( settings, "SELECT count(*) FROM vector_top_k('documents_embedding_idx', " "(SELECT embedding FROM documents WHERE embedding IS NOT NULL LIMIT 1), 10000)", ) except Exception as e: if "not found" in str(e).lower(): indexed = -1 # index doesn't exist else: raise return total, indexed def main(): parser = argparse.ArgumentParser( description="populate DiskANN vector index for semantic search" ) parser.add_argument( "--batch-size", type=int, default=500, help="rows per UPDATE batch (default: 500)", ) parser.add_argument( "--workers", type=int, default=8, help="concurrent workers (default: 8)" ) parser.add_argument( "--check", action="store_true", help="just check index status and exit" ) args = parser.parse_args() try: settings = Settings() # type: ignore except Exception as e: print(f"error loading settings: {e}", file=sys.stderr) print("required env vars: TURSO_URL, TURSO_TOKEN", file=sys.stderr) sys.exit(1) total, indexed = check_status(settings) if indexed == -1: print("vector index does not exist — creating it...") exec_sql( settings, "CREATE INDEX IF NOT EXISTS documents_embedding_idx " "ON documents(libsql_vector_idx(embedding))", ) print("index created") indexed = 0 print(f"embeddings: {total}, indexed: {indexed}") if args.check: return remaining = total - indexed if remaining <= 0: print("index is fully populated") return import time from concurrent.futures import ThreadPoolExecutor, as_completed batch_size = args.batch_size num_batches = (remaining + batch_size - 1) // batch_size workers = min(args.workers, num_batches) print(f"touching {remaining} rows: {num_batches} batches of {batch_size}, {workers} workers", flush=True) # pre-compute all batch offset ranges batches = list(range(0, total, batch_size)) def touch_batch(offset: int) -> int: sql = ( f"UPDATE documents SET embedding = embedding " f"WHERE rowid IN (" f" SELECT rowid FROM documents WHERE embedding IS NOT NULL " f" ORDER BY rowid LIMIT {batch_size} OFFSET {offset}" f")" ) return exec_sql(settings, sql, timeout=300) touched = 0 start = time.time() with ThreadPoolExecutor(max_workers=workers) as pool: futures = {pool.submit(touch_batch, o): o for o in batches} for future in as_completed(futures): affected = future.result() touched += affected elapsed = time.time() - start rate = touched / elapsed if elapsed > 0 else 0 eta = (remaining - touched) / rate if rate > 0 else 0 print(f" {touched}/{remaining} ({rate:.0f} rows/s, ~{eta:.0f}s left)", flush=True) total_after, indexed_after = check_status(settings) print(f"done — embeddings: {total_after}, indexed: {indexed_after}") if __name__ == "__main__": main()