search for standard sites
pub-search.waow.tech
search
zig
blog
atproto
1#!/usr/bin/env -S uv run --script --quiet
2# /// script
3# requires-python = ">=3.12"
4# dependencies = ["httpx", "pydantic-settings"]
5# ///
6"""
7Populate the DiskANN vector index for semantic search.
8
9The index only tracks writes that happen after it's created, so existing
10embeddings need to be touched (UPDATE SET embedding = embedding) to get
11indexed. This script does that in batches to avoid turso timeouts.
12
13Usage:
14 ./scripts/populate-vector-index # populate index
15 ./scripts/populate-vector-index --batch-size 200 # custom batch size
16 ./scripts/populate-vector-index --check # just check index status
17"""
18
19import argparse
20import os
21import sys
22
23import httpx
24from pydantic_settings import BaseSettings, SettingsConfigDict
25
26
27class Settings(BaseSettings):
28 model_config = SettingsConfigDict(
29 env_file=os.environ.get("ENV_FILE", ".env"), extra="ignore"
30 )
31 turso_url: str
32 turso_token: str
33
34 @property
35 def turso_host(self) -> str:
36 url = self.turso_url
37 if url.startswith("libsql://"):
38 url = url[len("libsql://") :]
39 return url
40
41
42def query(settings: Settings, sql: str, timeout: int = 30) -> dict:
43 response = httpx.post(
44 f"https://{settings.turso_host}/v2/pipeline",
45 headers={
46 "Authorization": f"Bearer {settings.turso_token}",
47 "Content-Type": "application/json",
48 },
49 json={
50 "requests": [
51 {"type": "execute", "stmt": {"sql": sql}},
52 {"type": "close"},
53 ]
54 },
55 timeout=timeout,
56 )
57 response.raise_for_status()
58 data = response.json()
59 result = data["results"][0]
60 if result["type"] == "error":
61 raise Exception(f"turso error: {result['error']}")
62 return result["response"]["result"]
63
64
65def exec_sql(settings: Settings, sql: str, timeout: int = 120) -> int:
66 result = query(settings, sql, timeout=timeout)
67 return result.get("affected_row_count", 0)
68
69
70def get_scalar(settings: Settings, sql: str) -> int:
71 result = query(settings, sql)
72 row = result["rows"][0]
73 cell = row[0]
74 return int(cell["value"] if isinstance(cell, dict) else cell)
75
76
77def check_status(settings: Settings) -> tuple[int, int]:
78 """Returns (total_embeddings, indexed_count)."""
79 total = get_scalar(
80 settings, "SELECT COUNT(*) FROM documents WHERE embedding IS NOT NULL"
81 )
82
83 try:
84 # use a dummy vector to count indexed docs
85 indexed = get_scalar(
86 settings,
87 "SELECT count(*) FROM vector_top_k('documents_embedding_idx', "
88 "(SELECT embedding FROM documents WHERE embedding IS NOT NULL LIMIT 1), 10000)",
89 )
90 except Exception as e:
91 if "not found" in str(e).lower():
92 indexed = -1 # index doesn't exist
93 else:
94 raise
95
96 return total, indexed
97
98
99def main():
100 parser = argparse.ArgumentParser(
101 description="populate DiskANN vector index for semantic search"
102 )
103 parser.add_argument(
104 "--batch-size",
105 type=int,
106 default=500,
107 help="rows per UPDATE batch (default: 500)",
108 )
109 parser.add_argument(
110 "--workers", type=int, default=8, help="concurrent workers (default: 8)"
111 )
112 parser.add_argument(
113 "--check", action="store_true", help="just check index status and exit"
114 )
115 args = parser.parse_args()
116
117 try:
118 settings = Settings() # type: ignore
119 except Exception as e:
120 print(f"error loading settings: {e}", file=sys.stderr)
121 print("required env vars: TURSO_URL, TURSO_TOKEN", file=sys.stderr)
122 sys.exit(1)
123
124 total, indexed = check_status(settings)
125
126 if indexed == -1:
127 print("vector index does not exist — creating it...")
128 exec_sql(
129 settings,
130 "CREATE INDEX IF NOT EXISTS documents_embedding_idx "
131 "ON documents(libsql_vector_idx(embedding))",
132 )
133 print("index created")
134 indexed = 0
135
136 print(f"embeddings: {total}, indexed: {indexed}")
137
138 if args.check:
139 return
140
141 remaining = total - indexed
142 if remaining <= 0:
143 print("index is fully populated")
144 return
145
146 import time
147 from concurrent.futures import ThreadPoolExecutor, as_completed
148
149 batch_size = args.batch_size
150 num_batches = (remaining + batch_size - 1) // batch_size
151 workers = min(args.workers, num_batches)
152 print(f"touching {remaining} rows: {num_batches} batches of {batch_size}, {workers} workers", flush=True)
153
154 # pre-compute all batch offset ranges
155 batches = list(range(0, total, batch_size))
156
157 def touch_batch(offset: int) -> int:
158 sql = (
159 f"UPDATE documents SET embedding = embedding "
160 f"WHERE rowid IN ("
161 f" SELECT rowid FROM documents WHERE embedding IS NOT NULL "
162 f" ORDER BY rowid LIMIT {batch_size} OFFSET {offset}"
163 f")"
164 )
165 return exec_sql(settings, sql, timeout=300)
166
167 touched = 0
168 start = time.time()
169 with ThreadPoolExecutor(max_workers=workers) as pool:
170 futures = {pool.submit(touch_batch, o): o for o in batches}
171 for future in as_completed(futures):
172 affected = future.result()
173 touched += affected
174 elapsed = time.time() - start
175 rate = touched / elapsed if elapsed > 0 else 0
176 eta = (remaining - touched) / rate if rate > 0 else 0
177 print(f" {touched}/{remaining} ({rate:.0f} rows/s, ~{eta:.0f}s left)", flush=True)
178
179 total_after, indexed_after = check_status(settings)
180 print(f"done — embeddings: {total_after}, indexed: {indexed_after}")
181
182
183if __name__ == "__main__":
184 main()