audio streaming app
plyr.fm
1"""tests for playlist track recommendation logic."""
2
3from unittest.mock import AsyncMock, patch
4
5from backend._internal.clients.tpuf import VectorSearchResult
6from backend._internal.recommendations import (
7 _kmeans,
8 get_playlist_recommendations,
9 rrf_merge,
10)
11
12# --- rrf_merge tests ---
13
14
15def test_rrf_merge_single_list() -> None:
16 """rrf_merge with a single list just filters excluded IDs."""
17 results = [
18 VectorSearchResult(track_id=1, distance=0.1),
19 VectorSearchResult(track_id=2, distance=0.2),
20 VectorSearchResult(track_id=3, distance=0.3),
21 ]
22 merged = rrf_merge([results], exclude_ids={2})
23
24 assert [r.track_id for r in merged] == [1, 3]
25
26
27def test_rrf_merge_multiple_lists() -> None:
28 """rrf_merge combines rankings — tracks appearing in multiple lists rank higher."""
29 list_a = [
30 VectorSearchResult(track_id=10, distance=0.1),
31 VectorSearchResult(track_id=20, distance=0.2),
32 VectorSearchResult(track_id=30, distance=0.3),
33 ]
34 list_b = [
35 VectorSearchResult(track_id=20, distance=0.15),
36 VectorSearchResult(track_id=40, distance=0.25),
37 VectorSearchResult(track_id=10, distance=0.35),
38 ]
39
40 merged = rrf_merge([list_a, list_b], exclude_ids=set())
41
42 # track 20 and 10 appear in both lists so should rank highest
43 ids = [r.track_id for r in merged]
44 assert 20 in ids[:2]
45 assert 10 in ids[:2]
46 assert len(ids) == 4 # 10, 20, 30, 40
47
48
49def test_rrf_merge_excludes_playlist_tracks() -> None:
50 """tracks in exclude_ids are not in the result."""
51 results = [
52 VectorSearchResult(track_id=1, distance=0.1),
53 VectorSearchResult(track_id=2, distance=0.2),
54 ]
55 merged = rrf_merge([results], exclude_ids={1, 2})
56 assert merged == []
57
58
59def test_rrf_merge_empty_lists() -> None:
60 """rrf_merge with empty input returns empty."""
61 assert rrf_merge([], exclude_ids=set()) == []
62 assert rrf_merge([[]], exclude_ids=set()) == []
63
64
65def test_rrf_merge_keeps_best_distance() -> None:
66 """when a track appears in multiple lists, keep the best (lowest) distance."""
67 list_a = [VectorSearchResult(track_id=1, distance=0.5)]
68 list_b = [VectorSearchResult(track_id=1, distance=0.2)]
69
70 merged = rrf_merge([list_a, list_b], exclude_ids=set())
71 assert len(merged) == 1
72 assert merged[0].distance == 0.2
73
74
75# --- k-means tests ---
76
77
78def test_kmeans_basic() -> None:
79 """k-means produces correct number of centroids and finds clusters."""
80 # two clusters: around [0,0] and [10,10]
81 vectors: list[list[float]] = [
82 [0.0, 0.1],
83 [0.1, 0.0],
84 [-0.1, 0.0],
85 [10.0, 10.1],
86 [10.1, 10.0],
87 [9.9, 10.0],
88 ]
89 centroids = _kmeans(vectors, k=2)
90 assert len(centroids) == 2
91 assert len(centroids[0]) == 2
92
93 # centroids should be near [0,0] and [10,10]
94 centroid_sorted = sorted(centroids, key=lambda c: c[0])
95 assert abs(centroid_sorted[0][0]) < 1.0
96 assert abs(centroid_sorted[1][0] - 10.0) < 1.0
97
98
99def test_kmeans_single_cluster() -> None:
100 """k-means with k=1 returns the mean."""
101 vectors: list[list[float]] = [[1.0, 2.0], [3.0, 4.0]]
102 centroids = _kmeans(vectors, k=1)
103 assert len(centroids) == 1
104 assert abs(centroids[0][0] - 2.0) < 1e-6
105 assert abs(centroids[0][1] - 3.0) < 1e-6
106
107
108# --- adaptive strategy tests ---
109
110
111@patch("backend._internal.recommendations.get_vectors")
112@patch("backend._internal.recommendations.query")
113async def test_single_track_strategy(
114 mock_query: AsyncMock, mock_get_vectors: AsyncMock
115) -> None:
116 """1 track: queries turbopuffer directly with the track's embedding."""
117 mock_get_vectors.return_value = {1: [0.1] * 512}
118 mock_query.return_value = [
119 VectorSearchResult(track_id=99, distance=0.1),
120 VectorSearchResult(track_id=1, distance=0.0), # should be excluded
121 ]
122
123 results = await get_playlist_recommendations([1], limit=3)
124
125 mock_query.assert_called_once()
126 assert len(results) == 1
127 assert results[0].track_id == 99
128
129
130@patch("backend._internal.recommendations.get_vectors")
131@patch("backend._internal.recommendations.query")
132async def test_multi_track_rrf_strategy(
133 mock_query: AsyncMock, mock_get_vectors: AsyncMock
134) -> None:
135 """2-5 tracks: queries each embedding, merges with RRF."""
136 mock_get_vectors.return_value = {
137 1: [0.1] * 512,
138 2: [0.2] * 512,
139 3: [0.3] * 512,
140 }
141 mock_query.return_value = [
142 VectorSearchResult(track_id=99, distance=0.1),
143 VectorSearchResult(track_id=98, distance=0.2),
144 ]
145
146 results = await get_playlist_recommendations([1, 2, 3], limit=3)
147
148 # should have queried 3 times (one per track)
149 assert mock_query.call_count == 3
150 assert all(r.track_id not in {1, 2, 3} for r in results)
151
152
153@patch("backend._internal.recommendations.get_vectors")
154@patch("backend._internal.recommendations.query")
155async def test_large_playlist_kmeans_strategy(
156 mock_query: AsyncMock, mock_get_vectors: AsyncMock
157) -> None:
158 """6+ tracks: clusters into centroids, queries each centroid."""
159 # 8 tracks -> min(3, 8//2) = 3 clusters
160 vecs = {i: [float(i) / 10] * 512 for i in range(1, 9)}
161 mock_get_vectors.return_value = vecs
162 mock_query.return_value = [
163 VectorSearchResult(track_id=99, distance=0.1),
164 ]
165
166 results = await get_playlist_recommendations(list(range(1, 9)), limit=3)
167
168 # should have queried 3 times (one per cluster centroid)
169 assert mock_query.call_count == 3
170 assert all(r.track_id not in set(range(1, 9)) for r in results)
171
172
173@patch("backend._internal.recommendations.get_vectors")
174async def test_empty_playlist_returns_empty(mock_get_vectors: AsyncMock) -> None:
175 """empty playlist returns empty recommendations."""
176 results = await get_playlist_recommendations([], limit=3)
177 assert results == []
178 mock_get_vectors.assert_not_called()
179
180
181@patch("backend._internal.recommendations.get_vectors")
182async def test_no_embeddings_returns_empty(mock_get_vectors: AsyncMock) -> None:
183 """playlist with no embedded tracks returns empty."""
184 mock_get_vectors.return_value = {}
185 results = await get_playlist_recommendations([1, 2], limit=3)
186 assert results == []