Files
api/tests/conftest.py
T
Barcode Betty 69d7fe1508
CI / lint (pull_request) Failing after 3s
CI / typecheck (pull_request) Successful in 26s
CI / test (pull_request) Successful in 34s
CI / build-and-push (pull_request) Has been skipped
Swap Redis limiters for in-memory in test fixture
The conftest was setting rate_limit_redis_enabled=False but the
rate_limit module's _redis_client and the RedisSlidingWindow limiters
are constructed at module import. Flipping the setting inside the
fixture doesn't undo that, so the Redis client was still being
constructed and torn down at the end of the test event loop, raising
RuntimeError('Event loop is closed').

This swaps the limiters directly on the module in the fixture setup
and restores the originals in teardown. Local: 164 passed, 7
skipped.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-02 13:42:48 +00:00

355 lines
13 KiB
Python

"""Shared test fixtures with in-memory SQLite database.
Session-based auth: tests create users and sessions directly in the DB,
matching the Better-Auth session validation flow.
"""
import secrets
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import String, TypeDecorator, Uuid, create_engine, event, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import CHAR
from cartsnitch_api.config import settings as cartsnitch_settings
from cartsnitch_api.database import get_db
from cartsnitch_api.main import create_app
from cartsnitch_api.middleware import rate_limit as _rate_limit_module
from cartsnitch_api.models import Base
class _StringUUID(TypeDecorator):
"""TypeDecorator that lets Text/String/UUID columns accept uuid.UUID on bind.
SQLite has no native UUID type — passing a ``uuid.UUID`` raises
``type 'UUID' is not supported``. This stores UUID values as their hex
string in the DB, accepts either uuid.UUID or str at bind time, and
returns uuid.UUID on read so existing test assertions like
``isinstance(store.id, uuid.UUID)`` still work.
"""
impl = CHAR(36)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return None
if isinstance(value, uuid.UUID):
return str(value)
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return None
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(value)
def _set_timestamp_defaults(mapper, connection, target):
"""Populate created_at/updated_at and missing PK IDs for SQLite.
SQLite can't bind ``uuid.UUID`` objects to Text/String columns, and has
no server-side default for ``func.now()`` or ``gen_random_uuid()``. We
strip those server_defaults elsewhere; this listener fills in
Python-side timestamp defaults at insert time, generates IDs for PK
columns that have no default, and populates ``func.now()`` columns
whose server_default was stripped (e.g. ``ingested_at``). UUID values
for non-PK columns are converted by the ``_StringUUID`` TypeDecorator.
"""
now = datetime.now(UTC)
for col in mapper.columns:
key = col.key
if key in ("created_at", "updated_at"):
if getattr(target, key, None) is None:
setattr(target, key, now)
continue
if col.primary_key and getattr(target, key, None) is None:
setattr(target, key, str(uuid.uuid4()))
continue
if getattr(col, "_sqlite_default_now", False) and getattr(target, key, None) is None:
setattr(target, key, now)
def _adapt_columns_for_sqlite():
"""Strip Postgres-only server_defaults and adapt UUID columns for SQLite.
Must be called BEFORE ``Base.metadata.create_all`` so the DDL reflects
the adapted column types.
"""
for tbl in Base.metadata.tables.values():
for col in tbl.columns.values():
# Strip PostgreSQL-specific function server_defaults (gen_random_uuid,
# gen_random_bytes, now()) but keep simple string-literal defaults
# like ``server_default="false"`` since they work in SQLite.
sd = col.server_default
if sd is not None:
sd_text = str(sd.arg) if hasattr(sd, "arg") else str(sd)
sd_text = sd_text.lower()
if any(x in sd_text for x in ["gen_random_uuid", "gen_random_bytes", "now()"]):
col.server_default = None
if "now()" in sd_text and not col.nullable:
col._sqlite_default_now = True # type: ignore[attr-defined]
# Replace UUID column types with a SQLite-compatible TypeDecorator
if isinstance(col.type, Uuid):
col.type = _StringUUID()
# Text/String PK columns without a default need the _StringUUID type
# so the before_insert listener can generate hex-string IDs.
if col.primary_key and col.default is None and col.server_default is None:
if not isinstance(col.type, _StringUUID):
col.type = _StringUUID()
# FK columns that may receive uuid.UUID values from test code
if col.foreign_keys and not col.primary_key and isinstance(col.type, String):
col.type = _StringUUID()
def _register_event_listeners():
"""Attach before_insert listener to every mapped class."""
for cls in Base.registry._class_registry.values():
if hasattr(cls, "__mapper__"):
event.listen(cls, "before_insert", _set_timestamp_defaults)
TEST_JWT_SECRET = secrets.token_urlsafe(32)
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
@pytest.fixture(autouse=True)
def setup_test_settings():
original_jwt = cartsnitch_settings.jwt_secret_key
original_service = cartsnitch_settings.service_key
original_fernet = cartsnitch_settings.fernet_key
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
cartsnitch_settings.service_key = TEST_SERVICE_KEY
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
yield
cartsnitch_settings.jwt_secret_key = original_jwt
cartsnitch_settings.service_key = original_service
cartsnitch_settings.fernet_key = original_fernet
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest.fixture(autouse=True)
def disable_rate_limiting():
"""Disable rate limiting for all tests to prevent 429 interference.
The rate_limit module creates its Redis client at import time when
``settings.rate_limit_redis_enabled`` is true. We can't undo that by
flipping the setting inside the fixture — the client and the
Redis-backed limiters are already constructed. So we swap them out
for the in-memory limiters directly on the module, which also
prevents "Event loop is closed" errors when the redis client tries
to disconnect after the test event loop ends.
"""
cartsnitch_settings.rate_limit_enabled = False
cartsnitch_settings.rate_limit_redis_enabled = False
original_public = _rate_limit_module._public_limiter
original_auth = _rate_limit_module._auth_limiter
original_auth_strict = _rate_limit_module._auth_strict_limiter
_rate_limit_module._redis_client = None
_rate_limit_module._use_redis = False
_rate_limit_module._public_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_requests, cartsnitch_settings.rate_limit_window_seconds
)
_rate_limit_module._auth_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_requests * 5, cartsnitch_settings.rate_limit_window_seconds
)
_rate_limit_module._auth_strict_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_auth_requests,
cartsnitch_settings.rate_limit_auth_window_seconds,
)
yield
cartsnitch_settings.rate_limit_enabled = True
cartsnitch_settings.rate_limit_redis_enabled = True
_rate_limit_module._public_limiter = original_public
_rate_limit_module._auth_limiter = original_auth
_rate_limit_module._auth_strict_limiter = original_auth_strict
@pytest.fixture
def engine():
"""Sync in-memory SQLite engine for model unit tests.
Strips PostgreSQL-specific server_default expressions, replaces UUID
column types with a SQLite-compatible TypeDecorator, and registers a
before_insert event listener to populate timestamps.
"""
eng = create_engine("sqlite:///:memory:")
_adapt_columns_for_sqlite()
_register_event_listeners()
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def session(engine):
"""Sync SQLAlchemy session for model unit tests."""
factory = sessionmaker(bind=engine)
with factory() as sess:
yield sess
@pytest.fixture
async def db_engine():
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
_adapt_columns_for_sqlite()
_register_event_listeners()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
token TEXT NOT NULL UNIQUE,
user_id TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
ip_address TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
account_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
access_token TEXT,
refresh_token TEXT,
access_token_expires_at TIMESTAMP,
refresh_token_expires_at TIMESTAMP,
scope TEXT,
id_token TEXT,
password TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS verifications (
id TEXT PRIMARY KEY,
identifier TEXT NOT NULL,
value TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
yield session
@pytest.fixture
async def client(db_engine):
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async def override_get_db():
async with factory() as session:
yield session
app = create_app()
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
async def _create_test_user_and_session(
client: AsyncClient, db_engine, **user_overrides
) -> tuple[dict, str]:
"""Create a test user and a valid session directly in the DB.
Returns (user_dict, session_token). Better-Auth stores the raw token
in the DB, so we insert it as-is.
"""
user_id = str(uuid.uuid4())
email = user_overrides.get("email", "test@example.com")
display_name = user_overrides.get("display_name", "Test User")
session_token = secrets.token_urlsafe(32)
session_id = str(uuid.uuid4())
now = datetime.now(UTC).isoformat()
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, "
"email_verified, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, "
":email_verified, :email_inbound_token, :created_at, :updated_at)"
),
{
"id": user_id,
"email": email,
"hashed_password": "not-used-with-better-auth",
"display_name": display_name,
"email_verified": False,
"email_inbound_token": secrets.token_urlsafe(16),
"created_at": now,
"updated_at": now,
},
)
await conn.execute(
text(
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
),
{
"id": session_id,
"token": session_token,
"user_id": user_id,
"expires_at": expires,
"created_at": now,
"updated_at": now,
},
)
return {"id": user_id, "email": email, "display_name": display_name}, session_token
@pytest.fixture
async def auth_headers(client, db_engine):
"""Create a test user with a valid session and return auth headers."""
_, session_token = await _create_test_user_and_session(client, db_engine)
return {"Cookie": f"better-auth.session_token={session_token}"}