"""Shared test fixtures with in-memory SQLite database. Session-based auth: tests create users and sessions directly in the DB, matching the Better-Auth session validation flow. """ import secrets import uuid as uuid_lib from datetime import UTC, datetime, timedelta import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import create_engine, event, text, TypeDecorator, String import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.types import CHAR, TypeEngine from cartsnitch_api.database import get_db from cartsnitch_api.main import create_app from cartsnitch_api.models import Base cartsnitch_settings = None class SQLiteCompatibleUUID(TypeDecorator): """Adapts PostgreSQL UUID for use with SQLite at test runtime. Stores as CHAR(32) hex string, converts to Python uuid.UUID on read. """ impl = CHAR(32) cache_ok = True def process_bind_param(self, value, dialect): if value is None: return None if isinstance(value, uuid_lib.UUID): return value.hex return str(value) def process_result_value(self, value, dialect): if value is None: return None return uuid_lib.UUID(hex=value) class _StringUUID(TypeDecorator): """Fallback TypeDecorator that accepts both UUID objects and strings. Used when the model uses Text but tests pass uuid.UUID values. Stores as native string, returns uuid.UUID on read. """ impl = String(36) cache_ok = True def process_bind_param(self, value, dialect): if value is None: return None if isinstance(value, uuid_lib.UUID): return value.hex return str(value) def process_result_value(self, value, dialect): if value is None: return None try: return uuid_lib.UUID(hex=value) except Exception: return value def _adapt_uuid_columns_for_sqlite(googol=None): """Replace PostgreSQL UUID column types with SQLiteCompatibleUUID. PostgreSQL UUID columns generate DDL using gen_random_uuid() which SQLite doesn't support. This replaces those column types so SQLite can bind UUIDs as hex strings. Also sets a Python-side default so INSERTs without explicit id succeed. Accepts an optional connection arg from run_sync. """ for table in Base.metadata.tables.values(): for column in table.columns.values(): if isinstance(column.type, PostgresUUID): column.type = SQLiteCompatibleUUID() if column.server_default is None and column.default is None: column.default = uuid_lib.uuid4 def _adapt_text_pk_columns_for_uuid(googol=None): """Replace Text primary key columns that tests bind uuid.UUID values into. User.id is a Text column but tests pass uuid.UUID objects. SQLite can't bind UUID directly so we swap these to a compatible type. Accepts an optional connection arg from run_sync. """ for table in Base.metadata.tables.values(): for column in table.columns.values(): if column.primary_key and isinstance(column.type, sa.Text): column.type = _StringUUID() def _adapt_fk_columns_for_uuid(googol=None): """Replace FK columns that reference UUID PKs but are typed as Text. purchase.user_id, user_store_accounts.user_id/store_id, etc. are typed as str but tests pass uuid.UUID objects for them. Accepts an optional connection arg from run_sync. """ for table in Base.metadata.tables.values(): for column in table.columns.values(): if column.foreign_keys: for fk in column.foreign_keys: if isinstance(column.type, (sa.Text, sa.String)) and column.type.length in (None, 255): column.type = _StringUUID() def _strip_postgres_server_defaults(googol=None): """Remove PostgreSQL-specific server_default expressions for SQLite compatibility. PostgreSQL functions like gen_random_uuid() and the base64 encoding expression in email_inbound_token are not valid in SQLite. Accepts an optional connection arg from run_sync. """ for table in Base.metadata.tables.values(): for column in table.columns.values(): if column.server_default is not None: sd = str(column.server_default.arg) if "gen_random_uuid" in sd or "gen_random_bytes" in sd: column.server_default = None TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" TEST_JWT_SECRET = secrets.token_urlsafe(32) TEST_SERVICE_KEY = secrets.token_urlsafe(32) TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" @pytest.fixture(autouse=True) def setup_test_settings(): from cartsnitch_api.config import settings as real_settings original_jwt = real_settings.jwt_secret_key original_service = real_settings.service_key original_fernet = real_settings.fernet_key real_settings.jwt_secret_key = TEST_JWT_SECRET real_settings.service_key = TEST_SERVICE_KEY real_settings.fernet_key = TEST_FERNET_KEY yield real_settings.jwt_secret_key = original_jwt real_settings.service_key = original_service real_settings.fernet_key = original_fernet @pytest.fixture(autouse=True) def disable_rate_limiting(): from cartsnitch_api.config import settings as real_settings real_settings.rate_limit_enabled = False yield real_settings.rate_limit_enabled = True @pytest.fixture def engine(): """Sync in-memory SQLite engine for model unit tests.""" eng = create_engine("sqlite:///:memory:") _adapt_uuid_columns_for_sqlite() _adapt_text_pk_columns_for_uuid() _adapt_fk_columns_for_uuid() _strip_postgres_server_defaults() Base.metadata.create_all(eng) yield eng eng.dispose() @pytest.fixture def session(engine): """Sync SQLAlchemy session for model unit tests.""" factory = sessionmaker(bind=engine) with factory() as sess: yield sess @pytest.fixture async def db_engine(): engine = create_async_engine(TEST_DATABASE_URL, echo=False) @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() async with engine.begin() as conn: await conn.run_sync(_adapt_uuid_columns_for_sqlite) await conn.run_sync(_adapt_text_pk_columns_for_uuid) await conn.run_sync(_adapt_fk_columns_for_uuid) await conn.run_sync(_strip_postgres_server_defaults) await conn.run_sync(Base.metadata.create_all) # Create Better-Auth tables (not managed by SQLAlchemy models) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, token TEXT NOT NULL UNIQUE, user_id TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, ip_address TEXT, user_agent TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS accounts ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, account_id TEXT NOT NULL, provider_id TEXT NOT NULL, access_token TEXT, refresh_token TEXT, access_token_expires_at TIMESTAMP, refresh_token_expires_at TIMESTAMP, scope TEXT, id_token TEXT, password TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) await conn.execute( text(""" CREATE TABLE IF NOT EXISTS verifications ( id TEXT PRIMARY KEY, identifier TEXT NOT NULL, value TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) """) ) yield engine async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() @pytest.fixture async def db_session(db_engine): factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: yield session @pytest.fixture async def client(db_engine): factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async def override_get_db(): async with factory() as session: yield session app = create_app() app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _create_test_user_and_session( client: AsyncClient, db_engine, **user_overrides ) -> tuple[dict, str]: """Create a test user and a valid session directly in the DB. Returns (user_dict, session_token). Better-Auth stores the raw token in the DB, so we insert it as-is. """ user_id = str(uuid_lib.uuid4()) email = user_overrides.get("email", "test@example.com") display_name = user_overrides.get("display_name", "Test User") session_token = secrets.token_urlsafe(32) session_id = str(uuid_lib.uuid4()) now = datetime.now(UTC).isoformat() expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() async with db_engine.begin() as conn: await conn.execute( text( "INSERT INTO users (id, email, hashed_password, display_name, " "email_verified, email_inbound_token, created_at, updated_at) " "VALUES (:id, :email, :hashed_password, :display_name, " ":email_verified, :email_inbound_token, :created_at, :updated_at)" ), { "id": user_id, "email": email, "hashed_password": "not-used-with-better-auth", "display_name": display_name, "email_verified": False, "email_inbound_token": secrets.token_urlsafe(16), "created_at": now, "updated_at": now, }, ) await conn.execute( text( "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" ), { "id": session_id, "token": session_token, "user_id": user_id, "expires_at": expires, "created_at": now, "updated_at": now, }, ) return {"id": user_id, "email": email, "display_name": display_name}, session_token @pytest.fixture async def auth_headers(client, db_engine): """Create a test user with a valid session and return auth headers.""" _, session_token = await _create_test_user_and_session(client, db_engine) return {"Cookie": f"better-auth.session_token={session_token}"}