audio streaming app plyr.fm
at main 498 lines 16 kB view raw
1"""pytest configuration for relay tests.""" 2 3import asyncio 4import contextlib 5import os 6from collections.abc import AsyncGenerator, Callable, Generator 7from contextlib import asynccontextmanager 8from datetime import UTC, datetime 9from io import BytesIO 10from typing import BinaryIO 11from urllib.parse import urlsplit, urlunsplit 12 13import asyncpg 14import pytest 15import redis as sync_redis_lib 16import sqlalchemy as sa 17from fastapi import FastAPI 18from fastapi.testclient import TestClient 19from sqlalchemy.ext.asyncio import ( 20 AsyncConnection, 21 AsyncEngine, 22 AsyncSession, 23 create_async_engine, 24) 25from sqlalchemy.orm import sessionmaker 26 27from backend.config import settings 28from backend.models import Base 29from backend.storage.r2 import R2Storage 30from backend.utilities.redis import clear_client_cache 31 32 33class MockStorage(R2Storage): 34 """Mock storage for tests - no R2 credentials needed.""" 35 36 def __init__(self): 37 # skip R2Storage.__init__ which requires credentials 38 pass 39 40 async def save( 41 self, 42 file: BinaryIO | BytesIO, 43 filename: str, 44 progress_callback: Callable[[float], None] | None = None, 45 ) -> str: 46 """Mock save - returns a fake file_id.""" 47 return "mock_file_id_123" 48 49 async def get_url( 50 self, 51 file_id: str, 52 *, 53 file_type: str | None = None, 54 extension: str | None = None, 55 ) -> str | None: 56 """Mock get_url - returns a fake URL.""" 57 return f"https://mock.r2.dev/{file_id}" 58 59 async def delete(self, file_id: str, file_type: str | None = None) -> bool: 60 """Mock delete.""" 61 return True 62 63 64def pytest_configure(config): 65 """Set mock storage before any test modules are imported.""" 66 import backend.storage 67 68 # set _storage directly to prevent R2Storage initialization 69 backend.storage._storage = MockStorage() 70 71 72def _database_from_url(url: str) -> str: 73 """extract database name from connection URL.""" 74 _, _, path, _, _ = urlsplit(url) 75 return path.strip("/") 76 77 78def _postgres_admin_url(database_url: str) -> str: 79 """convert async database URL to sync postgres URL for admin operations.""" 80 scheme, netloc, _, query, fragment = urlsplit(database_url) 81 # asyncpg -> postgres for direct connection 82 scheme = scheme.replace("+asyncpg", "").replace("postgresql", "postgres") 83 return urlunsplit((scheme, netloc, "/postgres", query, fragment)) 84 85 86@asynccontextmanager 87async def session_context(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: 88 """create a database session context.""" 89 async_session_maker = sessionmaker( 90 bind=engine, 91 class_=AsyncSession, 92 expire_on_commit=False, 93 ) 94 async with async_session_maker() as session: 95 yield session 96 97 98async def _create_clear_database_procedure( 99 connection: AsyncConnection, 100) -> None: 101 """creates a stored procedure in the test database used for quickly clearing 102 the database between tests. 103 """ 104 tables = list(reversed(Base.metadata.sorted_tables)) 105 106 def schema(table: sa.Table) -> str: 107 return table.schema or "public" 108 109 def timestamp_column(table: sa.Table) -> str | None: 110 """find the timestamp column to use for filtering""" 111 if "created_at" in table.columns: 112 return "created_at" 113 elif "updated_at" in table.columns: 114 return "updated_at" 115 else: 116 # if no timestamp column, delete all rows 117 return None 118 119 delete_statements = [] 120 for table in tables: 121 ts_col = timestamp_column(table) 122 if ts_col: 123 delete_statements.append( 124 f""" 125 BEGIN 126 DELETE FROM {schema(table)}.{table.name} 127 WHERE {ts_col} > _test_start_time; 128 EXCEPTION WHEN OTHERS THEN 129 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM; 130 END; 131 """ 132 ) 133 else: 134 # no timestamp column - delete all rows 135 delete_statements.append( 136 f""" 137 BEGIN 138 DELETE FROM {schema(table)}.{table.name}; 139 EXCEPTION WHEN OTHERS THEN 140 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM; 141 END; 142 """ 143 ) 144 145 deletes = "\n".join(delete_statements) 146 147 signature = "clear_database(_test_start_time timestamptz)" 148 procedure_body = f""" 149 CREATE PROCEDURE {signature} 150 LANGUAGE PLPGSQL 151 AS $$ 152 BEGIN 153 {deletes} 154 END; 155 $$; 156 """ 157 158 await connection.execute(sa.text(f"DROP PROCEDURE IF EXISTS {signature};")) 159 await connection.execute(sa.text(procedure_body)) 160 161 162async def _truncate_tables(connection: AsyncConnection) -> None: 163 """truncate all tables to ensure a clean slate at start of session.""" 164 # get all table names from metadata 165 tables = [table.name for table in Base.metadata.sorted_tables] 166 if not tables: 167 return 168 169 # truncate all tables with cascade to handle foreign keys 170 # restart identity resets auto-increment counters 171 stmt = f"TRUNCATE TABLE {', '.join(tables)} RESTART IDENTITY CASCADE;" 172 await connection.execute(sa.text(stmt)) 173 174 175async def _setup_template_database(template_url: str) -> None: 176 """initialize database schema and helper procedure on template database.""" 177 engine = create_async_engine(template_url, echo=False) 178 try: 179 async with engine.begin() as conn: 180 await conn.run_sync(Base.metadata.create_all) 181 await _truncate_tables(conn) 182 await _create_clear_database_procedure(conn) 183 finally: 184 await engine.dispose() 185 186 187async def _ensure_template_database(base_url: str) -> str: 188 """ensure template database exists and is migrated. 189 190 uses advisory lock to coordinate between xdist workers. 191 returns the template database name. 192 """ 193 base_db_name = _database_from_url(base_url) 194 template_db_name = f"{base_db_name}_template" 195 postgres_url = _postgres_admin_url(base_url) 196 197 conn = await asyncpg.connect(postgres_url) 198 try: 199 # advisory lock prevents race condition between workers 200 await conn.execute("SELECT pg_advisory_lock(hashtext($1))", template_db_name) 201 202 # check if template exists 203 exists = await conn.fetchval( 204 "SELECT 1 FROM pg_database WHERE datname = $1", template_db_name 205 ) 206 207 if not exists: 208 # create template database 209 await conn.execute(f'CREATE DATABASE "{template_db_name}"') 210 211 # build URL for template and set it up 212 scheme, netloc, _, query, fragment = urlsplit(base_url) 213 template_url = urlunsplit( 214 (scheme, netloc, f"/{template_db_name}", query, fragment) 215 ) 216 await _setup_template_database(template_url) 217 218 # release lock (other workers waiting will see template exists) 219 await conn.execute("SELECT pg_advisory_unlock(hashtext($1))", template_db_name) 220 221 return template_db_name 222 finally: 223 await conn.close() 224 225 226async def _create_worker_database_from_template( 227 base_url: str, worker_id: str, template_db_name: str 228) -> str: 229 """create worker database by cloning the template (instant file copy).""" 230 base_db_name = _database_from_url(base_url) 231 worker_db_name = f"{base_db_name}_{worker_id}" 232 postgres_url = _postgres_admin_url(base_url) 233 234 conn = await asyncpg.connect(postgres_url) 235 try: 236 # kill connections to worker db (if it exists from previous run) 237 await conn.execute( 238 """ 239 SELECT pg_terminate_backend(pid) 240 FROM pg_stat_activity 241 WHERE datname = $1 AND pid <> pg_backend_pid() 242 """, 243 worker_db_name, 244 ) 245 246 # kill connections to template db (required for cloning) 247 await conn.execute( 248 """ 249 SELECT pg_terminate_backend(pid) 250 FROM pg_stat_activity 251 WHERE datname = $1 AND pid <> pg_backend_pid() 252 """, 253 template_db_name, 254 ) 255 256 # drop and recreate from template (instant - just file copy) 257 await conn.execute(f'DROP DATABASE IF EXISTS "{worker_db_name}"') 258 await conn.execute( 259 f'CREATE DATABASE "{worker_db_name}" WITH TEMPLATE "{template_db_name}"' 260 ) 261 262 return worker_db_name 263 finally: 264 await conn.close() 265 266 267@pytest.fixture(scope="session") 268def test_database_url(worker_id: str) -> str: 269 """generate a unique test database URL for each pytest worker. 270 271 uses template database pattern for fast parallel test execution: 272 1. first worker creates template db with migrations (once) 273 2. each worker clones from template (instant file copy) 274 275 also patches settings.database.url so all production code uses test db. 276 """ 277 import asyncio 278 import os 279 280 base_url = settings.database.url 281 282 # single worker - just use base database 283 if worker_id == "master": 284 asyncio.run(_setup_database_direct(base_url)) 285 return base_url 286 287 # xdist workers - use template pattern 288 template_db_name = asyncio.run(_ensure_template_database(base_url)) 289 asyncio.run( 290 _create_worker_database_from_template(base_url, worker_id, template_db_name) 291 ) 292 293 # build URL for worker database 294 scheme, netloc, _, query, fragment = urlsplit(base_url) 295 base_db_name = _database_from_url(base_url) 296 worker_db_name = f"{base_db_name}_{worker_id}" 297 worker_url = urlunsplit((scheme, netloc, f"/{worker_db_name}", query, fragment)) 298 299 # patch settings so all production code uses this URL 300 # this is safe because each xdist worker is a separate process 301 settings.database.url = worker_url 302 os.environ["DATABASE_URL"] = worker_url 303 304 return worker_url 305 306 307async def _setup_database_direct(database_url: str) -> None: 308 """set up database directly (for single worker mode).""" 309 engine = create_async_engine(database_url, echo=False) 310 try: 311 async with engine.begin() as conn: 312 await conn.run_sync(Base.metadata.create_all) 313 await _truncate_tables(conn) 314 await _create_clear_database_procedure(conn) 315 finally: 316 await engine.dispose() 317 318 319@pytest.fixture(scope="session") 320def _database_setup(test_database_url: str) -> None: 321 """marker fixture - database is set up by test_database_url fixture.""" 322 _ = test_database_url # ensure dependency chain 323 324 325@pytest.fixture() 326async def _engine( 327 test_database_url: str, _database_setup: None 328) -> AsyncGenerator[AsyncEngine, None]: 329 """create a database engine for each test (to avoid event loop issues).""" 330 from backend.utilities.database import ENGINES 331 332 # clear any cached engines from previous tests 333 for cached_engine in list(ENGINES.values()): 334 await cached_engine.dispose() 335 ENGINES.clear() 336 337 engine = create_async_engine( 338 test_database_url, 339 echo=False, 340 pool_size=2, 341 max_overflow=0, 342 ) 343 try: 344 yield engine 345 finally: 346 await engine.dispose() 347 # clean up cached engines 348 for cached_engine in list(ENGINES.values()): 349 await cached_engine.dispose() 350 ENGINES.clear() 351 352 353@pytest.fixture() 354async def _clear_db(_engine: AsyncEngine) -> AsyncGenerator[None, None]: 355 """clear the database after each test.""" 356 start_time = datetime.now(UTC) 357 358 try: 359 yield 360 finally: 361 # clear the database after the test 362 async with _engine.begin() as conn: 363 await conn.execute( 364 sa.text("CALL clear_database(:start_time)"), 365 {"start_time": start_time}, 366 ) 367 368 369@pytest.fixture 370async def db_session( 371 _engine: AsyncEngine, _clear_db: None 372) -> AsyncGenerator[AsyncSession, None]: 373 """provide a database session for each test. 374 375 the _clear_db fixture is used as a dependency to ensure proper cleanup order. 376 """ 377 async with session_context(engine=_engine) as session: 378 yield session 379 380 381@pytest.fixture(scope="session") 382def fastapi_app() -> Generator[FastAPI, None, None]: 383 """provides the FastAPI app with a test lifespan that skips docket worker. 384 385 docket Worker binds asyncio.Tasks to the TestClient's portal loop; under 386 xdist, session teardown runs on a different loop → RuntimeError. no test 387 needs a live worker (all docket usage is mocked), so skip it. 388 """ 389 from backend.main import app as main_app 390 391 original_lifespan = main_app.router.lifespan_context 392 main_app.router.lifespan_context = _test_lifespan 393 yield main_app 394 main_app.router.lifespan_context = original_lifespan 395 396 397@asynccontextmanager 398async def _test_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: 399 """test lifespan — skips docket worker to avoid event loop issues.""" 400 from backend._internal import jam_service, notification_service, queue_service 401 402 await notification_service.setup() 403 await queue_service.setup() 404 await jam_service.setup() 405 406 yield 407 408 for service in (notification_service, queue_service, jam_service): 409 with contextlib.suppress(TimeoutError): 410 await asyncio.wait_for(service.shutdown(), timeout=2.0) 411 412 413@pytest.fixture(scope="session") 414def client(fastapi_app: FastAPI) -> Generator[TestClient, None, None]: 415 """provides a TestClient for testing the FastAPI application. 416 417 session-scoped to avoid the overhead of starting the full lifespan 418 (database init, services) for each test. 419 """ 420 with TestClient(fastapi_app) as tc: 421 yield tc 422 423 424def _redis_db_for_worker(worker_id: str) -> int: 425 """determine redis database number based on xdist worker id. 426 427 uses different DB numbers for each worker to isolate parallel tests: 428 - master/gw0: db 1 429 - gw1: db 2 430 - gw2: db 3 431 - etc. 432 433 db 0 is reserved for local development. 434 """ 435 if worker_id == "master" or not worker_id: 436 return 1 437 if "gw" in worker_id: 438 return 1 + int(worker_id.replace("gw", "")) 439 return 1 440 441 442def _redis_url_with_db(base_url: str, db: int) -> str: 443 """replace database number in redis URL.""" 444 # redis://host:port/db -> redis://host:port/{new_db} 445 if "/" in base_url.rsplit(":", 1)[-1]: 446 # has db number, replace it 447 base = base_url.rsplit("/", 1)[0] 448 return f"{base}/{db}" 449 else: 450 # no db number, append it 451 return f"{base_url}/{db}" 452 453 454@pytest.fixture(scope="session", autouse=True) 455def redis_database(worker_id: str) -> Generator[None, None, None]: 456 """use isolated redis databases for parallel test execution. 457 458 each xdist worker gets its own redis database to prevent cache pollution 459 between tests running in parallel. flushes the db before and after tests. 460 461 if redis is not available, silently skips - tests that actually need redis 462 will fail on their own with a more specific error. 463 """ 464 # skip if no redis configured 465 if not settings.docket.url: 466 yield 467 return 468 469 db = _redis_db_for_worker(worker_id) 470 new_url = _redis_url_with_db(settings.docket.url, db) 471 472 # patch settings for this worker process 473 settings.docket.url = new_url 474 os.environ["DOCKET_URL"] = new_url 475 476 # clear any cached clients (they have old URL) 477 clear_client_cache() 478 479 # try to flush db before tests - if redis unavailable, skip silently 480 try: 481 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1) 482 client.flushdb() 483 client.close() 484 except sync_redis_lib.ConnectionError: 485 # redis not available - tests that need it will fail with specific errors 486 yield 487 return 488 489 yield 490 491 # flush db after tests and clear cached clients 492 clear_client_cache() 493 try: 494 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1) 495 client.flushdb() 496 client.close() 497 except sync_redis_lib.ConnectionError: 498 pass # redis went away during tests, nothing to clean up