"""Shared test fixtures with in-memory SQLite database. Session-based auth: tests create users and sessions directly in the DB, matching the Better-Auth session validation flow. """ import secrets import uuid from datetime import UTC, datetime, timedelta import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import String, TypeDecorator, Uuid, create_engine, event, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.types import CHAR from cartsnitch_api.config import settings as cartsnitch_settings from cartsnitch_api.database import get_db from cartsnitch_api.main import create_app from cartsnitch_api.middleware import rate_limit as _rate_limit_module from cartsnitch_api.models import Base class _StringUUID(TypeDecorator): """TypeDecorator that lets Text/String/UUID columns accept uuid.UUID on bind. SQLite has no native UUID type — passing a ``uuid.UUID`` raises ``type 'UUID' is not supported``. This stores UUID values as their hex string in the DB, accepts either uuid.UUID or str at bind time, and returns uuid.UUID on read so existing test assertions like ``isinstance(store.id, uuid.UUID)`` still work. """ impl = CHAR(36) cache_ok = True def process_bind_param(self, value, dialect): if value is None: return None if isinstance(value, uuid.UUID): return str(value) return str(value) def process_result_value(self, value, dialect): if value is None: return None if isinstance(value, uuid.UUID): return value return uuid.UUID(value) def _set_timestamp_defaults(mapper, connection, target): """Populate created_at/updated_at and missing PK IDs for SQLite. SQLite can't bind ``uuid.UUID`` objects to Text/String columns, and has no server-side default for ``func.now()`` or ``gen_random_uuid()``. We strip those server_defaults elsewhere; this listener fills in Python-side timestamp defaults at insert time, generates IDs for PK columns that have no default, and populates ``func.now()`` columns whose server_default was stripped (e.g. ``ingested_at``). UUID values for non-PK columns are converted by the ``_StringUUID`` TypeDecorator. """ now = datetime.now(UTC) for col in mapper.columns: key = col.key if key in ("created_at", "updated_at"): if getattr(target, key, None) is None: setattr(target, key, now) continue if col.primary_key and getattr(target, key, None) is None: setattr(target, key, str(uuid.uuid4())) continue if getattr(col, "_sqlite_default_now", False) and getattr(target, key, None) is None: setattr(target, key, now) def _adapt_columns_for_sqlite(): """Strip Postgres-only server_defaults and adapt UUID columns for SQLite. Must be called BEFORE ``Base.metadata.create_all`` so the DDL reflects the adapted column types. """ for tbl in Base.metadata.tables.values(): for col in tbl.columns.values(): # Strip PostgreSQL-specific function server_defaults (gen_random_uuid, # gen_random_bytes, now()) but keep simple string-literal defaults # like ``server_default="false"`` since they work in SQLite. sd = col.server_default if sd is not None: sd_text = str(sd.arg) if hasattr(sd, "arg") else str(sd) sd_text = sd_text.lower() if any(x in sd_text for x in ["gen_random_uuid", "gen_random_bytes", "now()"]): col.server_default = None if "now()" in sd_text and not col.nullable: col._sqlite_default_now = True # type: ignore[attr-defined] # Replace UUID column types with a SQLite-compatible TypeDecorator if isinstance(col.type, Uuid): col.type = _StringUUID() # Text/String PK columns without a default need the _StringUUID type # so the before_insert listener can generate hex-string IDs. if col.primary_key and col.default is None and col.server_default is None: if not isinstance(col.type, _StringUUID): col.type = _StringUUID() # FK columns that may receive uuid.UUID values from test code if col.foreign_keys and not col.primary_key and isinstance(col.type, String): col.type = _StringUUID() def _register_event_listeners(): """Attach before_insert listener to every mapped class.""" for cls in Base.registry._class_registry.values(): if hasattr(cls, "__mapper__"): event.listen(cls, "before_insert", _set_timestamp_defaults) TEST_JWT_SECRET = secrets.token_urlsafe(32) TEST_SERVICE_KEY = secrets.token_urlsafe(32) TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" @pytest.fixture(autouse=True) def setup_test_settings(): original_jwt = cartsnitch_settings.jwt_secret_key original_service = cartsnitch_settings.service_key original_fernet = cartsnitch_settings.fernet_key cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET cartsnitch_settings.service_key = TEST_SERVICE_KEY cartsnitch_settings.fernet_key = TEST_FERNET_KEY yield cartsnitch_settings.jwt_secret_key = original_jwt cartsnitch_settings.service_key = original_service cartsnitch_settings.fernet_key = original_fernet TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" @pytest.fixture(autouse=True) def disable_rate_limiting(): """Disable rate limiting for all tests to prevent 429 interference. The rate_limit module creates its Redis client at import time when ``settings.rate_limit_redis_enabled`` is true. We can't undo that by flipping the setting inside the fixture — the client and the Redis-backed limiters are already constructed. So we swap them out for the in-memory limiters directly on the module, which also prevents "Event loop is closed" errors when the redis client tries to disconnect after the test event loop ends. """ cartsnitch_settings.rate_limit_enabled = False cartsnitch_settings.rate_limit_redis_enabled = False original_public = _rate_limit_module._public_limiter original_auth = _rate_limit_module._auth_limiter original_auth_strict = _rate_limit_module._auth_strict_limiter _rate_limit_module._redis_client = None _rate_limit_module._use_redis = False _rate_limit_module._public_limiter = _rate_limit_module.InMemorySlidingWindow( cartsnitch_settings.rate_limit_requests, cartsnitch_settings.rate_limit_window_seconds ) _rate_limit_module._auth_limiter = _rate_limit_module.InMemorySlidingWindow( cartsnitch_settings.rate_limit_requests * 5, cartsnitch_settings.rate_limit_window_seconds ) _rate_limit_module._auth_strict_limiter = _rate_limit_module.InMemorySlidingWindow( cartsnitch_settings.rate_limit_auth_requests, cartsnitch_settings.rate_limit_auth_window_seconds, ) yield cartsnitch_settings.rate_limit_enabled = True cartsnitch_settings.rate_limit_redis_enabled = True _rate_limit_module._public_limiter = original_public _rate_limit_module._auth_limiter = original_auth _rate_limit_module._auth_strict_limiter = original_auth_strict @pytest.fixture def engine(): """Sync in-memory SQLite engine for model unit tests. Strips PostgreSQL-specific server_default expressions, replaces UUID column types with a SQLite-compatible TypeDecorator, and registers a before_insert event listener to populate timestamps. """ eng = create_engine("sqlite:///:memory:") _adapt_columns_for_sqlite() _register_event_listeners() Base.metadata.create_all(eng) yield eng eng.dispose() @pytest.fixture def session(engine): """Sync SQLAlchemy session for model unit tests.""" factory = sessionmaker(bind=engine) with factory() as sess: yield sess @pytest.fixture async def db_engine(): engine = create_async_engine(TEST_DATABASE_URL, echo=False) @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() _adapt_columns_for_sqlite() _register_event_listeners() async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, token TEXT NOT NULL UNIQUE, user_id TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, ip_address TEXT, user_agent TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS accounts ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, account_id TEXT NOT NULL, provider_id TEXT NOT NULL, access_token TEXT, refresh_token TEXT, access_token_expires_at TIMESTAMP, refresh_token_expires_at TIMESTAMP, scope TEXT, id_token TEXT, password TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS verifications ( id TEXT PRIMARY KEY, identifier TEXT NOT NULL, value TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) yield engine async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() @pytest.fixture async def db_session(db_engine): factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: yield session @pytest.fixture async def client(db_engine): factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async def override_get_db(): async with factory() as session: yield session app = create_app() app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _create_test_user_and_session( client: AsyncClient, db_engine, **user_overrides ) -> tuple[dict, str]: """Create a test user and a valid session directly in the DB. Returns (user_dict, session_token). Better-Auth stores the raw token in the DB, so we insert it as-is. """ user_id = str(uuid.uuid4()) email = user_overrides.get("email", "test@example.com") display_name = user_overrides.get("display_name", "Test User") session_token = secrets.token_urlsafe(32) session_id = str(uuid.uuid4()) now = datetime.now(UTC).isoformat() expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() async with db_engine.begin() as conn: await conn.execute( text( "INSERT INTO users (id, email, hashed_password, display_name, " "email_verified, email_inbound_token, created_at, updated_at) " "VALUES (:id, :email, :hashed_password, :display_name, " ":email_verified, :email_inbound_token, :created_at, :updated_at)" ), { "id": user_id, "email": email, "hashed_password": "not-used-with-better-auth", "display_name": display_name, "email_verified": False, "email_inbound_token": secrets.token_urlsafe(16), "created_at": now, "updated_at": now, }, ) await conn.execute( text( "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" ), { "id": session_id, "token": session_token, "user_id": user_id, "expires_at": expires, "created_at": now, "updated_at": now, }, ) return {"id": user_id, "email": email, "display_name": display_name}, session_token @pytest.fixture async def auth_headers(client, db_engine): """Create a test user with a valid session and return auth headers.""" _, session_token = await _create_test_user_and_session(client, db_engine) return {"Cookie": f"better-auth.session_token={session_token}"}