audio streaming app
plyr.fm
1"""pytest configuration for relay tests."""
2
3import asyncio
4import contextlib
5import os
6from collections.abc import AsyncGenerator, Callable, Generator
7from contextlib import asynccontextmanager
8from datetime import UTC, datetime
9from io import BytesIO
10from typing import BinaryIO
11from urllib.parse import urlsplit, urlunsplit
12
13import asyncpg
14import pytest
15import redis as sync_redis_lib
16import sqlalchemy as sa
17from fastapi import FastAPI
18from fastapi.testclient import TestClient
19from sqlalchemy.ext.asyncio import (
20 AsyncConnection,
21 AsyncEngine,
22 AsyncSession,
23 create_async_engine,
24)
25from sqlalchemy.orm import sessionmaker
26
27from backend.config import settings
28from backend.models import Base
29from backend.storage.r2 import R2Storage
30from backend.utilities.redis import clear_client_cache
31
32
33class MockStorage(R2Storage):
34 """Mock storage for tests - no R2 credentials needed."""
35
36 def __init__(self):
37 # skip R2Storage.__init__ which requires credentials
38 pass
39
40 async def save(
41 self,
42 file: BinaryIO | BytesIO,
43 filename: str,
44 progress_callback: Callable[[float], None] | None = None,
45 ) -> str:
46 """Mock save - returns a fake file_id."""
47 return "mock_file_id_123"
48
49 async def get_url(
50 self,
51 file_id: str,
52 *,
53 file_type: str | None = None,
54 extension: str | None = None,
55 ) -> str | None:
56 """Mock get_url - returns a fake URL."""
57 return f"https://mock.r2.dev/{file_id}"
58
59 async def delete(self, file_id: str, file_type: str | None = None) -> bool:
60 """Mock delete."""
61 return True
62
63
64def pytest_configure(config):
65 """Set mock storage before any test modules are imported."""
66 import backend.storage
67
68 # set _storage directly to prevent R2Storage initialization
69 backend.storage._storage = MockStorage()
70
71
72def _database_from_url(url: str) -> str:
73 """extract database name from connection URL."""
74 _, _, path, _, _ = urlsplit(url)
75 return path.strip("/")
76
77
78def _postgres_admin_url(database_url: str) -> str:
79 """convert async database URL to sync postgres URL for admin operations."""
80 scheme, netloc, _, query, fragment = urlsplit(database_url)
81 # asyncpg -> postgres for direct connection
82 scheme = scheme.replace("+asyncpg", "").replace("postgresql", "postgres")
83 return urlunsplit((scheme, netloc, "/postgres", query, fragment))
84
85
86@asynccontextmanager
87async def session_context(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
88 """create a database session context."""
89 async_session_maker = sessionmaker(
90 bind=engine,
91 class_=AsyncSession,
92 expire_on_commit=False,
93 )
94 async with async_session_maker() as session:
95 yield session
96
97
98async def _create_clear_database_procedure(
99 connection: AsyncConnection,
100) -> None:
101 """creates a stored procedure in the test database used for quickly clearing
102 the database between tests.
103 """
104 tables = list(reversed(Base.metadata.sorted_tables))
105
106 def schema(table: sa.Table) -> str:
107 return table.schema or "public"
108
109 def timestamp_column(table: sa.Table) -> str | None:
110 """find the timestamp column to use for filtering"""
111 if "created_at" in table.columns:
112 return "created_at"
113 elif "updated_at" in table.columns:
114 return "updated_at"
115 else:
116 # if no timestamp column, delete all rows
117 return None
118
119 delete_statements = []
120 for table in tables:
121 ts_col = timestamp_column(table)
122 if ts_col:
123 delete_statements.append(
124 f"""
125 BEGIN
126 DELETE FROM {schema(table)}.{table.name}
127 WHERE {ts_col} > _test_start_time;
128 EXCEPTION WHEN OTHERS THEN
129 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM;
130 END;
131 """
132 )
133 else:
134 # no timestamp column - delete all rows
135 delete_statements.append(
136 f"""
137 BEGIN
138 DELETE FROM {schema(table)}.{table.name};
139 EXCEPTION WHEN OTHERS THEN
140 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM;
141 END;
142 """
143 )
144
145 deletes = "\n".join(delete_statements)
146
147 signature = "clear_database(_test_start_time timestamptz)"
148 procedure_body = f"""
149 CREATE PROCEDURE {signature}
150 LANGUAGE PLPGSQL
151 AS $$
152 BEGIN
153 {deletes}
154 END;
155 $$;
156 """
157
158 await connection.execute(sa.text(f"DROP PROCEDURE IF EXISTS {signature};"))
159 await connection.execute(sa.text(procedure_body))
160
161
162async def _truncate_tables(connection: AsyncConnection) -> None:
163 """truncate all tables to ensure a clean slate at start of session."""
164 # get all table names from metadata
165 tables = [table.name for table in Base.metadata.sorted_tables]
166 if not tables:
167 return
168
169 # truncate all tables with cascade to handle foreign keys
170 # restart identity resets auto-increment counters
171 stmt = f"TRUNCATE TABLE {', '.join(tables)} RESTART IDENTITY CASCADE;"
172 await connection.execute(sa.text(stmt))
173
174
175async def _setup_template_database(template_url: str) -> None:
176 """initialize database schema and helper procedure on template database."""
177 engine = create_async_engine(template_url, echo=False)
178 try:
179 async with engine.begin() as conn:
180 await conn.run_sync(Base.metadata.create_all)
181 await _truncate_tables(conn)
182 await _create_clear_database_procedure(conn)
183 finally:
184 await engine.dispose()
185
186
187async def _ensure_template_database(base_url: str) -> str:
188 """ensure template database exists and is migrated.
189
190 uses advisory lock to coordinate between xdist workers.
191 returns the template database name.
192 """
193 base_db_name = _database_from_url(base_url)
194 template_db_name = f"{base_db_name}_template"
195 postgres_url = _postgres_admin_url(base_url)
196
197 conn = await asyncpg.connect(postgres_url)
198 try:
199 # advisory lock prevents race condition between workers
200 await conn.execute("SELECT pg_advisory_lock(hashtext($1))", template_db_name)
201
202 # check if template exists
203 exists = await conn.fetchval(
204 "SELECT 1 FROM pg_database WHERE datname = $1", template_db_name
205 )
206
207 if not exists:
208 # create template database
209 await conn.execute(f'CREATE DATABASE "{template_db_name}"')
210
211 # build URL for template and set it up
212 scheme, netloc, _, query, fragment = urlsplit(base_url)
213 template_url = urlunsplit(
214 (scheme, netloc, f"/{template_db_name}", query, fragment)
215 )
216 await _setup_template_database(template_url)
217
218 # release lock (other workers waiting will see template exists)
219 await conn.execute("SELECT pg_advisory_unlock(hashtext($1))", template_db_name)
220
221 return template_db_name
222 finally:
223 await conn.close()
224
225
226async def _create_worker_database_from_template(
227 base_url: str, worker_id: str, template_db_name: str
228) -> str:
229 """create worker database by cloning the template (instant file copy)."""
230 base_db_name = _database_from_url(base_url)
231 worker_db_name = f"{base_db_name}_{worker_id}"
232 postgres_url = _postgres_admin_url(base_url)
233
234 conn = await asyncpg.connect(postgres_url)
235 try:
236 # kill connections to worker db (if it exists from previous run)
237 await conn.execute(
238 """
239 SELECT pg_terminate_backend(pid)
240 FROM pg_stat_activity
241 WHERE datname = $1 AND pid <> pg_backend_pid()
242 """,
243 worker_db_name,
244 )
245
246 # kill connections to template db (required for cloning)
247 await conn.execute(
248 """
249 SELECT pg_terminate_backend(pid)
250 FROM pg_stat_activity
251 WHERE datname = $1 AND pid <> pg_backend_pid()
252 """,
253 template_db_name,
254 )
255
256 # drop and recreate from template (instant - just file copy)
257 await conn.execute(f'DROP DATABASE IF EXISTS "{worker_db_name}"')
258 await conn.execute(
259 f'CREATE DATABASE "{worker_db_name}" WITH TEMPLATE "{template_db_name}"'
260 )
261
262 return worker_db_name
263 finally:
264 await conn.close()
265
266
267@pytest.fixture(scope="session")
268def test_database_url(worker_id: str) -> str:
269 """generate a unique test database URL for each pytest worker.
270
271 uses template database pattern for fast parallel test execution:
272 1. first worker creates template db with migrations (once)
273 2. each worker clones from template (instant file copy)
274
275 also patches settings.database.url so all production code uses test db.
276 """
277 import asyncio
278 import os
279
280 base_url = settings.database.url
281
282 # single worker - just use base database
283 if worker_id == "master":
284 asyncio.run(_setup_database_direct(base_url))
285 return base_url
286
287 # xdist workers - use template pattern
288 template_db_name = asyncio.run(_ensure_template_database(base_url))
289 asyncio.run(
290 _create_worker_database_from_template(base_url, worker_id, template_db_name)
291 )
292
293 # build URL for worker database
294 scheme, netloc, _, query, fragment = urlsplit(base_url)
295 base_db_name = _database_from_url(base_url)
296 worker_db_name = f"{base_db_name}_{worker_id}"
297 worker_url = urlunsplit((scheme, netloc, f"/{worker_db_name}", query, fragment))
298
299 # patch settings so all production code uses this URL
300 # this is safe because each xdist worker is a separate process
301 settings.database.url = worker_url
302 os.environ["DATABASE_URL"] = worker_url
303
304 return worker_url
305
306
307async def _setup_database_direct(database_url: str) -> None:
308 """set up database directly (for single worker mode)."""
309 engine = create_async_engine(database_url, echo=False)
310 try:
311 async with engine.begin() as conn:
312 await conn.run_sync(Base.metadata.create_all)
313 await _truncate_tables(conn)
314 await _create_clear_database_procedure(conn)
315 finally:
316 await engine.dispose()
317
318
319@pytest.fixture(scope="session")
320def _database_setup(test_database_url: str) -> None:
321 """marker fixture - database is set up by test_database_url fixture."""
322 _ = test_database_url # ensure dependency chain
323
324
325@pytest.fixture()
326async def _engine(
327 test_database_url: str, _database_setup: None
328) -> AsyncGenerator[AsyncEngine, None]:
329 """create a database engine for each test (to avoid event loop issues)."""
330 from backend.utilities.database import ENGINES
331
332 # clear any cached engines from previous tests
333 for cached_engine in list(ENGINES.values()):
334 await cached_engine.dispose()
335 ENGINES.clear()
336
337 engine = create_async_engine(
338 test_database_url,
339 echo=False,
340 pool_size=2,
341 max_overflow=0,
342 )
343 try:
344 yield engine
345 finally:
346 await engine.dispose()
347 # clean up cached engines
348 for cached_engine in list(ENGINES.values()):
349 await cached_engine.dispose()
350 ENGINES.clear()
351
352
353@pytest.fixture()
354async def _clear_db(_engine: AsyncEngine) -> AsyncGenerator[None, None]:
355 """clear the database after each test."""
356 start_time = datetime.now(UTC)
357
358 try:
359 yield
360 finally:
361 # clear the database after the test
362 async with _engine.begin() as conn:
363 await conn.execute(
364 sa.text("CALL clear_database(:start_time)"),
365 {"start_time": start_time},
366 )
367
368
369@pytest.fixture
370async def db_session(
371 _engine: AsyncEngine, _clear_db: None
372) -> AsyncGenerator[AsyncSession, None]:
373 """provide a database session for each test.
374
375 the _clear_db fixture is used as a dependency to ensure proper cleanup order.
376 """
377 async with session_context(engine=_engine) as session:
378 yield session
379
380
381@pytest.fixture(scope="session")
382def fastapi_app() -> Generator[FastAPI, None, None]:
383 """provides the FastAPI app with a test lifespan that skips docket worker.
384
385 docket Worker binds asyncio.Tasks to the TestClient's portal loop; under
386 xdist, session teardown runs on a different loop → RuntimeError. no test
387 needs a live worker (all docket usage is mocked), so skip it.
388 """
389 from backend.main import app as main_app
390
391 original_lifespan = main_app.router.lifespan_context
392 main_app.router.lifespan_context = _test_lifespan
393 yield main_app
394 main_app.router.lifespan_context = original_lifespan
395
396
397@asynccontextmanager
398async def _test_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
399 """test lifespan — skips docket worker to avoid event loop issues."""
400 from backend._internal import jam_service, notification_service, queue_service
401
402 await notification_service.setup()
403 await queue_service.setup()
404 await jam_service.setup()
405
406 yield
407
408 for service in (notification_service, queue_service, jam_service):
409 with contextlib.suppress(TimeoutError):
410 await asyncio.wait_for(service.shutdown(), timeout=2.0)
411
412
413@pytest.fixture(scope="session")
414def client(fastapi_app: FastAPI) -> Generator[TestClient, None, None]:
415 """provides a TestClient for testing the FastAPI application.
416
417 session-scoped to avoid the overhead of starting the full lifespan
418 (database init, services) for each test.
419 """
420 with TestClient(fastapi_app) as tc:
421 yield tc
422
423
424def _redis_db_for_worker(worker_id: str) -> int:
425 """determine redis database number based on xdist worker id.
426
427 uses different DB numbers for each worker to isolate parallel tests:
428 - master/gw0: db 1
429 - gw1: db 2
430 - gw2: db 3
431 - etc.
432
433 db 0 is reserved for local development.
434 """
435 if worker_id == "master" or not worker_id:
436 return 1
437 if "gw" in worker_id:
438 return 1 + int(worker_id.replace("gw", ""))
439 return 1
440
441
442def _redis_url_with_db(base_url: str, db: int) -> str:
443 """replace database number in redis URL."""
444 # redis://host:port/db -> redis://host:port/{new_db}
445 if "/" in base_url.rsplit(":", 1)[-1]:
446 # has db number, replace it
447 base = base_url.rsplit("/", 1)[0]
448 return f"{base}/{db}"
449 else:
450 # no db number, append it
451 return f"{base_url}/{db}"
452
453
454@pytest.fixture(scope="session", autouse=True)
455def redis_database(worker_id: str) -> Generator[None, None, None]:
456 """use isolated redis databases for parallel test execution.
457
458 each xdist worker gets its own redis database to prevent cache pollution
459 between tests running in parallel. flushes the db before and after tests.
460
461 if redis is not available, silently skips - tests that actually need redis
462 will fail on their own with a more specific error.
463 """
464 # skip if no redis configured
465 if not settings.docket.url:
466 yield
467 return
468
469 db = _redis_db_for_worker(worker_id)
470 new_url = _redis_url_with_db(settings.docket.url, db)
471
472 # patch settings for this worker process
473 settings.docket.url = new_url
474 os.environ["DOCKET_URL"] = new_url
475
476 # clear any cached clients (they have old URL)
477 clear_client_cache()
478
479 # try to flush db before tests - if redis unavailable, skip silently
480 try:
481 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1)
482 client.flushdb()
483 client.close()
484 except sync_redis_lib.ConnectionError:
485 # redis not available - tests that need it will fail with specific errors
486 yield
487 return
488
489 yield
490
491 # flush db after tests and clear cached clients
492 clear_client_cache()
493 try:
494 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1)
495 client.flushdb()
496 client.close()
497 except sync_redis_lib.ConnectionError:
498 pass # redis went away during tests, nothing to clean up