ce23ee18b8
The rate-limit middleware creates a Redis client at module import time when rate_limit_redis_enabled is true. The conftest disables rate_limit_enabled but not the redis flag, so the client still gets created. After the test event loop closes, the client's async disconnect raises 'Event loop is closed', surfacing as 500s on test_validation_error_returns_422_with_field_errors and test_error_stats_with_valid_key. Setting rate_limit_redis_enabled=False in the autouse fixture prevents the Redis client from being created in the first place. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
327 lines
12 KiB
Python
327 lines
12 KiB
Python
"""Shared test fixtures with in-memory SQLite database.
|
|
|
|
Session-based auth: tests create users and sessions directly in the DB,
|
|
matching the Better-Auth session validation flow.
|
|
"""
|
|
|
|
import secrets
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import String, TypeDecorator, Uuid, create_engine, event, text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.types import CHAR
|
|
|
|
from cartsnitch_api.config import settings as cartsnitch_settings
|
|
from cartsnitch_api.database import get_db
|
|
from cartsnitch_api.main import create_app
|
|
from cartsnitch_api.models import Base
|
|
|
|
|
|
class _StringUUID(TypeDecorator):
|
|
"""TypeDecorator that lets Text/String/UUID columns accept uuid.UUID on bind.
|
|
|
|
SQLite has no native UUID type — passing a ``uuid.UUID`` raises
|
|
``type 'UUID' is not supported``. This stores UUID values as their hex
|
|
string in the DB, accepts either uuid.UUID or str at bind time, and
|
|
returns uuid.UUID on read so existing test assertions like
|
|
``isinstance(store.id, uuid.UUID)`` still work.
|
|
"""
|
|
|
|
impl = CHAR(36)
|
|
cache_ok = True
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, uuid.UUID):
|
|
return str(value)
|
|
return str(value)
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, uuid.UUID):
|
|
return value
|
|
return uuid.UUID(value)
|
|
|
|
|
|
def _set_timestamp_defaults(mapper, connection, target):
|
|
"""Populate created_at/updated_at and missing PK IDs for SQLite.
|
|
|
|
SQLite can't bind ``uuid.UUID`` objects to Text/String columns, and has
|
|
no server-side default for ``func.now()`` or ``gen_random_uuid()``. We
|
|
strip those server_defaults elsewhere; this listener fills in
|
|
Python-side timestamp defaults at insert time, generates IDs for PK
|
|
columns that have no default, and populates ``func.now()`` columns
|
|
whose server_default was stripped (e.g. ``ingested_at``). UUID values
|
|
for non-PK columns are converted by the ``_StringUUID`` TypeDecorator.
|
|
"""
|
|
now = datetime.now(UTC)
|
|
for col in mapper.columns:
|
|
key = col.key
|
|
if key in ("created_at", "updated_at"):
|
|
if getattr(target, key, None) is None:
|
|
setattr(target, key, now)
|
|
continue
|
|
if col.primary_key and getattr(target, key, None) is None:
|
|
setattr(target, key, str(uuid.uuid4()))
|
|
continue
|
|
if getattr(col, "_sqlite_default_now", False) and getattr(target, key, None) is None:
|
|
setattr(target, key, now)
|
|
|
|
|
|
def _adapt_columns_for_sqlite():
|
|
"""Strip Postgres-only server_defaults and adapt UUID columns for SQLite.
|
|
|
|
Must be called BEFORE ``Base.metadata.create_all`` so the DDL reflects
|
|
the adapted column types.
|
|
"""
|
|
for tbl in Base.metadata.tables.values():
|
|
for col in tbl.columns.values():
|
|
# Strip PostgreSQL-specific function server_defaults (gen_random_uuid,
|
|
# gen_random_bytes, now()) but keep simple string-literal defaults
|
|
# like ``server_default="false"`` since they work in SQLite.
|
|
sd = col.server_default
|
|
if sd is not None:
|
|
sd_text = str(sd.arg) if hasattr(sd, "arg") else str(sd)
|
|
sd_text = sd_text.lower()
|
|
if any(x in sd_text for x in ["gen_random_uuid", "gen_random_bytes", "now()"]):
|
|
col.server_default = None
|
|
if "now()" in sd_text and not col.nullable:
|
|
col._sqlite_default_now = True # type: ignore[attr-defined]
|
|
|
|
# Replace UUID column types with a SQLite-compatible TypeDecorator
|
|
if isinstance(col.type, Uuid):
|
|
col.type = _StringUUID()
|
|
|
|
# Text/String PK columns without a default need the _StringUUID type
|
|
# so the before_insert listener can generate hex-string IDs.
|
|
if col.primary_key and col.default is None and col.server_default is None:
|
|
if not isinstance(col.type, _StringUUID):
|
|
col.type = _StringUUID()
|
|
|
|
# FK columns that may receive uuid.UUID values from test code
|
|
if col.foreign_keys and not col.primary_key and isinstance(col.type, String):
|
|
col.type = _StringUUID()
|
|
|
|
|
|
def _register_event_listeners():
|
|
"""Attach before_insert listener to every mapped class."""
|
|
for cls in Base.registry._class_registry.values():
|
|
if hasattr(cls, "__mapper__"):
|
|
event.listen(cls, "before_insert", _set_timestamp_defaults)
|
|
|
|
|
|
TEST_JWT_SECRET = secrets.token_urlsafe(32)
|
|
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
|
|
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_test_settings():
|
|
original_jwt = cartsnitch_settings.jwt_secret_key
|
|
original_service = cartsnitch_settings.service_key
|
|
original_fernet = cartsnitch_settings.fernet_key
|
|
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
|
|
cartsnitch_settings.service_key = TEST_SERVICE_KEY
|
|
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
|
|
yield
|
|
cartsnitch_settings.jwt_secret_key = original_jwt
|
|
cartsnitch_settings.service_key = original_service
|
|
cartsnitch_settings.fernet_key = original_fernet
|
|
|
|
|
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def disable_rate_limiting():
|
|
"""Disable rate limiting for all tests to prevent 429 interference."""
|
|
cartsnitch_settings.rate_limit_enabled = False
|
|
cartsnitch_settings.rate_limit_redis_enabled = False
|
|
yield
|
|
cartsnitch_settings.rate_limit_enabled = True
|
|
cartsnitch_settings.rate_limit_redis_enabled = True
|
|
|
|
|
|
@pytest.fixture
|
|
def engine():
|
|
"""Sync in-memory SQLite engine for model unit tests.
|
|
|
|
Strips PostgreSQL-specific server_default expressions, replaces UUID
|
|
column types with a SQLite-compatible TypeDecorator, and registers a
|
|
before_insert event listener to populate timestamps.
|
|
"""
|
|
eng = create_engine("sqlite:///:memory:")
|
|
_adapt_columns_for_sqlite()
|
|
_register_event_listeners()
|
|
Base.metadata.create_all(eng)
|
|
yield eng
|
|
eng.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
def session(engine):
|
|
"""Sync SQLAlchemy session for model unit tests."""
|
|
factory = sessionmaker(bind=engine)
|
|
with factory() as sess:
|
|
yield sess
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
|
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.close()
|
|
|
|
_adapt_columns_for_sqlite()
|
|
_register_event_listeners()
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
await conn.execute(
|
|
text("""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id TEXT PRIMARY KEY,
|
|
token TEXT NOT NULL UNIQUE,
|
|
user_id TEXT NOT NULL,
|
|
expires_at TIMESTAMP NOT NULL,
|
|
ip_address TEXT,
|
|
user_agent TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
)
|
|
await conn.execute(
|
|
text("""
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
account_id TEXT NOT NULL,
|
|
provider_id TEXT NOT NULL,
|
|
access_token TEXT,
|
|
refresh_token TEXT,
|
|
access_token_expires_at TIMESTAMP,
|
|
refresh_token_expires_at TIMESTAMP,
|
|
scope TEXT,
|
|
id_token TEXT,
|
|
password TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
)
|
|
await conn.execute(
|
|
text("""
|
|
CREATE TABLE IF NOT EXISTS verifications (
|
|
id TEXT PRIMARY KEY,
|
|
identifier TEXT NOT NULL,
|
|
value TEXT NOT NULL,
|
|
expires_at TIMESTAMP NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
)
|
|
|
|
yield engine
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
|
async with factory() as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture
|
|
async def client(db_engine):
|
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
async def override_get_db():
|
|
async with factory() as session:
|
|
yield session
|
|
|
|
app = create_app()
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def _create_test_user_and_session(
|
|
client: AsyncClient, db_engine, **user_overrides
|
|
) -> tuple[dict, str]:
|
|
"""Create a test user and a valid session directly in the DB.
|
|
|
|
Returns (user_dict, session_token). Better-Auth stores the raw token
|
|
in the DB, so we insert it as-is.
|
|
"""
|
|
user_id = str(uuid.uuid4())
|
|
email = user_overrides.get("email", "test@example.com")
|
|
display_name = user_overrides.get("display_name", "Test User")
|
|
session_token = secrets.token_urlsafe(32)
|
|
session_id = str(uuid.uuid4())
|
|
now = datetime.now(UTC).isoformat()
|
|
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
|
|
|
async with db_engine.begin() as conn:
|
|
await conn.execute(
|
|
text(
|
|
"INSERT INTO users (id, email, hashed_password, display_name, "
|
|
"email_verified, email_inbound_token, created_at, updated_at) "
|
|
"VALUES (:id, :email, :hashed_password, :display_name, "
|
|
":email_verified, :email_inbound_token, :created_at, :updated_at)"
|
|
),
|
|
{
|
|
"id": user_id,
|
|
"email": email,
|
|
"hashed_password": "not-used-with-better-auth",
|
|
"display_name": display_name,
|
|
"email_verified": False,
|
|
"email_inbound_token": secrets.token_urlsafe(16),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
},
|
|
)
|
|
await conn.execute(
|
|
text(
|
|
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
|
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
|
),
|
|
{
|
|
"id": session_id,
|
|
"token": session_token,
|
|
"user_id": user_id,
|
|
"expires_at": expires,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
},
|
|
)
|
|
|
|
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
|
|
|
|
|
@pytest.fixture
|
|
async def auth_headers(client, db_engine):
|
|
"""Create a test user with a valid session and return auth headers."""
|
|
_, session_token = await _create_test_user_and_session(client, db_engine)
|
|
return {"Cookie": f"better-auth.session_token={session_token}"}
|