search for standard sites pub-search.waow.tech
search zig blog atproto
at main 184 lines 5.7 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2# /// script 3# requires-python = ">=3.12" 4# dependencies = ["httpx", "pydantic-settings"] 5# /// 6""" 7Populate the DiskANN vector index for semantic search. 8 9The index only tracks writes that happen after it's created, so existing 10embeddings need to be touched (UPDATE SET embedding = embedding) to get 11indexed. This script does that in batches to avoid turso timeouts. 12 13Usage: 14 ./scripts/populate-vector-index # populate index 15 ./scripts/populate-vector-index --batch-size 200 # custom batch size 16 ./scripts/populate-vector-index --check # just check index status 17""" 18 19import argparse 20import os 21import sys 22 23import httpx 24from pydantic_settings import BaseSettings, SettingsConfigDict 25 26 27class Settings(BaseSettings): 28 model_config = SettingsConfigDict( 29 env_file=os.environ.get("ENV_FILE", ".env"), extra="ignore" 30 ) 31 turso_url: str 32 turso_token: str 33 34 @property 35 def turso_host(self) -> str: 36 url = self.turso_url 37 if url.startswith("libsql://"): 38 url = url[len("libsql://") :] 39 return url 40 41 42def query(settings: Settings, sql: str, timeout: int = 30) -> dict: 43 response = httpx.post( 44 f"https://{settings.turso_host}/v2/pipeline", 45 headers={ 46 "Authorization": f"Bearer {settings.turso_token}", 47 "Content-Type": "application/json", 48 }, 49 json={ 50 "requests": [ 51 {"type": "execute", "stmt": {"sql": sql}}, 52 {"type": "close"}, 53 ] 54 }, 55 timeout=timeout, 56 ) 57 response.raise_for_status() 58 data = response.json() 59 result = data["results"][0] 60 if result["type"] == "error": 61 raise Exception(f"turso error: {result['error']}") 62 return result["response"]["result"] 63 64 65def exec_sql(settings: Settings, sql: str, timeout: int = 120) -> int: 66 result = query(settings, sql, timeout=timeout) 67 return result.get("affected_row_count", 0) 68 69 70def get_scalar(settings: Settings, sql: str) -> int: 71 result = query(settings, sql) 72 row = result["rows"][0] 73 cell = row[0] 74 return int(cell["value"] if isinstance(cell, dict) else cell) 75 76 77def check_status(settings: Settings) -> tuple[int, int]: 78 """Returns (total_embeddings, indexed_count).""" 79 total = get_scalar( 80 settings, "SELECT COUNT(*) FROM documents WHERE embedding IS NOT NULL" 81 ) 82 83 try: 84 # use a dummy vector to count indexed docs 85 indexed = get_scalar( 86 settings, 87 "SELECT count(*) FROM vector_top_k('documents_embedding_idx', " 88 "(SELECT embedding FROM documents WHERE embedding IS NOT NULL LIMIT 1), 10000)", 89 ) 90 except Exception as e: 91 if "not found" in str(e).lower(): 92 indexed = -1 # index doesn't exist 93 else: 94 raise 95 96 return total, indexed 97 98 99def main(): 100 parser = argparse.ArgumentParser( 101 description="populate DiskANN vector index for semantic search" 102 ) 103 parser.add_argument( 104 "--batch-size", 105 type=int, 106 default=500, 107 help="rows per UPDATE batch (default: 500)", 108 ) 109 parser.add_argument( 110 "--workers", type=int, default=8, help="concurrent workers (default: 8)" 111 ) 112 parser.add_argument( 113 "--check", action="store_true", help="just check index status and exit" 114 ) 115 args = parser.parse_args() 116 117 try: 118 settings = Settings() # type: ignore 119 except Exception as e: 120 print(f"error loading settings: {e}", file=sys.stderr) 121 print("required env vars: TURSO_URL, TURSO_TOKEN", file=sys.stderr) 122 sys.exit(1) 123 124 total, indexed = check_status(settings) 125 126 if indexed == -1: 127 print("vector index does not exist — creating it...") 128 exec_sql( 129 settings, 130 "CREATE INDEX IF NOT EXISTS documents_embedding_idx " 131 "ON documents(libsql_vector_idx(embedding))", 132 ) 133 print("index created") 134 indexed = 0 135 136 print(f"embeddings: {total}, indexed: {indexed}") 137 138 if args.check: 139 return 140 141 remaining = total - indexed 142 if remaining <= 0: 143 print("index is fully populated") 144 return 145 146 import time 147 from concurrent.futures import ThreadPoolExecutor, as_completed 148 149 batch_size = args.batch_size 150 num_batches = (remaining + batch_size - 1) // batch_size 151 workers = min(args.workers, num_batches) 152 print(f"touching {remaining} rows: {num_batches} batches of {batch_size}, {workers} workers", flush=True) 153 154 # pre-compute all batch offset ranges 155 batches = list(range(0, total, batch_size)) 156 157 def touch_batch(offset: int) -> int: 158 sql = ( 159 f"UPDATE documents SET embedding = embedding " 160 f"WHERE rowid IN (" 161 f" SELECT rowid FROM documents WHERE embedding IS NOT NULL " 162 f" ORDER BY rowid LIMIT {batch_size} OFFSET {offset}" 163 f")" 164 ) 165 return exec_sql(settings, sql, timeout=300) 166 167 touched = 0 168 start = time.time() 169 with ThreadPoolExecutor(max_workers=workers) as pool: 170 futures = {pool.submit(touch_batch, o): o for o in batches} 171 for future in as_completed(futures): 172 affected = future.result() 173 touched += affected 174 elapsed = time.time() - start 175 rate = touched / elapsed if elapsed > 0 else 0 176 eta = (remaining - touched) / rate if rate > 0 else 0 177 print(f" {touched}/{remaining} ({rate:.0f} rows/s, ~{eta:.0f}s left)", flush=True) 178 179 total_after, indexed_after = check_status(settings) 180 print(f"done — embeddings: {total_after}, indexed: {indexed_after}") 181 182 183if __name__ == "__main__": 184 main()