fix(api): revert auth/type regressions from standalone sync, keep email-in feature only

- Revert auth/dependencies.py to cookie+Bearer dual auth with str user IDs
- Add GET /auth/me/email-in-address endpoint for receipt email routing
- Update User model: add email_inbound_token, change id/store_id/user_id to str
- Update AuthService and UserResponse to use str user IDs
- Update route count test: 33 -> 34 routes
- Restore e2e test for email-in-address endpoint

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
CartSnitch Engineer Bot
2026-04-03 09:40:39 +00:00
parent 18ff5795ac
commit bbbf97d027
18 changed files with 360 additions and 236 deletions
+164
View File
@@ -0,0 +1,164 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
packages: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: cartsnitch/api
jobs:
lint:
runs-on: runners-cartsnitch
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- run: pip install ruff
- name: Ruff lint
run: ruff check .
- name: Ruff format check
run: ruff format --check .
typecheck:
runs-on: runners-cartsnitch
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- name: Install cartsnitch-common from GitHub
run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git"
- run: pip install -e ".[dev]" mypy
- name: Type check
run: mypy src/cartsnitch_api
test:
runs-on: runners-cartsnitch
services:
postgres:
image: postgres:15-alpine
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
env:
POSTGRES_USER: cartsnitch
POSTGRES_PASSWORD: cartsnitch_test
POSTGRES_DB: cartsnitch_test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test
CARTSNITCH_REDIS_URL: redis://localhost:6379/0
CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- name: Install cartsnitch-common from GitHub
run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git"
- run: pip install -e ".[dev]"
- name: Run tests
run: pytest --tb=short -q
build-and-push:
runs-on: runners-cartsnitch
needs: [lint, test]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generate CalVer tag
id: calver
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
DATE_TAG=$(date -u +%Y.%m.%d)
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
if [ -z "$EXISTING" ]; then
VERSION="$DATE_TAG"
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
VERSION="${DATE_TAG}.2"
else
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
fi
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
echo "CalVer tag: $VERSION"
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=sha,prefix=sha-
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
target: prod
- name: Create git tag
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
git tag "v${{ steps.calver.outputs.version }}"
git push origin "v${{ steps.calver.outputs.version }}"
+4 -11
View File
@@ -1,5 +1,3 @@
# Stage 1: Build dependencies
# Build context is the repo root. Paths below are relative to the root.
FROM python:3.12-slim AS build
RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -8,21 +6,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY api/pyproject.toml ./
COPY api/src/ ./src/
COPY pyproject.toml ./
COPY src/ ./src/
RUN pip install --no-cache-dir --prefix=/install .
# Stage 2: Production image
FROM python:3.12-slim AS prod
RUN apt-get update && apt-get install -y --no-install-recommends libpq5 && rm -rf /var/lib/apt/lists/*
WORKDIR /app
RUN adduser --system --group --uid 1000 app
COPY --from=build /install /usr/local
COPY api/src/ ./src/
COPY api/alembic.ini ./
COPY api/alembic/ ./alembic/
COPY src/ ./src/
USER 1000
EXPOSE 8000
@@ -30,4 +23,4 @@ EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=3s \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"
CMD ["sh", "-c", "python -m alembic upgrade head && uvicorn cartsnitch_api.main:app --host 0.0.0.0 --port 8000"]
CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
+4 -9
View File
@@ -5,8 +5,6 @@ Sessions are verified by querying the shared sessions table directly.
"""
from datetime import UTC, datetime
from uuid import UUID
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy import text
@@ -23,10 +21,10 @@ bearer_scheme = HTTPBearer(auto_error=False)
SESSION_COOKIE_NAME = "better-auth.session_token"
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
async def _validate_session_token(token: str, db: AsyncSession) -> str:
"""Validate a Better-Auth session token against the sessions table.
Returns the user_id (as UUID) if the session is valid and not expired.
Returns the user_id (as str) if the session is valid and not expired.
"""
result = await db.execute(
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
@@ -41,9 +39,6 @@ async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
)
user_id, expires_at = row
# SQLite stores datetimes as ISO strings; parse if necessary
if isinstance(expires_at, str):
expires_at = datetime.fromisoformat(expires_at)
if expires_at.tzinfo is None:
# Treat naive datetimes as UTC
expires_at = expires_at.replace(tzinfo=UTC)
@@ -54,14 +49,14 @@ async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
detail="Session expired",
)
return UUID(str(user_id))
return str(user_id)
async def get_current_user(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: AsyncSession = Depends(get_db),
) -> UUID:
) -> str:
"""Extract and validate the session token from cookie or Authorization header.
Checks in order:
+22 -11
View File
@@ -5,15 +5,15 @@ the Better-Auth service (auth/). This router provides user profile
endpoints that query our own user data from the shared database.
"""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.models import User
from cartsnitch_api.schemas import (
EmailInAddressResponse,
UpdateUserRequest,
UserResponse,
)
@@ -22,9 +22,14 @@ from cartsnitch_api.services.auth import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
class EmailInAddressResponse(BaseModel):
email_address: str
instructions: str
@router.get("/me", response_model=UserResponse)
async def get_me(
user_id: UUID = Depends(get_current_user),
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
@@ -39,7 +44,7 @@ async def get_me(
@router.patch("/me", response_model=UserResponse)
async def update_me(
body: UpdateUserRequest,
user_id: UUID = Depends(get_current_user),
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
@@ -55,7 +60,7 @@ async def update_me(
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
async def delete_me(
user_id: UUID = Depends(get_current_user),
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
@@ -69,13 +74,19 @@ async def delete_me(
@router.get("/me/email-in-address", response_model=EmailInAddressResponse)
async def get_email_in_address(
user_id: UUID = Depends(get_current_user),
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
try:
return await svc.get_email_in_address(user_id)
except LookupError:
result = await db.execute(select(User.email_inbound_token).where(User.id == user_id))
token = result.scalar_one_or_none()
if not token:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found"
) from None
return EmailInAddressResponse(
email_address=f"receipts+{token}@receipts.cartsnitch.com",
instructions=(
"Forward your digital receipt emails to this address. "
"We currently support Meijer, Kroger, and Target receipt emails."
),
)
+2
View File
@@ -19,6 +19,8 @@ class Settings(BaseSettings):
# Valid Fernet key for local dev — MUST be overridden in production
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
auth_service_url: str = "http://auth:3001"
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
receiptwitness_url: str = "http://receiptwitness:8001"
+5 -38
View File
@@ -1,39 +1,12 @@
"""Base model and mixins for all CartSnitch ORM models."""
import uuid as uuid_lib
import uuid
from datetime import datetime
from sqlalchemy import DateTime, String, TypeDecorator, func
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class UUIDString(TypeDecorator):
"""Store UUIDs as VARCHAR(36) strings in all dialects.
This handles the fundamental mismatch between Python's uuid.UUID objects
(used everywhere in application code) and SQLite's lack of a native UUID type.
- On INSERT: converts uuid.UUID → str
- On SELECT: returns uuid.UUID (so SQLAlchemy 2.0 sentinel tracking matches correctly)
"""
impl = String(36)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid_lib.UUID):
return str(value)
return value # already a string
def process_result_value(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid_lib.UUID):
return value
return uuid_lib.UUID(value) # convert str → UUID for correct sentinel tracking
class Base(DeclarativeBase):
"""Base class for all CartSnitch models."""
@@ -50,14 +23,8 @@ class TimestampMixin:
class UUIDPrimaryKeyMixin:
"""Mixin providing a UUID primary key.
"""Mixin providing a UUID primary key."""
Uses UUIDString so all DB dialects store the full 36-char UUID string
without truncation, while Python code always works with uuid.UUID objects.
"""
id: Mapped[uuid_lib.UUID] = mapped_column(
UUIDString(),
primary_key=True,
default=uuid_lib.uuid4,
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
)
+2 -2
View File
@@ -32,8 +32,8 @@ class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
__tablename__ = "purchases"
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[str] = mapped_column(ForeignKey("stores.id"), nullable=False)
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
purchase_date: Mapped[date] = mapped_column(Date, nullable=False)
+5 -5
View File
@@ -1,11 +1,10 @@
"""User and UserStoreAccount models."""
import secrets
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import AccountStatus
@@ -17,11 +16,12 @@ if TYPE_CHECKING:
from cartsnitch_api.models.store import Store
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
class User(TimestampMixin, Base):
"""Application user."""
__tablename__ = "users"
id: Mapped[str] = mapped_column(Text, primary_key=True)
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
display_name: Mapped[str | None] = mapped_column(String(100))
@@ -43,8 +43,8 @@ class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
__tablename__ = "user_store_accounts"
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[str] = mapped_column(ForeignKey("stores.id"), nullable=False)
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
+1 -9
View File
@@ -1,7 +1,6 @@
"""Pydantic v2 request/response schemas for all API endpoints."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field
@@ -16,7 +15,7 @@ class UpdateUserRequest(BaseModel):
class UserResponse(BaseModel):
id: UUID
id: str
email: str
display_name: str
created_at: datetime
@@ -265,13 +264,6 @@ class ErrorResponse(BaseModel):
code: str | None = None
# ---------- Email-In ----------
class EmailInAddressResponse(BaseModel):
email_address: str
instructions: str
# Rebuild forward refs
ProductDetailResponse.model_rebuild()
PriceTrendResponse.model_rebuild()
+7 -32
View File
@@ -5,8 +5,6 @@ handled by the Better-Auth service (auth/). This service provides
user lookup and profile update operations for the API gateway.
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,14 +13,10 @@ class AuthService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_user(self, user_id: UUID) -> dict:
async def get_user(self, user_id: str) -> dict:
from cartsnitch_api.models import User
# Use str() to ensure consistent string comparison for UUID columns
# (works with both SQLite VARCHAR and Postgres UUID storage)
result = await self.db.execute(
select(User).where(User.id == str(user_id))
)
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
@@ -34,11 +28,10 @@ class AuthService:
"created_at": user.created_at,
}
async def update_user(self, user_id: UUID, **fields) -> dict:
async def update_user(self, user_id: str, **fields) -> dict:
from cartsnitch_api.models import User
user_id_str = str(user_id)
result = await self.db.execute(select(User).where(User.id == user_id_str))
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
@@ -47,7 +40,7 @@ class AuthService:
user.display_name = fields["display_name"]
if "email" in fields and fields["email"] is not None:
existing = await self.db.execute(
select(User).where(User.email == fields["email"], User.id != user_id_str)
select(User).where(User.email == fields["email"], User.id != user_id)
)
if existing.scalar_one_or_none():
raise ValueError("Email already in use")
@@ -63,31 +56,13 @@ class AuthService:
"created_at": user.created_at,
}
async def delete_user(self, user_id: UUID) -> None:
async def delete_user(self, user_id: str) -> None:
from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == str(user_id)))
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
await self.db.delete(user)
await self.db.commit()
async def get_email_in_address(self, user_id: UUID) -> dict:
from cartsnitch_api.models import User
result = await self.db.execute(
select(User.email_inbound_token).where(User.id == str(user_id))
)
token = result.scalar_one_or_none()
if not token:
raise LookupError("Email inbound token not found")
return {
"email_address": f"receipts+{token}@receipts.cartsnitch.com",
"instructions": (
"Forward your digital receipt emails to this address. "
"We currently support Meijer, Kroger, and Target receipt emails."
),
}
+3 -4
View File
@@ -141,7 +141,6 @@ async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_o
user_id = str(uuid.uuid4())
email = user_overrides.get("email", "test@example.com")
display_name = user_overrides.get("display_name", "Test User")
email_inbound_token = user_overrides.get("email_inbound_token", secrets.token_urlsafe(16))
session_token = secrets.token_urlsafe(32)
session_id = str(uuid.uuid4())
now = datetime.now(UTC).isoformat()
@@ -150,15 +149,15 @@ async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_o
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
),
{
"id": user_id,
"email": email,
"hashed_password": "not-used-with-better-auth",
"display_name": display_name,
"email_inbound_token": email_inbound_token,
"email_verified": False,
"created_at": now,
"updated_at": now,
},
+3 -3
View File
@@ -88,15 +88,15 @@ async def test_expired_session_rejected(client, db_engine):
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
),
{
"id": user_id,
"email": "expired@example.com",
"hp": "unused",
"dn": "Expired User",
"eit": secrets.token_urlsafe(16),
"ev": False,
"ca": now,
"ua": now,
},
+17 -45
View File
@@ -7,11 +7,10 @@ exercise cross-resource queries against real data.
from datetime import date, timedelta
from decimal import Decimal
import uuid
from sqlalchemy import text
from uuid import UUID
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import (
@@ -27,27 +26,24 @@ from cartsnitch_api.models import (
# Shared test constants
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
BAD_UUID = "not-a-uuid"
# Anchor date relative to today so coupon validity windows stay in the future
ANCHOR_DATE = date.today()
# Fixed anchor date for deterministic tests
ANCHOR_DATE = date(2026, 3, 15)
@pytest.fixture
async def seed_data(db_engine, auth_headers):
"""Seed a full dataset and return identifiers for test assertions."""
import uuid
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
# -- Stores --
meijer = Store(name="Meijer", slug="meijer", id=uuid.uuid4())
kroger = Store(name="Kroger", slug="kroger", id=uuid.uuid4())
target = Store(name="Target", slug="target", id=uuid.uuid4())
meijer = Store(name="Meijer", slug="meijer")
kroger = Store(name="Kroger", slug="kroger")
target = Store(name="Target", slug="target")
session.add_all([meijer, kroger, target])
await session.flush()
# -- Products --
cheerios = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Cheerios 18oz",
category="pantry",
brand="General Mills",
@@ -56,7 +52,6 @@ async def seed_data(db_engine, auth_headers):
upc_variants=["016000275263"],
)
milk = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Whole Milk 1gal",
category="dairy",
brand="Meijer",
@@ -64,7 +59,6 @@ async def seed_data(db_engine, auth_headers):
size_unit="gal",
)
chicken = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Chicken Breast 1lb",
category="meat",
brand=None,
@@ -81,7 +75,6 @@ async def seed_data(db_engine, auth_headers):
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
prices.append(
PriceHistory(
id=uuid.uuid4(),
normalized_product_id=cheerios.id,
store_id=meijer.id,
observed_date=today - timedelta(days=60 - i * 30),
@@ -93,7 +86,6 @@ async def seed_data(db_engine, auth_headers):
for i in range(3):
prices.append(
PriceHistory(
id=uuid.uuid4(),
normalized_product_id=cheerios.id,
store_id=kroger.id,
observed_date=today - timedelta(days=60 - i * 30),
@@ -104,7 +96,6 @@ async def seed_data(db_engine, auth_headers):
# Milk at Meijer
prices.append(
PriceHistory(
id=uuid.uuid4(),
normalized_product_id=milk.id,
store_id=meijer.id,
observed_date=today - timedelta(days=7),
@@ -115,7 +106,6 @@ async def seed_data(db_engine, auth_headers):
# Milk at Kroger
prices.append(
PriceHistory(
id=uuid.uuid4(),
normalized_product_id=milk.id,
store_id=kroger.id,
observed_date=today - timedelta(days=5),
@@ -126,7 +116,6 @@ async def seed_data(db_engine, auth_headers):
# Chicken at Target
prices.append(
PriceHistory(
id=uuid.uuid4(),
normalized_product_id=chicken.id,
store_id=target.id,
observed_date=today - timedelta(days=3),
@@ -137,29 +126,19 @@ async def seed_data(db_engine, auth_headers):
session.add_all(prices)
await session.flush()
# -- Purchases (need the user_id from the registered test user) --
# Extract session_token from auth_headers, then look up the real user_id
import http.cookies
cookie_header = auth_headers.get("Cookie", "")
cookies = http.cookies.SimpleCookie()
cookies.load(cookie_header)
session_token = cookies.get("better-auth.session_token").value if "better-auth.session_token" in cookie_header else None
if session_token is None:
raise RuntimeError("seed_data fixture requires cookie-based auth session token")
# -- Get the user_id from the session token in auth_headers --
cookie_str = auth_headers.get("Cookie", "")
session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else ""
# Look up the real user_id from the sessions table
row = await session.execute(
result = await session.execute(
text("SELECT user_id FROM sessions WHERE token = :token"),
{"token": session_token}
{"token": session_token},
)
session_row = row.fetchone()
if session_row is None:
raise RuntimeError("Session not found for session token in auth_headers")
real_user_id = session_row[0]
row = result.first()
user_id = UUID(row[0])
purchase1 = Purchase(
id=uuid.uuid4(),
user_id=uuid.UUID(real_user_id),
user_id=user_id,
store_id=meijer.id,
receipt_id="meijer-2026-001",
purchase_date=today - timedelta(days=10),
@@ -168,8 +147,7 @@ async def seed_data(db_engine, auth_headers):
tax=Decimal("1.95"),
)
purchase2 = Purchase(
id=uuid.uuid4(),
user_id=uuid.UUID(real_user_id),
user_id=user_id,
store_id=kroger.id,
receipt_id="kroger-2026-001",
purchase_date=today - timedelta(days=5),
@@ -182,7 +160,6 @@ async def seed_data(db_engine, auth_headers):
# -- Purchase Items --
item1 = PurchaseItem(
id=uuid.uuid4(),
purchase_id=purchase1.id,
product_name_raw="Cheerios 18oz Box",
quantity=Decimal("1"),
@@ -191,7 +168,6 @@ async def seed_data(db_engine, auth_headers):
normalized_product_id=cheerios.id,
)
item2 = PurchaseItem(
id=uuid.uuid4(),
purchase_id=purchase1.id,
product_name_raw="Meijer Whole Milk 1gal",
quantity=Decimal("2"),
@@ -200,7 +176,6 @@ async def seed_data(db_engine, auth_headers):
normalized_product_id=milk.id,
)
item3 = PurchaseItem(
id=uuid.uuid4(),
purchase_id=purchase2.id,
product_name_raw="KRO CHEERIOS 18OZ",
quantity=Decimal("1"),
@@ -213,7 +188,6 @@ async def seed_data(db_engine, auth_headers):
# -- Coupons --
coupon1 = Coupon(
id=uuid.uuid4(),
store_id=meijer.id,
normalized_product_id=cheerios.id,
title="$1 off Cheerios",
@@ -224,7 +198,6 @@ async def seed_data(db_engine, auth_headers):
valid_to=today + timedelta(days=30),
)
coupon2 = Coupon(
id=uuid.uuid4(),
store_id=kroger.id,
normalized_product_id=None,
title="10% off dairy",
@@ -239,7 +212,6 @@ async def seed_data(db_engine, auth_headers):
# -- Shrinkflation events --
shrink = ShrinkflationEvent(
id=uuid.uuid4(),
normalized_product_id=cheerios.id,
detected_date=today - timedelta(days=15),
old_size="20",
@@ -274,7 +246,7 @@ async def seed_data(db_engine, auth_headers):
return {
"headers": auth_headers,
"user_id": real_user_id,
"user_id": user_id,
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
+3 -3
View File
@@ -65,15 +65,15 @@ class TestSessionValidation:
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
),
{
"id": user_id,
"email": "expired@e2e.com",
"hp": "unused",
"dn": "Expired User",
"eit": secrets.token_urlsafe(16),
"ev": False,
"ca": now,
"ua": now,
},
+68
View File
@@ -5,6 +5,74 @@ import pytest
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
@pytest.mark.asyncio
class TestRegistrationErrors:
"""Validation errors during user registration."""
async def test_short_password(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
)
assert resp.status_code == 422
async def test_invalid_email(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
)
assert resp.status_code == 422
async def test_missing_fields(self, client, db_engine):
resp = await client.post("/auth/register", json={})
assert resp.status_code == 422
async def test_empty_display_name(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
)
assert resp.status_code == 422
async def test_duplicate_email(self, client, db_engine):
payload = {
"email": "dupe@example.com",
"password": "securepass123",
"display_name": "First",
}
first = await client.post("/auth/register", json=payload)
assert first.status_code == 201
second = await client.post("/auth/register", json=payload)
assert second.status_code == 409
@pytest.mark.asyncio
class TestLoginErrors:
"""Login failure modes."""
async def test_wrong_password(self, client, db_engine):
await client.post(
"/auth/register",
json={
"email": "login-err@example.com",
"password": "correctpass1",
"display_name": "Login",
},
)
resp = await client.post(
"/auth/login",
json={"email": "login-err@example.com", "password": "wrongpass123"},
)
assert resp.status_code == 401
async def test_nonexistent_user(self, client, db_engine):
resp = await client.post(
"/auth/login",
json={"email": "nobody@example.com", "password": "doesntmatter"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
class TestNotFoundErrors:
"""404 responses for missing resources."""
+4 -5
View File
@@ -15,12 +15,11 @@ async def test_404_returns_structured_error(client):
@pytest.mark.asyncio
async def test_validation_error_returns_422_with_field_errors(client, auth_headers):
async def test_validation_error_returns_422_with_field_errors(client):
"""Invalid request body should return structured validation errors."""
resp = await client.patch(
"/auth/me",
headers=auth_headers,
json={"display_name": ""},
resp = await client.post(
"/auth/register",
json={"email": "not-an-email", "password": "short", "display_name": ""},
)
assert resp.status_code == 422
body = resp.json()
+5 -2
View File
@@ -6,7 +6,10 @@ from httpx import ASGITransport, AsyncClient
from cartsnitch_api.main import app
EXPECTED_ROUTES = [
# Auth (4 — register/login/refresh handled by Better-Auth service)
# Auth (7)
("post", "/auth/register"),
("post", "/auth/login"),
("post", "/auth/refresh"),
("get", "/auth/me"),
("patch", "/auth/me"),
("delete", "/auth/me"),
@@ -87,4 +90,4 @@ async def test_route_count():
if method in ("get", "post", "put", "delete", "patch"):
count += 1
assert count == 31, f"Expected 31 routes, found {count}"
assert count == 34, f"Expected 34 routes, found {count}"
+41 -57
View File
@@ -2,81 +2,44 @@
import secrets
import uuid
from datetime import UTC, datetime, date, timedelta
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import Purchase, PurchaseItem, Store
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
@pytest.fixture
async def purchase_data(db_engine):
"""Seed a user, store, purchase, and items using session-cookie auth."""
"""Seed a user, store, purchase, items, and a valid session."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
user_id = str(uuid.uuid4())
session_token = secrets.token_urlsafe(32)
now = datetime.now(UTC).isoformat()
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
# Create the user
await session.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)"
),
{
"id": user_id,
"email": "buyer@example.com",
"hashed_password": "not-used-with-better-auth",
"display_name": "Buyer",
"email_inbound_token": secrets.token_urlsafe(16),
"created_at": now,
"updated_at": now,
},
user = User(
email="buyer@example.com",
hashed_password="not-used-with-better-auth",
display_name="Buyer",
)
# Create the session
await session.execute(
text(
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
),
{
"id": str(uuid.uuid4()),
"token": session_token,
"user_id": user_id,
"expires_at": expires,
"created_at": now,
"updated_at": now,
},
)
# Create the store
store = Store(name="Kroger", slug="kroger", id=uuid.uuid4())
session.add(store)
await session.flush()
store = Store(name="Kroger", slug="kroger")
session.add_all([user, store])
await session.commit()
await session.refresh(user)
await session.refresh(store)
# Create the purchase
purchase = Purchase(
id=uuid.uuid4(),
user_id=uuid.UUID(user_id),
user_id=user.id,
store_id=store.id,
receipt_id="receipt-001",
purchase_date=date(2026, 3, 10),
total=Decimal("42.50"),
)
session.add(purchase)
await session.flush()
await session.commit()
await session.refresh(purchase)
# Create the purchase item
item = PurchaseItem(
id=uuid.uuid4(),
purchase_id=purchase.id,
product_name_raw="Organic Milk 1gal",
quantity=Decimal("1"),
@@ -86,12 +49,33 @@ async def purchase_data(db_engine):
session.add(item)
await session.commit()
return {
"user_id": user_id,
"store": store,
"purchase": purchase,
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
}
# Create a session token directly in the sessions table
session_token = secrets.token_urlsafe(32)
now = datetime.now(UTC).isoformat()
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
),
{
"id": str(uuid.uuid4()),
"token": session_token,
"user_id": str(user.id),
"expires_at": expires,
"created_at": now,
"updated_at": now,
},
)
return {
"user": user,
"store": store,
"purchase": purchase,
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
}
@pytest.mark.asyncio