search for standard sites pub-search.waow.tech
search zig blog atproto
at main 222 lines 7.7 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2# /// script 3# requires-python = ">=3.12" 4# dependencies = ["httpx", "pydantic-settings"] 5# /// 6""" 7Backfill embeddings for leaflet-search documents. 8 9Usage: 10 ./scripts/backfill-embeddings # process all documents missing embeddings 11 ./scripts/backfill-embeddings --limit 10 # process 10 documents 12 ./scripts/backfill-embeddings --dry-run # show what would be processed 13""" 14 15import argparse 16import json 17import os 18import sys 19 20import httpx 21from pydantic_settings import BaseSettings, SettingsConfigDict 22 23 24class Settings(BaseSettings): 25 model_config = SettingsConfigDict( 26 env_file=os.environ.get("ENV_FILE", ".env"), extra="ignore" 27 ) 28 29 turso_url: str 30 turso_token: str 31 voyage_api_key: str 32 33 @property 34 def turso_host(self) -> str: 35 """Strip libsql:// prefix if present.""" 36 url = self.turso_url 37 if url.startswith("libsql://"): 38 url = url[len("libsql://"):] 39 return url 40 41 42def turso_query(settings: Settings, sql: str, args: list | None = None) -> list[dict]: 43 """Execute a query against Turso and return rows.""" 44 stmt = {"sql": sql} 45 if args: 46 stmt["args"] = [{"type": "text", "value": str(a)} for a in args] 47 48 response = httpx.post( 49 f"https://{settings.turso_host}/v2/pipeline", 50 headers={ 51 "Authorization": f"Bearer {settings.turso_token}", 52 "Content-Type": "application/json", 53 }, 54 json={"requests": [{"type": "execute", "stmt": stmt}, {"type": "close"}]}, 55 timeout=30, 56 ) 57 response.raise_for_status() 58 data = response.json() 59 60 result = data["results"][0] 61 if result["type"] == "error": 62 raise Exception(f"Turso error: {result['error']}") 63 64 cols = [c["name"] for c in result["response"]["result"]["cols"]] 65 rows = result["response"]["result"]["rows"] 66 67 def extract_value(cell): 68 if cell is None: 69 return None 70 if isinstance(cell, dict): 71 return cell.get("value") 72 # cell might be the value directly in some formats 73 return cell 74 75 return [dict(zip(cols, [extract_value(cell) for cell in row])) for row in rows] 76 77 78def turso_exec(settings: Settings, sql: str, args: list | None = None, retries: int = 3) -> None: 79 """Execute a statement against Turso with retry logic.""" 80 turso_batch_exec(settings, [(sql, args)], retries) 81 82 83def turso_batch_exec(settings: Settings, statements: list[tuple[str, list | None]], retries: int = 3) -> None: 84 """Execute multiple statements in a single pipeline request.""" 85 import time 86 87 requests = [] 88 for sql, args in statements: 89 stmt = {"sql": sql} 90 if args: 91 stmt["args"] = [{"type": "text", "value": str(a)} for a in args] 92 requests.append({"type": "execute", "stmt": stmt}) 93 requests.append({"type": "close"}) 94 95 for attempt in range(retries): 96 try: 97 response = httpx.post( 98 f"https://{settings.turso_host}/v2/pipeline", 99 headers={ 100 "Authorization": f"Bearer {settings.turso_token}", 101 "Content-Type": "application/json", 102 }, 103 json={"requests": requests}, 104 timeout=120, 105 ) 106 response.raise_for_status() 107 data = response.json() 108 for i, result in enumerate(data["results"][:-1]): # skip the close result 109 if result["type"] == "error": 110 raise Exception(f"Turso error on statement {i}: {result['error']}") 111 return 112 except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.ConnectError) as e: 113 if attempt < retries - 1: 114 wait = 2 ** (attempt + 1) 115 print(f" {type(e).__name__}, retrying in {wait}s...") 116 time.sleep(wait) 117 else: 118 raise 119 120 121def voyage_embed(settings: Settings, texts: list[str]) -> list[list[float]]: 122 """Generate embeddings using Voyage AI.""" 123 response = httpx.post( 124 "https://api.voyageai.com/v1/embeddings", 125 headers={ 126 "Authorization": f"Bearer {settings.voyage_api_key}", 127 "Content-Type": "application/json", 128 }, 129 json={ 130 "input": texts, 131 "model": "voyage-4-lite", 132 "output_dimension": 1024, 133 "input_type": "document", 134 }, 135 timeout=60, 136 ) 137 response.raise_for_status() 138 data = response.json() 139 return [item["embedding"] for item in data["data"]] 140 141 142def main(): 143 parser = argparse.ArgumentParser(description="Backfill embeddings for leaflet-search") 144 parser.add_argument("--limit", type=int, default=0, help="max documents to process (0 = all)") 145 parser.add_argument("--batch-size", type=int, default=20, help="documents per Voyage API call") 146 parser.add_argument("--dry-run", action="store_true", help="show what would be processed") 147 args = parser.parse_args() 148 149 try: 150 settings = Settings() # type: ignore 151 except Exception as e: 152 print(f"error loading settings: {e}", file=sys.stderr) 153 print("required env vars: TURSO_URL, TURSO_TOKEN, VOYAGE_API_KEY", file=sys.stderr) 154 sys.exit(1) 155 156 # check if embedding column exists, add if not 157 try: 158 turso_query(settings, "SELECT embedding FROM documents LIMIT 1") 159 except Exception as e: 160 if "no such column" in str(e).lower(): 161 print("adding embedding column...") 162 turso_exec(settings, "ALTER TABLE documents ADD COLUMN embedding F32_BLOB(1024)") 163 print("done") 164 else: 165 raise 166 167 # get documents needing embeddings 168 limit_clause = f"LIMIT {args.limit}" if args.limit > 0 else "" 169 docs = turso_query( 170 settings, 171 f"SELECT uri, title, content FROM documents WHERE embedding IS NULL {limit_clause}", 172 ) 173 174 if not docs: 175 print("no documents need embeddings") 176 return 177 178 print(f"found {len(docs)} documents needing embeddings") 179 180 if args.dry_run: 181 for doc in docs[:10]: 182 print(f" - {doc['uri']}: {doc['title'][:50]}...") 183 if len(docs) > 10: 184 print(f" ... and {len(docs) - 10} more") 185 return 186 187 # process in batches with concurrency 188 from concurrent.futures import ThreadPoolExecutor, as_completed 189 190 def process_batch(batch_info): 191 batch_num, batch = batch_info 192 texts = [f"{doc['title']} {doc['content']}" for doc in batch] 193 embeddings = voyage_embed(settings, texts) 194 statements = [] 195 for doc, embedding in zip(batch, embeddings): 196 embedding_json = json.dumps(embedding) 197 statements.append(( 198 "UPDATE documents SET embedding = vector32(?) WHERE uri = ?", 199 [embedding_json, doc["uri"]], 200 )) 201 turso_batch_exec(settings, statements) 202 return batch_num, len(batch) 203 204 batches = [(i // args.batch_size + 1, docs[i : i + args.batch_size]) 205 for i in range(0, len(docs), args.batch_size)] 206 207 processed = 0 208 workers = min(8, len(batches)) # more workers now that index is dropped 209 print(f"processing {len(batches)} batches with {workers} workers...") 210 211 with ThreadPoolExecutor(max_workers=workers) as executor: 212 futures = {executor.submit(process_batch, b): b[0] for b in batches} 213 for future in as_completed(futures): 214 batch_num, count = future.result() 215 processed += count 216 print(f"batch {batch_num} done ({processed}/{len(docs)})", flush=True) 217 218 print(f"done! processed {processed} documents") 219 220 221if __name__ == "__main__": 222 main()