diff --git a/src/cartsnitch_api/database.py b/src/cartsnitch_api/database.py index 3c6043c..1168f4b 100644 --- a/src/cartsnitch_api/database.py +++ b/src/cartsnitch_api/database.py @@ -6,14 +6,21 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn from cartsnitch_api.config import settings -engine = create_async_engine( - settings.database_url, - echo=False, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=3600, -) + +def _build_engine_kwargs() -> dict: + url = settings.database_url + kwargs: dict = {"echo": False} + if not url.startswith("sqlite"): + kwargs.update( + pool_size=10, + max_overflow=20, + pool_pre_ping=True, + pool_recycle=3600, + ) + return kwargs + + +engine = create_async_engine(settings.database_url, **_build_engine_kwargs()) async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) diff --git a/src/cartsnitch_api/models/base.py b/src/cartsnitch_api/models/base.py index f93cf79..7381b58 100644 --- a/src/cartsnitch_api/models/base.py +++ b/src/cartsnitch_api/models/base.py @@ -1,30 +1,43 @@ """Base model and mixins for all CartSnitch ORM models.""" import uuid -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import DateTime, func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from cartsnitch_api.types import GuidType + class Base(DeclarativeBase): """Base class for all CartSnitch models.""" +def _utcnow(): + return datetime.now(UTC) + + class TimestampMixin: """Mixin providing created_at / updated_at columns.""" created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False + DateTime(timezone=True), + server_default=func.now(), + default=_utcnow, + nullable=False, ) updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + DateTime(timezone=True), + server_default=func.now(), + onupdate=_utcnow, + default=_utcnow, + nullable=False, ) class UUIDPrimaryKeyMixin: - """Mixin providing a UUID primary key.""" + """Mixin providing a UUID primary key using GuidType for cross-DB compatibility.""" id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + GuidType(), primary_key=True, default=uuid.uuid4 ) diff --git a/src/cartsnitch_api/models/purchase.py b/src/cartsnitch_api/models/purchase.py index 97f577d..5f59694 100644 --- a/src/cartsnitch_api/models/purchase.py +++ b/src/cartsnitch_api/models/purchase.py @@ -18,7 +18,7 @@ from sqlalchemy import ( ) from sqlalchemy.orm import Mapped, mapped_column, relationship -from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin, _utcnow if TYPE_CHECKING: from cartsnitch_api.models.price import PriceHistory @@ -46,6 +46,7 @@ class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base): ingested_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), + default=_utcnow, nullable=False, ) diff --git a/src/cartsnitch_api/models/user.py b/src/cartsnitch_api/models/user.py index 5b51778..6d70c1c 100644 --- a/src/cartsnitch_api/models/user.py +++ b/src/cartsnitch_api/models/user.py @@ -1,6 +1,7 @@ """User and UserStoreAccount models.""" import secrets +import uuid from datetime import datetime from typing import TYPE_CHECKING @@ -10,7 +11,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from cartsnitch_api.constants import AccountStatus from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin -from cartsnitch_api.types import EncryptedJSON +from cartsnitch_api.types import EncryptedJSON, GuidType if TYPE_CHECKING: from cartsnitch_api.models.purchase import Purchase @@ -22,11 +23,13 @@ class User(TimestampMixin, Base): __tablename__ = "users" - id: Mapped[str] = mapped_column(Text, primary_key=True) + id: Mapped[uuid.UUID] = mapped_column(GuidType(), primary_key=True, default=uuid.uuid4) email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True) display_name: Mapped[str | None] = mapped_column(String(100)) - email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false") + email_verified: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="false" + ) image: Mapped[str | None] = mapped_column(Text, nullable=True) email_inbound_token: Mapped[str] = mapped_column( String(22), diff --git a/src/cartsnitch_api/types.py b/src/cartsnitch_api/types.py index 13a7820..7b11225 100644 --- a/src/cartsnitch_api/types.py +++ b/src/cartsnitch_api/types.py @@ -1,9 +1,10 @@ """Custom SQLAlchemy column types.""" import json +import uuid as uuid_lib from cryptography.fernet import Fernet -from sqlalchemy import Text +from sqlalchemy import String, Text from sqlalchemy.types import TypeDecorator from cartsnitch_api.config import settings @@ -34,3 +35,27 @@ class EncryptedJSON(TypeDecorator): return None decrypted = _get_fernet().decrypt(value.encode()) return json.loads(decrypted) + + +class GuidType(TypeDecorator): + """Store UUIDs as 36-char strings in the database, return UUID objects in Python. + + Uses PostgreSQL UUID type when available, String(36) otherwise (SQLite). + """ + + impl = String(36) + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + if isinstance(value, uuid_lib.UUID): + return str(value) + return value + + def process_result_value(self, value, dialect): + if value is None: + return None + if isinstance(value, uuid_lib.UUID): + return value + return uuid_lib.UUID(value) diff --git a/tests/conftest.py b/tests/conftest.py index 6439552..341c36f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import secrets import uuid from datetime import UTC, datetime, timedelta +import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import create_engine, event, text @@ -19,6 +20,8 @@ from cartsnitch_api.database import get_db from cartsnitch_api.main import create_app from cartsnitch_api.models import Base +aiosqlite.register_adapter(uuid.UUID, lambda u: str(u)) + TEST_JWT_SECRET = secrets.token_urlsafe(32) TEST_SERVICE_KEY = secrets.token_urlsafe(32) TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" @@ -53,17 +56,27 @@ def disable_rate_limiting(): def engine(): """Sync in-memory SQLite engine for model unit tests. - Strips PostgreSQL-specific server_default expressions so SQLite can + Strips ALL PostgreSQL-specific server_default expressions so SQLite can handle all column inserts without missing-function errors. """ eng = create_engine("sqlite:///:memory:") - for table in Base.metadata.tables.values(): - for col in table.columns.values(): + @event.listens_for(eng, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + for metadata_table in Base.metadata.tables.values(): + for col in metadata_table.columns.values(): sd = col.server_default if sd is not None: + if not hasattr(sd, "expression"): + col.server_default = None + continue expr_str = str(sd.expression).lower() - if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: + _pg_fns = ("gen_random_uuid", "gen_random_bytes", "now()") + if any(pg_fn in expr_str for pg_fn in _pg_fns): col.server_default = None Base.metadata.create_all(eng) @@ -93,8 +106,12 @@ async def db_engine(): for col in table.columns.values(): sd = col.server_default if sd is not None: + if not hasattr(sd, "expression"): + col.server_default = None + continue expr_str = str(sd.expression).lower() - if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: + _pg_fns = ("gen_random_uuid", "gen_random_bytes", "now()") + if any(pg_fn in expr_str for pg_fn in _pg_fns): col.server_default = None async with engine.begin() as conn: diff --git a/tests/test_encrypted_json.py b/tests/test_encrypted_json.py index 07cf44c..5b08a65 100644 --- a/tests/test_encrypted_json.py +++ b/tests/test_encrypted_json.py @@ -18,12 +18,16 @@ from cartsnitch_api.models.user import User, UserStoreAccount def engine(): eng = create_engine("sqlite:///:memory:") - for table in Base.metadata.tables.values(): - for col in table.columns.values(): + for metadata_table in Base.metadata.tables.values(): + for col in metadata_table.columns.values(): sd = col.server_default if sd is not None: + if not hasattr(sd, "expression"): + col.server_default = None + continue expr_str = str(sd.expression).lower() - if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: + _pg_fns = ("gen_random_uuid", "gen_random_bytes", "now()") + if any(pg_fn in expr_str for pg_fn in _pg_fns): col.server_default = None Base.metadata.create_all(eng)