search for standard sites pub-search.waow.tech
search zig blog atproto
at ca756a01806bc76bc6514afb7ba67f4baa3b5491 155 lines 4.4 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2# /// script 3# requires-python = ">=3.12" 4# dependencies = ["httpx", "pydantic-settings"] 5# /// 6""" 7Measure cosine distance distributions from tpuf for various queries. 8 9Embeds queries via Voyage, runs ANN search on tpuf, and prints the 10distance distribution so we can pick an empirical threshold. 11 12Usage: 13 ./scripts/measure-distances 14""" 15 16import os 17import subprocess 18import sys 19 20import httpx 21from pydantic_settings import BaseSettings, SettingsConfigDict 22 23TPUF_NAMESPACE = "leaflet-search" 24FLY_APP = "leaflet-search-backend" 25 26TEST_QUERIES = [ 27 "community builders", 28 "consciousness", 29 "rust programming", 30 "atproto federation", 31 "machine learning", 32 "philosophy of mind", 33 "web development", 34 "decentralized social", 35] 36 37 38class Settings(BaseSettings): 39 model_config = SettingsConfigDict( 40 env_file=os.environ.get("ENV_FILE", ".env"), extra="ignore" 41 ) 42 voyage_api_key: str 43 44 45def get_tpuf_key() -> str: 46 result = subprocess.run( 47 ["fly", "-a", FLY_APP, "ssh", "console", "-C", "printenv TURBOPUFFER_API_KEY"], 48 capture_output=True, 49 text=True, 50 ) 51 if result.returncode != 0: 52 raise Exception(f"fly ssh failed: {result.stderr.strip()}") 53 key = result.stdout.strip().splitlines()[-1].strip() 54 if not key.startswith("tpuf_"): 55 raise Exception(f"unexpected key format: {key[:10]}...") 56 return key 57 58 59def embed_query(settings: Settings, text: str) -> list[float]: 60 resp = httpx.post( 61 "https://api.voyageai.com/v1/embeddings", 62 headers={ 63 "Authorization": f"Bearer {settings.voyage_api_key}", 64 "Content-Type": "application/json", 65 }, 66 json={ 67 "model": "voyage-4-lite", 68 "input_type": "query", 69 "output_dimension": 1024, 70 "input": [text], 71 }, 72 timeout=30, 73 ) 74 resp.raise_for_status() 75 return resp.json()["data"][0]["embedding"] 76 77 78def tpuf_query(tpuf_key: str, vector: list[float], top_k: int = 40) -> list[dict]: 79 resp = httpx.post( 80 f"https://api.turbopuffer.com/v2/namespaces/{TPUF_NAMESPACE}/query", 81 headers={ 82 "Authorization": f"Bearer {tpuf_key}", 83 "Content-Type": "application/json", 84 }, 85 json={ 86 "rank_by": ["vector", "ANN", vector], 87 "top_k": top_k, 88 "include_attributes": ["uri", "title"], 89 }, 90 timeout=30, 91 ) 92 resp.raise_for_status() 93 return resp.json().get("rows", []) 94 95 96def main(): 97 try: 98 settings = Settings() # type: ignore 99 except Exception as e: 100 print(f"error: {e}", file=sys.stderr) 101 print("required: VOYAGE_API_KEY (or .env file)", file=sys.stderr) 102 sys.exit(1) 103 104 print("getting tpuf key from fly...", end="", flush=True) 105 tpuf_key = get_tpuf_key() 106 print(f" ok\n") 107 108 for query in TEST_QUERIES: 109 print(f"=== {query!r} ===") 110 vector = embed_query(settings, query) 111 rows = tpuf_query(tpuf_key, vector) 112 113 if not rows: 114 print(" (no results)") 115 print() 116 continue 117 118 dists = [r["$dist"] for r in rows] 119 titles = [r.get("title", "?") for r in rows] 120 121 # show distribution 122 print(f" results: {len(dists)}") 123 print(f" min dist: {min(dists):.4f} (best match)") 124 print(f" max dist: {max(dists):.4f} (worst match)") 125 print(f" median: {sorted(dists)[len(dists)//2]:.4f}") 126 127 # histogram of distance buckets 128 buckets = {} 129 for d in dists: 130 b = round(d, 1) # bucket to nearest 0.1 131 buckets[b] = buckets.get(b, 0) + 1 132 print(f" buckets: ", end="") 133 for b in sorted(buckets): 134 print(f"[{b:.1f}]={buckets[b]}", end=" ") 135 print() 136 137 # count at various thresholds 138 for t in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]: 139 n = sum(1 for d in dists if d <= t) 140 print(f" dist<={t}: {n}/{len(dists)}") 141 142 # top 5 + bottom 5 143 print(f" top 5:") 144 for r in rows[:5]: 145 title = r.get("title", "?")[:60] 146 print(f" {r['$dist']:.4f} {title}") 147 print(f" bottom 5:") 148 for r in rows[-5:]: 149 title = r.get("title", "?")[:60] 150 print(f" {r['$dist']:.4f} {title}") 151 print() 152 153 154if __name__ == "__main__": 155 main()