audio streaming app
plyr.fm
1#!/usr/bin/env -S uv run --script --quiet
2"""backfill CLAP embeddings for existing tracks.
3
4## Context
5
6Mood search uses CLAP embeddings stored in turbopuffer to match text
7descriptions to audio content. New uploads generate embeddings automatically,
8but existing tracks need to be backfilled.
9
10## What This Script Does
11
121. Queries all tracks with an R2 audio URL
132. Downloads each track's audio from R2
143. Generates a CLAP embedding via Modal
154. Upserts the embedding into turbopuffer
16
17## Usage
18
19```bash
20# dry run (show what would be embedded, no API calls)
21uv run scripts/backfill_embeddings.py --dry-run
22
23# embed first 5 tracks
24uv run scripts/backfill_embeddings.py --limit 5
25
26# full backfill with 10 concurrent workers (default)
27uv run scripts/backfill_embeddings.py
28
29# custom concurrency
30uv run scripts/backfill_embeddings.py --concurrency 20
31```
32"""
33
34import argparse
35import asyncio
36import logging
37import time
38
39import httpx
40from sqlalchemy import select
41
42from backend._internal.clients.clap import get_clap_client
43from backend._internal.clients.tpuf import upsert
44from backend.config import settings
45from backend.models import Artist, Track
46from backend.utilities.database import db_session
47
48logging.basicConfig(
49 level=logging.INFO,
50 format="%(asctime)s - %(levelname)s - %(message)s",
51)
52logger = logging.getLogger(__name__)
53
54
55async def _embed_one(
56 track: Track,
57 artist: Artist,
58 http: httpx.AsyncClient,
59 clap_client: object,
60 sem: asyncio.Semaphore,
61 counter: dict[str, int],
62 total: int,
63) -> None:
64 """embed a single track, guarded by semaphore."""
65 async with sem:
66 r2_url = f"{settings.storage.r2_public_bucket_url}/audio/{track.file_id}.{track.file_type}"
67 idx = counter["started"] + 1
68 counter["started"] += 1
69
70 try:
71 logger.info(
72 "embedding [%d/%d] track %d: %s", idx, total, track.id, track.title
73 )
74
75 resp = await http.get(r2_url)
76 resp.raise_for_status()
77 audio_bytes = resp.content
78
79 embed_result = await clap_client.embed_audio(audio_bytes)
80
81 if not embed_result.success or not embed_result.embedding:
82 logger.warning(
83 "embedding failed for track %d: %s", track.id, embed_result.error
84 )
85 counter["failed"] += 1
86 return
87
88 await upsert(
89 track_id=track.id,
90 embedding=embed_result.embedding,
91 title=track.title,
92 artist_handle=artist.handle,
93 artist_did=artist.did,
94 )
95
96 counter["embedded"] += 1
97 logger.info(
98 "embedded track %d (%d dims)", track.id, embed_result.dimensions
99 )
100
101 except Exception:
102 logger.exception("failed to process track %d", track.id)
103 counter["failed"] += 1
104
105
106async def backfill_embeddings(
107 dry_run: bool = False,
108 limit: int | None = None,
109 concurrency: int = 10,
110) -> None:
111 """backfill CLAP embeddings for tracks missing from turbopuffer."""
112
113 if not dry_run:
114 if not settings.modal.enabled:
115 logger.error("MODAL_ENABLED is not set — cannot generate embeddings")
116 return
117 if not settings.turbopuffer.enabled:
118 logger.error("TURBOPUFFER_ENABLED is not set — cannot store embeddings")
119 return
120
121 async with db_session() as db:
122 stmt = (
123 select(Track, Artist)
124 .join(Artist, Track.artist_did == Artist.did)
125 .where(Track.file_id.isnot(None))
126 .order_by(Track.id)
127 )
128 if limit:
129 stmt = stmt.limit(limit)
130
131 result = await db.execute(stmt)
132 rows = result.all()
133
134 if not rows:
135 logger.info("no tracks found to embed")
136 return
137
138 logger.info("found %d tracks to process (concurrency=%d)", len(rows), concurrency)
139
140 if dry_run:
141 for track, artist in rows:
142 logger.info(
143 "would embed: [%d] %s by %s",
144 track.id,
145 track.title,
146 artist.handle,
147 )
148 return
149
150 clap_client = get_clap_client()
151 sem = asyncio.Semaphore(concurrency)
152 counter: dict[str, int] = {"started": 0, "embedded": 0, "failed": 0}
153 t0 = time.monotonic()
154
155 async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http:
156 tasks = [
157 _embed_one(track, artist, http, clap_client, sem, counter, len(rows))
158 for track, artist in rows
159 ]
160 await asyncio.gather(*tasks)
161
162 elapsed = time.monotonic() - t0
163 logger.info(
164 "backfill complete: %d embedded, %d failed, %d total in %.0fs (%.1f tracks/s)",
165 counter["embedded"],
166 counter["failed"],
167 len(rows),
168 elapsed,
169 len(rows) / elapsed if elapsed > 0 else 0,
170 )
171
172
173async def main() -> None:
174 parser = argparse.ArgumentParser(description="backfill CLAP embeddings")
175 parser.add_argument(
176 "--dry-run", action="store_true", help="show what would be done"
177 )
178 parser.add_argument("--limit", type=int, default=None, help="max tracks to process")
179 parser.add_argument(
180 "--concurrency", type=int, default=10, help="concurrent workers"
181 )
182 args = parser.parse_args()
183
184 if args.dry_run:
185 logger.info("running in DRY RUN mode — no API calls will be made")
186
187 await backfill_embeddings(
188 dry_run=args.dry_run,
189 limit=args.limit,
190 concurrency=args.concurrency,
191 )
192
193
194if __name__ == "__main__":
195 asyncio.run(main())