audio streaming app plyr.fm
at main 200 lines 5.8 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2"""backfill genre classifications for existing tracks. 3 4## Context 5 6Genre classification uses the effnet-discogs model on Replicate to classify 7audio into genre labels. New uploads classify automatically, but existing 8tracks need to be backfilled. 9 10## What This Script Does 11 121. Queries all tracks with an R2 audio URL missing genre_predictions in extra 132. Classifies each track via Replicate 143. Stores predictions in track.extra["genre_predictions"] 15 16## Usage 17 18```bash 19# dry run (show what would be classified, no API calls) 20uv run scripts/backfill_genres.py --dry-run 21 22# classify first 5 tracks 23uv run scripts/backfill_genres.py --limit 5 24 25# full backfill with 5 concurrent workers (default) 26uv run scripts/backfill_genres.py 27 28# custom concurrency 29uv run scripts/backfill_genres.py --concurrency 10 30``` 31""" 32 33import argparse 34import asyncio 35import logging 36import time 37 38from sqlalchemy import select, text 39 40from backend._internal.clients.replicate import get_replicate_client 41from backend.config import settings 42from backend.models import Artist, Track 43from backend.utilities.database import db_session 44 45logging.basicConfig( 46 level=logging.INFO, 47 format="%(asctime)s - %(levelname)s - %(message)s", 48) 49logger = logging.getLogger(__name__) 50 51 52async def _classify_one( 53 track: Track, 54 artist: Artist, 55 client: object, 56 sem: asyncio.Semaphore, 57 counter: dict[str, int], 58 total: int, 59) -> None: 60 """classify a single track, guarded by semaphore.""" 61 async with sem: 62 idx = counter["started"] + 1 63 counter["started"] += 1 64 65 try: 66 logger.info( 67 "classifying [%d/%d] track %d: %s by %s", 68 idx, 69 total, 70 track.id, 71 track.title, 72 artist.handle, 73 ) 74 75 result = await client.classify(track.r2_url) 76 77 if not result.success or not result.genres: 78 logger.warning( 79 "classification failed for track %d: %s", 80 track.id, 81 result.error, 82 ) 83 counter["failed"] += 1 84 return 85 86 predictions = [ 87 {"name": g.name, "confidence": g.confidence} for g in result.genres 88 ] 89 90 async with db_session() as db: 91 db_result = await db.execute(select(Track).where(Track.id == track.id)) 92 db_track = db_result.scalar_one_or_none() 93 if db_track: 94 extra = dict(db_track.extra) if db_track.extra else {} 95 extra["genre_predictions"] = predictions 96 extra["genre_predictions_file_id"] = db_track.file_id 97 db_track.extra = extra 98 await db.commit() 99 100 counter["classified"] += 1 101 logger.info( 102 "classified track %d: top genre = %s (%.2f)", 103 track.id, 104 predictions[0]["name"], 105 predictions[0]["confidence"], 106 ) 107 108 except Exception: 109 logger.exception("failed to process track %d", track.id) 110 counter["failed"] += 1 111 112 113async def backfill_genres( 114 dry_run: bool = False, 115 limit: int | None = None, 116 concurrency: int = 5, 117) -> None: 118 """backfill genre classifications for tracks missing predictions.""" 119 120 if not dry_run: 121 if not settings.replicate.enabled: 122 logger.error("REPLICATE_ENABLED is not set — cannot classify genres") 123 return 124 125 async with db_session() as db: 126 stmt = ( 127 select(Track, Artist) 128 .join(Artist, Track.artist_did == Artist.did) 129 .where( 130 Track.r2_url.isnot(None), 131 # filter tracks missing genre_predictions in extra 132 text("NOT (tracks.extra ? 'genre_predictions')"), 133 ) 134 .order_by(Track.id) 135 ) 136 if limit: 137 stmt = stmt.limit(limit) 138 139 result = await db.execute(stmt) 140 rows = result.all() 141 142 if not rows: 143 logger.info("no tracks found to classify") 144 return 145 146 logger.info("found %d tracks to classify (concurrency=%d)", len(rows), concurrency) 147 148 if dry_run: 149 for track, artist in rows: 150 logger.info( 151 "would classify: [%d] %s by %s", 152 track.id, 153 track.title, 154 artist.handle, 155 ) 156 return 157 158 client = get_replicate_client() 159 sem = asyncio.Semaphore(concurrency) 160 counter: dict[str, int] = {"started": 0, "classified": 0, "failed": 0} 161 t0 = time.monotonic() 162 163 tasks = [ 164 _classify_one(track, artist, client, sem, counter, len(rows)) 165 for track, artist in rows 166 ] 167 await asyncio.gather(*tasks) 168 169 elapsed = time.monotonic() - t0 170 logger.info( 171 "backfill complete: %d classified, %d failed, %d total in %.0fs (%.1f tracks/s)", 172 counter["classified"], 173 counter["failed"], 174 len(rows), 175 elapsed, 176 len(rows) / elapsed if elapsed > 0 else 0, 177 ) 178 179 180async def main() -> None: 181 parser = argparse.ArgumentParser(description="backfill genre classifications") 182 parser.add_argument( 183 "--dry-run", action="store_true", help="show what would be done" 184 ) 185 parser.add_argument("--limit", type=int, default=None, help="max tracks to process") 186 parser.add_argument("--concurrency", type=int, default=5, help="concurrent workers") 187 args = parser.parse_args() 188 189 if args.dry_run: 190 logger.info("running in DRY RUN mode — no API calls will be made") 191 192 await backfill_genres( 193 dry_run=args.dry_run, 194 limit=args.limit, 195 concurrency=args.concurrency, 196 ) 197 198 199if __name__ == "__main__": 200 asyncio.run(main())