diff --git a/tests/conftest.py b/tests/conftest.py index e69de29..133f726 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,354 @@ +"""Shared test fixtures with in-memory SQLite database. + +Session-based auth: tests create users and sessions directly in the DB, +matching the Better-Auth session validation flow. +""" + +import secrets +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import String, TypeDecorator, Uuid, create_engine, event, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import CHAR + +from cartsnitch_api.config import settings as cartsnitch_settings +from cartsnitch_api.database import get_db +from cartsnitch_api.main import create_app +from cartsnitch_api.middleware import rate_limit as _rate_limit_module +from cartsnitch_api.models import Base + + +class _StringUUID(TypeDecorator): + """TypeDecorator that lets Text/String/UUID columns accept uuid.UUID on bind. + + SQLite has no native UUID type — passing a ``uuid.UUID`` raises + ``type 'UUID' is not supported``. This stores UUID values as their hex + string in the DB, accepts either uuid.UUID or str at bind time, and + returns uuid.UUID on read so existing test assertions like + ``isinstance(store.id, uuid.UUID)`` still work. + """ + + impl = CHAR(36) + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + if isinstance(value, uuid.UUID): + return str(value) + return str(value) + + def process_result_value(self, value, dialect): + if value is None: + return None + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) + + +def _set_timestamp_defaults(mapper, connection, target): + """Populate created_at/updated_at and missing PK IDs for SQLite. + + SQLite can't bind ``uuid.UUID`` objects to Text/String columns, and has + no server-side default for ``func.now()`` or ``gen_random_uuid()``. We + strip those server_defaults elsewhere; this listener fills in + Python-side timestamp defaults at insert time, generates IDs for PK + columns that have no default, and populates ``func.now()`` columns + whose server_default was stripped (e.g. ``ingested_at``). UUID values + for non-PK columns are converted by the ``_StringUUID`` TypeDecorator. + """ + now = datetime.now(UTC) + for col in mapper.columns: + key = col.key + if key in ("created_at", "updated_at"): + if getattr(target, key, None) is None: + setattr(target, key, now) + continue + if col.primary_key and getattr(target, key, None) is None: + setattr(target, key, str(uuid.uuid4())) + continue + if getattr(col, "_sqlite_default_now", False) and getattr(target, key, None) is None: + setattr(target, key, now) + + +def _adapt_columns_for_sqlite(): + """Strip Postgres-only server_defaults and adapt UUID columns for SQLite. + + Must be called BEFORE ``Base.metadata.create_all`` so the DDL reflects + the adapted column types. + """ + for tbl in Base.metadata.tables.values(): + for col in tbl.columns.values(): + # Strip PostgreSQL-specific function server_defaults (gen_random_uuid, + # gen_random_bytes, now()) but keep simple string-literal defaults + # like ``server_default="false"`` since they work in SQLite. + sd = col.server_default + if sd is not None: + sd_text = str(sd.arg) if hasattr(sd, "arg") else str(sd) + sd_text = sd_text.lower() + if any(x in sd_text for x in ["gen_random_uuid", "gen_random_bytes", "now()"]): + col.server_default = None + if "now()" in sd_text and not col.nullable: + col._sqlite_default_now = True # type: ignore[attr-defined] + + # Replace UUID column types with a SQLite-compatible TypeDecorator + if isinstance(col.type, Uuid): + col.type = _StringUUID() + + # Text/String PK columns without a default need the _StringUUID type + # so the before_insert listener can generate hex-string IDs. + if col.primary_key and col.default is None and col.server_default is None: + if not isinstance(col.type, _StringUUID): + col.type = _StringUUID() + + # FK columns that may receive uuid.UUID values from test code + if col.foreign_keys and not col.primary_key and isinstance(col.type, String): + col.type = _StringUUID() + + +def _register_event_listeners(): + """Attach before_insert listener to every mapped class.""" + for cls in Base.registry._class_registry.values(): + if hasattr(cls, "__mapper__"): + event.listen(cls, "before_insert", _set_timestamp_defaults) + + +TEST_JWT_SECRET = secrets.token_urlsafe(32) +TEST_SERVICE_KEY = secrets.token_urlsafe(32) +TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + + +@pytest.fixture(autouse=True) +def setup_test_settings(): + original_jwt = cartsnitch_settings.jwt_secret_key + original_service = cartsnitch_settings.service_key + original_fernet = cartsnitch_settings.fernet_key + cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET + cartsnitch_settings.service_key = TEST_SERVICE_KEY + cartsnitch_settings.fernet_key = TEST_FERNET_KEY + yield + cartsnitch_settings.jwt_secret_key = original_jwt + cartsnitch_settings.service_key = original_service + cartsnitch_settings.fernet_key = original_fernet + + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +@pytest.fixture(autouse=True) +def disable_rate_limiting(): + """Disable rate limiting for all tests to prevent 429 interference. + + The rate_limit module creates its Redis client at import time when + ``settings.rate_limit_redis_enabled`` is true. We can't undo that by + flipping the setting inside the fixture — the client and the + Redis-backed limiters are already constructed. So we swap them out + for the in-memory limiters directly on the module, which also + prevents "Event loop is closed" errors when the redis client tries + to disconnect after the test event loop ends. + """ + cartsnitch_settings.rate_limit_enabled = False + cartsnitch_settings.rate_limit_redis_enabled = False + original_public = _rate_limit_module._public_limiter + original_auth = _rate_limit_module._auth_limiter + original_auth_strict = _rate_limit_module._auth_strict_limiter + _rate_limit_module._redis_client = None + _rate_limit_module._use_redis = False + _rate_limit_module._public_limiter = _rate_limit_module.InMemorySlidingWindow( + cartsnitch_settings.rate_limit_requests, cartsnitch_settings.rate_limit_window_seconds + ) + _rate_limit_module._auth_limiter = _rate_limit_module.InMemorySlidingWindow( + cartsnitch_settings.rate_limit_requests * 5, cartsnitch_settings.rate_limit_window_seconds + ) + _rate_limit_module._auth_strict_limiter = _rate_limit_module.InMemorySlidingWindow( + cartsnitch_settings.rate_limit_auth_requests, + cartsnitch_settings.rate_limit_auth_window_seconds, + ) + yield + cartsnitch_settings.rate_limit_enabled = True + cartsnitch_settings.rate_limit_redis_enabled = True + _rate_limit_module._public_limiter = original_public + _rate_limit_module._auth_limiter = original_auth + _rate_limit_module._auth_strict_limiter = original_auth_strict + + +@pytest.fixture +def engine(): + """Sync in-memory SQLite engine for model unit tests. + + Strips PostgreSQL-specific server_default expressions, replaces UUID + column types with a SQLite-compatible TypeDecorator, and registers a + before_insert event listener to populate timestamps. + """ + eng = create_engine("sqlite:///:memory:") + _adapt_columns_for_sqlite() + _register_event_listeners() + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """Sync SQLAlchemy session for model unit tests.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess + + +@pytest.fixture +async def db_engine(): + engine = create_async_engine(TEST_DATABASE_URL, echo=False) + + @event.listens_for(engine.sync_engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + _adapt_columns_for_sqlite() + _register_event_listeners() + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await conn.execute( + text(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + token TEXT NOT NULL UNIQUE, + user_id TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + ip_address TEXT, + user_agent TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + ) + await conn.execute( + text(""" + CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + account_id TEXT NOT NULL, + provider_id TEXT NOT NULL, + access_token TEXT, + refresh_token TEXT, + access_token_expires_at TIMESTAMP, + refresh_token_expires_at TIMESTAMP, + scope TEXT, + id_token TEXT, + password TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + ) + await conn.execute( + text(""" + CREATE TABLE IF NOT EXISTS verifications ( + id TEXT PRIMARY KEY, + identifier TEXT NOT NULL, + value TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + ) + + yield engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest.fixture +async def db_session(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + yield session + + +@pytest.fixture +async def client(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + + async def override_get_db(): + async with factory() as session: + yield session + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + app.dependency_overrides.clear() + + +async def _create_test_user_and_session( + client: AsyncClient, db_engine, **user_overrides +) -> tuple[dict, str]: + """Create a test user and a valid session directly in the DB. + + Returns (user_dict, session_token). Better-Auth stores the raw token + in the DB, so we insert it as-is. + """ + user_id = str(uuid.uuid4()) + email = user_overrides.get("email", "test@example.com") + display_name = user_overrides.get("display_name", "Test User") + session_token = secrets.token_urlsafe(32) + session_id = str(uuid.uuid4()) + now = datetime.now(UTC).isoformat() + expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, " + "email_verified, email_inbound_token, created_at, updated_at) " + "VALUES (:id, :email, :hashed_password, :display_name, " + ":email_verified, :email_inbound_token, :created_at, :updated_at)" + ), + { + "id": user_id, + "email": email, + "hashed_password": "not-used-with-better-auth", + "display_name": display_name, + "email_verified": False, + "email_inbound_token": secrets.token_urlsafe(16), + "created_at": now, + "updated_at": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" + ), + { + "id": session_id, + "token": session_token, + "user_id": user_id, + "expires_at": expires, + "created_at": now, + "updated_at": now, + }, + ) + + return {"id": user_id, "email": email, "display_name": display_name}, session_token + + +@pytest.fixture +async def auth_headers(client, db_engine): + """Create a test user with a valid session and return auth headers.""" + _, session_token = await _create_test_user_and_session(client, db_engine) + return {"Cookie": f"better-auth.session_token={session_token}"}