audio streaming app
plyr.fm
1#!/usr/bin/env -S uv run --script --quiet
2"""backfill genre classifications for existing tracks.
3
4## Context
5
6Genre classification uses the effnet-discogs model on Replicate to classify
7audio into genre labels. New uploads classify automatically, but existing
8tracks need to be backfilled.
9
10## What This Script Does
11
121. Queries all tracks with an R2 audio URL missing genre_predictions in extra
132. Classifies each track via Replicate
143. Stores predictions in track.extra["genre_predictions"]
15
16## Usage
17
18```bash
19# dry run (show what would be classified, no API calls)
20uv run scripts/backfill_genres.py --dry-run
21
22# classify first 5 tracks
23uv run scripts/backfill_genres.py --limit 5
24
25# full backfill with 5 concurrent workers (default)
26uv run scripts/backfill_genres.py
27
28# custom concurrency
29uv run scripts/backfill_genres.py --concurrency 10
30```
31"""
32
33import argparse
34import asyncio
35import logging
36import time
37
38from sqlalchemy import select, text
39
40from backend._internal.clients.replicate import get_replicate_client
41from backend.config import settings
42from backend.models import Artist, Track
43from backend.utilities.database import db_session
44
45logging.basicConfig(
46 level=logging.INFO,
47 format="%(asctime)s - %(levelname)s - %(message)s",
48)
49logger = logging.getLogger(__name__)
50
51
52async def _classify_one(
53 track: Track,
54 artist: Artist,
55 client: object,
56 sem: asyncio.Semaphore,
57 counter: dict[str, int],
58 total: int,
59) -> None:
60 """classify a single track, guarded by semaphore."""
61 async with sem:
62 idx = counter["started"] + 1
63 counter["started"] += 1
64
65 try:
66 logger.info(
67 "classifying [%d/%d] track %d: %s by %s",
68 idx,
69 total,
70 track.id,
71 track.title,
72 artist.handle,
73 )
74
75 result = await client.classify(track.r2_url)
76
77 if not result.success or not result.genres:
78 logger.warning(
79 "classification failed for track %d: %s",
80 track.id,
81 result.error,
82 )
83 counter["failed"] += 1
84 return
85
86 predictions = [
87 {"name": g.name, "confidence": g.confidence} for g in result.genres
88 ]
89
90 async with db_session() as db:
91 db_result = await db.execute(select(Track).where(Track.id == track.id))
92 db_track = db_result.scalar_one_or_none()
93 if db_track:
94 extra = dict(db_track.extra) if db_track.extra else {}
95 extra["genre_predictions"] = predictions
96 extra["genre_predictions_file_id"] = db_track.file_id
97 db_track.extra = extra
98 await db.commit()
99
100 counter["classified"] += 1
101 logger.info(
102 "classified track %d: top genre = %s (%.2f)",
103 track.id,
104 predictions[0]["name"],
105 predictions[0]["confidence"],
106 )
107
108 except Exception:
109 logger.exception("failed to process track %d", track.id)
110 counter["failed"] += 1
111
112
113async def backfill_genres(
114 dry_run: bool = False,
115 limit: int | None = None,
116 concurrency: int = 5,
117) -> None:
118 """backfill genre classifications for tracks missing predictions."""
119
120 if not dry_run:
121 if not settings.replicate.enabled:
122 logger.error("REPLICATE_ENABLED is not set — cannot classify genres")
123 return
124
125 async with db_session() as db:
126 stmt = (
127 select(Track, Artist)
128 .join(Artist, Track.artist_did == Artist.did)
129 .where(
130 Track.r2_url.isnot(None),
131 # filter tracks missing genre_predictions in extra
132 text("NOT (tracks.extra ? 'genre_predictions')"),
133 )
134 .order_by(Track.id)
135 )
136 if limit:
137 stmt = stmt.limit(limit)
138
139 result = await db.execute(stmt)
140 rows = result.all()
141
142 if not rows:
143 logger.info("no tracks found to classify")
144 return
145
146 logger.info("found %d tracks to classify (concurrency=%d)", len(rows), concurrency)
147
148 if dry_run:
149 for track, artist in rows:
150 logger.info(
151 "would classify: [%d] %s by %s",
152 track.id,
153 track.title,
154 artist.handle,
155 )
156 return
157
158 client = get_replicate_client()
159 sem = asyncio.Semaphore(concurrency)
160 counter: dict[str, int] = {"started": 0, "classified": 0, "failed": 0}
161 t0 = time.monotonic()
162
163 tasks = [
164 _classify_one(track, artist, client, sem, counter, len(rows))
165 for track, artist in rows
166 ]
167 await asyncio.gather(*tasks)
168
169 elapsed = time.monotonic() - t0
170 logger.info(
171 "backfill complete: %d classified, %d failed, %d total in %.0fs (%.1f tracks/s)",
172 counter["classified"],
173 counter["failed"],
174 len(rows),
175 elapsed,
176 len(rows) / elapsed if elapsed > 0 else 0,
177 )
178
179
180async def main() -> None:
181 parser = argparse.ArgumentParser(description="backfill genre classifications")
182 parser.add_argument(
183 "--dry-run", action="store_true", help="show what would be done"
184 )
185 parser.add_argument("--limit", type=int, default=None, help="max tracks to process")
186 parser.add_argument("--concurrency", type=int, default=5, help="concurrent workers")
187 args = parser.parse_args()
188
189 if args.dry_run:
190 logger.info("running in DRY RUN mode — no API calls will be made")
191
192 await backfill_genres(
193 dry_run=args.dry_run,
194 limit=args.limit,
195 concurrency=args.concurrency,
196 )
197
198
199if __name__ == "__main__":
200 asyncio.run(main())