audio streaming app plyr.fm
at 7e2fe45b42f267ed8ae4affe5dff9cb7ece99bdd 195 lines 5.6 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2"""backfill CLAP embeddings for existing tracks. 3 4## Context 5 6Mood search uses CLAP embeddings stored in turbopuffer to match text 7descriptions to audio content. New uploads generate embeddings automatically, 8but existing tracks need to be backfilled. 9 10## What This Script Does 11 121. Queries all tracks with an R2 audio URL 132. Downloads each track's audio from R2 143. Generates a CLAP embedding via Modal 154. Upserts the embedding into turbopuffer 16 17## Usage 18 19```bash 20# dry run (show what would be embedded, no API calls) 21uv run scripts/backfill_embeddings.py --dry-run 22 23# embed first 5 tracks 24uv run scripts/backfill_embeddings.py --limit 5 25 26# full backfill with 10 concurrent workers (default) 27uv run scripts/backfill_embeddings.py 28 29# custom concurrency 30uv run scripts/backfill_embeddings.py --concurrency 20 31``` 32""" 33 34import argparse 35import asyncio 36import logging 37import time 38 39import httpx 40from sqlalchemy import select 41 42from backend._internal.clients.clap import get_clap_client 43from backend._internal.clients.tpuf import upsert 44from backend.config import settings 45from backend.models import Artist, Track 46from backend.utilities.database import db_session 47 48logging.basicConfig( 49 level=logging.INFO, 50 format="%(asctime)s - %(levelname)s - %(message)s", 51) 52logger = logging.getLogger(__name__) 53 54 55async def _embed_one( 56 track: Track, 57 artist: Artist, 58 http: httpx.AsyncClient, 59 clap_client: object, 60 sem: asyncio.Semaphore, 61 counter: dict[str, int], 62 total: int, 63) -> None: 64 """embed a single track, guarded by semaphore.""" 65 async with sem: 66 r2_url = f"{settings.storage.r2_public_bucket_url}/audio/{track.file_id}.{track.file_type}" 67 idx = counter["started"] + 1 68 counter["started"] += 1 69 70 try: 71 logger.info( 72 "embedding [%d/%d] track %d: %s", idx, total, track.id, track.title 73 ) 74 75 resp = await http.get(r2_url) 76 resp.raise_for_status() 77 audio_bytes = resp.content 78 79 embed_result = await clap_client.embed_audio(audio_bytes) 80 81 if not embed_result.success or not embed_result.embedding: 82 logger.warning( 83 "embedding failed for track %d: %s", track.id, embed_result.error 84 ) 85 counter["failed"] += 1 86 return 87 88 await upsert( 89 track_id=track.id, 90 embedding=embed_result.embedding, 91 title=track.title, 92 artist_handle=artist.handle, 93 artist_did=artist.did, 94 ) 95 96 counter["embedded"] += 1 97 logger.info( 98 "embedded track %d (%d dims)", track.id, embed_result.dimensions 99 ) 100 101 except Exception: 102 logger.exception("failed to process track %d", track.id) 103 counter["failed"] += 1 104 105 106async def backfill_embeddings( 107 dry_run: bool = False, 108 limit: int | None = None, 109 concurrency: int = 10, 110) -> None: 111 """backfill CLAP embeddings for tracks missing from turbopuffer.""" 112 113 if not dry_run: 114 if not settings.modal.enabled: 115 logger.error("MODAL_ENABLED is not set — cannot generate embeddings") 116 return 117 if not settings.turbopuffer.enabled: 118 logger.error("TURBOPUFFER_ENABLED is not set — cannot store embeddings") 119 return 120 121 async with db_session() as db: 122 stmt = ( 123 select(Track, Artist) 124 .join(Artist, Track.artist_did == Artist.did) 125 .where(Track.file_id.isnot(None)) 126 .order_by(Track.id) 127 ) 128 if limit: 129 stmt = stmt.limit(limit) 130 131 result = await db.execute(stmt) 132 rows = result.all() 133 134 if not rows: 135 logger.info("no tracks found to embed") 136 return 137 138 logger.info("found %d tracks to process (concurrency=%d)", len(rows), concurrency) 139 140 if dry_run: 141 for track, artist in rows: 142 logger.info( 143 "would embed: [%d] %s by %s", 144 track.id, 145 track.title, 146 artist.handle, 147 ) 148 return 149 150 clap_client = get_clap_client() 151 sem = asyncio.Semaphore(concurrency) 152 counter: dict[str, int] = {"started": 0, "embedded": 0, "failed": 0} 153 t0 = time.monotonic() 154 155 async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http: 156 tasks = [ 157 _embed_one(track, artist, http, clap_client, sem, counter, len(rows)) 158 for track, artist in rows 159 ] 160 await asyncio.gather(*tasks) 161 162 elapsed = time.monotonic() - t0 163 logger.info( 164 "backfill complete: %d embedded, %d failed, %d total in %.0fs (%.1f tracks/s)", 165 counter["embedded"], 166 counter["failed"], 167 len(rows), 168 elapsed, 169 len(rows) / elapsed if elapsed > 0 else 0, 170 ) 171 172 173async def main() -> None: 174 parser = argparse.ArgumentParser(description="backfill CLAP embeddings") 175 parser.add_argument( 176 "--dry-run", action="store_true", help="show what would be done" 177 ) 178 parser.add_argument("--limit", type=int, default=None, help="max tracks to process") 179 parser.add_argument( 180 "--concurrency", type=int, default=10, help="concurrent workers" 181 ) 182 args = parser.parse_args() 183 184 if args.dry_run: 185 logger.info("running in DRY RUN mode — no API calls will be made") 186 187 await backfill_embeddings( 188 dry_run=args.dry_run, 189 limit=args.limit, 190 concurrency=args.concurrency, 191 ) 192 193 194if __name__ == "__main__": 195 asyncio.run(main())