audio streaming app plyr.fm
at main 186 lines 6.2 kB view raw
1"""tests for playlist track recommendation logic.""" 2 3from unittest.mock import AsyncMock, patch 4 5from backend._internal.clients.tpuf import VectorSearchResult 6from backend._internal.recommendations import ( 7 _kmeans, 8 get_playlist_recommendations, 9 rrf_merge, 10) 11 12# --- rrf_merge tests --- 13 14 15def test_rrf_merge_single_list() -> None: 16 """rrf_merge with a single list just filters excluded IDs.""" 17 results = [ 18 VectorSearchResult(track_id=1, distance=0.1), 19 VectorSearchResult(track_id=2, distance=0.2), 20 VectorSearchResult(track_id=3, distance=0.3), 21 ] 22 merged = rrf_merge([results], exclude_ids={2}) 23 24 assert [r.track_id for r in merged] == [1, 3] 25 26 27def test_rrf_merge_multiple_lists() -> None: 28 """rrf_merge combines rankings — tracks appearing in multiple lists rank higher.""" 29 list_a = [ 30 VectorSearchResult(track_id=10, distance=0.1), 31 VectorSearchResult(track_id=20, distance=0.2), 32 VectorSearchResult(track_id=30, distance=0.3), 33 ] 34 list_b = [ 35 VectorSearchResult(track_id=20, distance=0.15), 36 VectorSearchResult(track_id=40, distance=0.25), 37 VectorSearchResult(track_id=10, distance=0.35), 38 ] 39 40 merged = rrf_merge([list_a, list_b], exclude_ids=set()) 41 42 # track 20 and 10 appear in both lists so should rank highest 43 ids = [r.track_id for r in merged] 44 assert 20 in ids[:2] 45 assert 10 in ids[:2] 46 assert len(ids) == 4 # 10, 20, 30, 40 47 48 49def test_rrf_merge_excludes_playlist_tracks() -> None: 50 """tracks in exclude_ids are not in the result.""" 51 results = [ 52 VectorSearchResult(track_id=1, distance=0.1), 53 VectorSearchResult(track_id=2, distance=0.2), 54 ] 55 merged = rrf_merge([results], exclude_ids={1, 2}) 56 assert merged == [] 57 58 59def test_rrf_merge_empty_lists() -> None: 60 """rrf_merge with empty input returns empty.""" 61 assert rrf_merge([], exclude_ids=set()) == [] 62 assert rrf_merge([[]], exclude_ids=set()) == [] 63 64 65def test_rrf_merge_keeps_best_distance() -> None: 66 """when a track appears in multiple lists, keep the best (lowest) distance.""" 67 list_a = [VectorSearchResult(track_id=1, distance=0.5)] 68 list_b = [VectorSearchResult(track_id=1, distance=0.2)] 69 70 merged = rrf_merge([list_a, list_b], exclude_ids=set()) 71 assert len(merged) == 1 72 assert merged[0].distance == 0.2 73 74 75# --- k-means tests --- 76 77 78def test_kmeans_basic() -> None: 79 """k-means produces correct number of centroids and finds clusters.""" 80 # two clusters: around [0,0] and [10,10] 81 vectors: list[list[float]] = [ 82 [0.0, 0.1], 83 [0.1, 0.0], 84 [-0.1, 0.0], 85 [10.0, 10.1], 86 [10.1, 10.0], 87 [9.9, 10.0], 88 ] 89 centroids = _kmeans(vectors, k=2) 90 assert len(centroids) == 2 91 assert len(centroids[0]) == 2 92 93 # centroids should be near [0,0] and [10,10] 94 centroid_sorted = sorted(centroids, key=lambda c: c[0]) 95 assert abs(centroid_sorted[0][0]) < 1.0 96 assert abs(centroid_sorted[1][0] - 10.0) < 1.0 97 98 99def test_kmeans_single_cluster() -> None: 100 """k-means with k=1 returns the mean.""" 101 vectors: list[list[float]] = [[1.0, 2.0], [3.0, 4.0]] 102 centroids = _kmeans(vectors, k=1) 103 assert len(centroids) == 1 104 assert abs(centroids[0][0] - 2.0) < 1e-6 105 assert abs(centroids[0][1] - 3.0) < 1e-6 106 107 108# --- adaptive strategy tests --- 109 110 111@patch("backend._internal.recommendations.get_vectors") 112@patch("backend._internal.recommendations.query") 113async def test_single_track_strategy( 114 mock_query: AsyncMock, mock_get_vectors: AsyncMock 115) -> None: 116 """1 track: queries turbopuffer directly with the track's embedding.""" 117 mock_get_vectors.return_value = {1: [0.1] * 512} 118 mock_query.return_value = [ 119 VectorSearchResult(track_id=99, distance=0.1), 120 VectorSearchResult(track_id=1, distance=0.0), # should be excluded 121 ] 122 123 results = await get_playlist_recommendations([1], limit=3) 124 125 mock_query.assert_called_once() 126 assert len(results) == 1 127 assert results[0].track_id == 99 128 129 130@patch("backend._internal.recommendations.get_vectors") 131@patch("backend._internal.recommendations.query") 132async def test_multi_track_rrf_strategy( 133 mock_query: AsyncMock, mock_get_vectors: AsyncMock 134) -> None: 135 """2-5 tracks: queries each embedding, merges with RRF.""" 136 mock_get_vectors.return_value = { 137 1: [0.1] * 512, 138 2: [0.2] * 512, 139 3: [0.3] * 512, 140 } 141 mock_query.return_value = [ 142 VectorSearchResult(track_id=99, distance=0.1), 143 VectorSearchResult(track_id=98, distance=0.2), 144 ] 145 146 results = await get_playlist_recommendations([1, 2, 3], limit=3) 147 148 # should have queried 3 times (one per track) 149 assert mock_query.call_count == 3 150 assert all(r.track_id not in {1, 2, 3} for r in results) 151 152 153@patch("backend._internal.recommendations.get_vectors") 154@patch("backend._internal.recommendations.query") 155async def test_large_playlist_kmeans_strategy( 156 mock_query: AsyncMock, mock_get_vectors: AsyncMock 157) -> None: 158 """6+ tracks: clusters into centroids, queries each centroid.""" 159 # 8 tracks -> min(3, 8//2) = 3 clusters 160 vecs = {i: [float(i) / 10] * 512 for i in range(1, 9)} 161 mock_get_vectors.return_value = vecs 162 mock_query.return_value = [ 163 VectorSearchResult(track_id=99, distance=0.1), 164 ] 165 166 results = await get_playlist_recommendations(list(range(1, 9)), limit=3) 167 168 # should have queried 3 times (one per cluster centroid) 169 assert mock_query.call_count == 3 170 assert all(r.track_id not in set(range(1, 9)) for r in results) 171 172 173@patch("backend._internal.recommendations.get_vectors") 174async def test_empty_playlist_returns_empty(mock_get_vectors: AsyncMock) -> None: 175 """empty playlist returns empty recommendations.""" 176 results = await get_playlist_recommendations([], limit=3) 177 assert results == [] 178 mock_get_vectors.assert_not_called() 179 180 181@patch("backend._internal.recommendations.get_vectors") 182async def test_no_embeddings_returns_empty(mock_get_vectors: AsyncMock) -> None: 183 """playlist with no embedded tracks returns empty.""" 184 mock_get_vectors.return_value = {} 185 results = await get_playlist_recommendations([1, 2], limit=3) 186 assert results == []