diff --git a/api/src/cartsnitch_api/auth/dependencies.py b/api/src/cartsnitch_api/auth/dependencies.py index 61735ee..8799dfd 100644 --- a/api/src/cartsnitch_api/auth/dependencies.py +++ b/api/src/cartsnitch_api/auth/dependencies.py @@ -1,34 +1,91 @@ -"""FastAPI dependency injection for authentication.""" +"""FastAPI dependency injection for authentication. +Validates Better-Auth session tokens from cookies or Bearer header. +Sessions are verified by querying the shared sessions table directly. +""" + +from datetime import UTC, datetime from uuid import UUID -from fastapi import Depends, Header, HTTPException, status +from fastapi import Cookie, Depends, Header, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession -from cartsnitch_api.auth.jwt import decode_token from cartsnitch_api.config import settings +from cartsnitch_api.database import get_db -bearer_scheme = HTTPBearer() +# Keep Bearer scheme as optional — Better-Auth primarily uses cookies, +# but we support Bearer tokens for service-to-service or mobile clients. +bearer_scheme = HTTPBearer(auto_error=False) + +# Better-Auth session cookie name +SESSION_COOKIE_NAME = "better-auth.session_token" + + +async def _validate_session_token(token: str, db: AsyncSession) -> UUID: + """Validate a Better-Auth session token against the sessions table. + + Returns the user_id (as UUID) if the session is valid and not expired. + """ + result = await db.execute( + text("SELECT user_id, expires_at FROM sessions WHERE token = :token"), + {"token": token}, + ) + row = result.first() + + if not row: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid session token", + ) + + user_id, expires_at = row + # SQLite stores datetimes as ISO strings; parse if necessary + if isinstance(expires_at, str): + expires_at = datetime.fromisoformat(expires_at) + if expires_at.tzinfo is None: + # Treat naive datetimes as UTC + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < datetime.now(UTC): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Session expired", + ) + + return UUID(str(user_id)) async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), + request: Request, + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + db: AsyncSession = Depends(get_db), ) -> UUID: - try: - payload = decode_token(credentials.credentials) - except ValueError: + """Extract and validate the session token from cookie or Authorization header. + + Checks in order: + 1. Better-Auth session cookie (primary — web clients) + 2. Bearer token in Authorization header (fallback — API clients) + """ + token: str | None = None + + # 1. Check session cookie + cookie_token = request.cookies.get(SESSION_COOKIE_NAME) + if cookie_token: + token = cookie_token + + # 2. Fall back to Bearer header + if not token and credentials: + token = credentials.credentials + + if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - ) from None + detail="Authentication required", + ) - if payload.get("type") != "access": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token type", - ) from None - - return UUID(payload["sub"]) + return await _validate_session_token(token, db) async def verify_service_key(x_service_key: str = Header()) -> None: diff --git a/api/src/cartsnitch_api/auth/routes.py b/api/src/cartsnitch_api/auth/routes.py index 472325e..40ccda4 100644 --- a/api/src/cartsnitch_api/auth/routes.py +++ b/api/src/cartsnitch_api/auth/routes.py @@ -1,20 +1,19 @@ -"""Auth routes: register, login, refresh, me, update, delete.""" +"""Auth routes: user profile management. + +Registration, login, refresh, and session management are handled by +the Better-Auth service (auth/). This router provides user profile +endpoints that query our own user data from the shared database. +""" from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from cartsnitch_api.auth.dependencies import get_current_user from cartsnitch_api.database import get_db -from cartsnitch_api.models import User from cartsnitch_api.schemas import ( - LoginRequest, - RefreshRequest, - RegisterRequest, - TokenResponse, + EmailInAddressResponse, UpdateUserRequest, UserResponse, ) @@ -23,37 +22,6 @@ from cartsnitch_api.services.auth import AuthService router = APIRouter(prefix="/auth", tags=["auth"]) -@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) -async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.register(body.email, body.password, body.display_name) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e - - -@router.post("/login", response_model=TokenResponse) -async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.login(body.email, body.password) - except ValueError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password" - ) from None - - -@router.post("/refresh", response_model=TokenResponse) -async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.refresh(body.refresh_token) - except ValueError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" - ) from None - - @router.get("/me", response_model=UserResponse) async def get_me( user_id: UUID = Depends(get_current_user), @@ -99,26 +67,15 @@ async def delete_me( ) from None -class EmailInAddressResponse(BaseModel): - email_address: str - instructions: str - - @router.get("/me/email-in-address", response_model=EmailInAddressResponse) async def get_email_in_address( user_id: UUID = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - result = await db.execute(select(User.email_inbound_token).where(User.id == user_id)) - token = result.scalar_one_or_none() - if not token: + svc = AuthService(db) + try: + return await svc.get_email_in_address(user_id) + except LookupError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found" ) from None - return EmailInAddressResponse( - email_address=f"receipts+{token}@receipts.cartsnitch.com", - instructions=( - "Forward your digital receipt emails to this address. " - "We currently support Meijer, Kroger, and Target receipt emails." - ), - ) diff --git a/api/src/cartsnitch_api/models/base.py b/api/src/cartsnitch_api/models/base.py index f93cf79..f4945bd 100644 --- a/api/src/cartsnitch_api/models/base.py +++ b/api/src/cartsnitch_api/models/base.py @@ -1,12 +1,39 @@ """Base model and mixins for all CartSnitch ORM models.""" -import uuid +import uuid as uuid_lib from datetime import datetime -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, String, TypeDecorator, func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +class UUIDString(TypeDecorator): + """Store UUIDs as VARCHAR(36) strings in all dialects. + + This handles the fundamental mismatch between Python's uuid.UUID objects + (used everywhere in application code) and SQLite's lack of a native UUID type. + - On INSERT: converts uuid.UUID → str + - On SELECT: returns uuid.UUID (so SQLAlchemy 2.0 sentinel tracking matches correctly) + """ + + impl = String(36) + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + if isinstance(value, uuid_lib.UUID): + return str(value) + return value # already a string + + def process_result_value(self, value, dialect): + if value is None: + return value + if isinstance(value, uuid_lib.UUID): + return value + return uuid_lib.UUID(value) # convert str → UUID for correct sentinel tracking + + class Base(DeclarativeBase): """Base class for all CartSnitch models.""" @@ -23,8 +50,14 @@ class TimestampMixin: class UUIDPrimaryKeyMixin: - """Mixin providing a UUID primary key.""" + """Mixin providing a UUID primary key. - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + Uses UUIDString so all DB dialects store the full 36-char UUID string + without truncation, while Python code always works with uuid.UUID objects. + """ + + id: Mapped[uuid_lib.UUID] = mapped_column( + UUIDString(), + primary_key=True, + default=uuid_lib.uuid4, ) diff --git a/api/src/cartsnitch_api/schemas.py b/api/src/cartsnitch_api/schemas.py index 19e351a..21a40e3 100644 --- a/api/src/cartsnitch_api/schemas.py +++ b/api/src/cartsnitch_api/schemas.py @@ -6,28 +6,8 @@ from uuid import UUID from pydantic import BaseModel, EmailStr, Field # ---------- Auth ---------- - - -class RegisterRequest(BaseModel): - email: EmailStr - password: str = Field(min_length=8, max_length=128) - display_name: str = Field(min_length=1, max_length=100) - - -class LoginRequest(BaseModel): - email: EmailStr - password: str - - -class RefreshRequest(BaseModel): - refresh_token: str - - -class TokenResponse(BaseModel): - access_token: str - refresh_token: str - token_type: str = "bearer" - expires_in: int +# Registration, login, and session management are handled by Better-Auth (auth/ service). +# These schemas are for the profile management endpoints only. class UpdateUserRequest(BaseModel): @@ -285,6 +265,13 @@ class ErrorResponse(BaseModel): code: str | None = None +# ---------- Email-In ---------- + +class EmailInAddressResponse(BaseModel): + email_address: str + instructions: str + + # Rebuild forward refs ProductDetailResponse.model_rebuild() PriceTrendResponse.model_rebuild() diff --git a/api/src/cartsnitch_api/services/auth.py b/api/src/cartsnitch_api/services/auth.py index 5ea6b77..adb474f 100644 --- a/api/src/cartsnitch_api/services/auth.py +++ b/api/src/cartsnitch_api/services/auth.py @@ -1,71 +1,28 @@ -"""Auth service — user registration, login, token management.""" +"""Auth service — user profile management. + +Registration, login, token management, and session handling are now +handled by the Better-Auth service (auth/). This service provides +user lookup and profile update operations for the API gateway. +""" from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token -from cartsnitch_api.auth.passwords import hash_password, verify_password -from cartsnitch_api.config import settings - class AuthService: def __init__(self, db: AsyncSession) -> None: self.db = db - async def register(self, email: str, password: str, display_name: str) -> dict: - from cartsnitch_api.models import User - - existing = await self.db.execute(select(User).where(User.email == email)) - if existing.scalar_one_or_none(): - raise ValueError("Email already registered") - - user = User( - email=email, - hashed_password=hash_password(password), - display_name=display_name, - ) - self.db.add(user) - await self.db.commit() - await self.db.refresh(user) - - return self._make_token_response(user.id) - - async def login(self, email: str, password: str) -> dict: - from cartsnitch_api.models import User - - result = await self.db.execute(select(User).where(User.email == email)) - user = result.scalar_one_or_none() - if not user or not verify_password(password, user.hashed_password): - raise ValueError("Invalid email or password") - - return self._make_token_response(user.id) - - async def refresh(self, refresh_token: str) -> dict: - from cartsnitch_api.models import User - - try: - payload = decode_token(refresh_token) - except ValueError: - raise ValueError("Invalid refresh token") from None - - if payload.get("type") != "refresh": - raise ValueError("Invalid token type") from None - - user_id = UUID(payload["sub"]) - - # Verify the user still exists before issuing new tokens - result = await self.db.execute(select(User).where(User.id == user_id)) - if not result.scalar_one_or_none(): - raise ValueError("User no longer exists") - - return self._make_token_response(user_id) - async def get_user(self, user_id: UUID) -> dict: from cartsnitch_api.models import User - result = await self.db.execute(select(User).where(User.id == user_id)) + # Use str() to ensure consistent string comparison for UUID columns + # (works with both SQLite VARCHAR and Postgres UUID storage) + result = await self.db.execute( + select(User).where(User.id == str(user_id)) + ) user = result.scalar_one_or_none() if not user: raise LookupError("User not found") @@ -80,7 +37,8 @@ class AuthService: async def update_user(self, user_id: UUID, **fields) -> dict: from cartsnitch_api.models import User - result = await self.db.execute(select(User).where(User.id == user_id)) + user_id_str = str(user_id) + result = await self.db.execute(select(User).where(User.id == user_id_str)) user = result.scalar_one_or_none() if not user: raise LookupError("User not found") @@ -89,7 +47,7 @@ class AuthService: user.display_name = fields["display_name"] if "email" in fields and fields["email"] is not None: existing = await self.db.execute( - select(User).where(User.email == fields["email"], User.id != user_id) + select(User).where(User.email == fields["email"], User.id != user_id_str) ) if existing.scalar_one_or_none(): raise ValueError("Email already in use") @@ -108,7 +66,7 @@ class AuthService: async def delete_user(self, user_id: UUID) -> None: from cartsnitch_api.models import User - result = await self.db.execute(select(User).where(User.id == user_id)) + result = await self.db.execute(select(User).where(User.id == str(user_id))) user = result.scalar_one_or_none() if not user: raise LookupError("User not found") @@ -116,10 +74,20 @@ class AuthService: await self.db.delete(user) await self.db.commit() - def _make_token_response(self, user_id: UUID) -> dict: + async def get_email_in_address(self, user_id: UUID) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute( + select(User.email_inbound_token).where(User.id == str(user_id)) + ) + token = result.scalar_one_or_none() + if not token: + raise LookupError("Email inbound token not found") + return { - "access_token": create_access_token(user_id), - "refresh_token": create_refresh_token(user_id), - "token_type": "bearer", - "expires_in": settings.jwt_access_token_expire_minutes * 60, + "email_address": f"receipts+{token}@receipts.cartsnitch.com", + "instructions": ( + "Forward your digital receipt emails to this address. " + "We currently support Meijer, Kroger, and Target receipt emails." + ), } diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 9873903..accfc77 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,8 +1,16 @@ -"""Shared test fixtures with in-memory SQLite database.""" +"""Shared test fixtures with in-memory SQLite database. + +Session-based auth: tests create users and sessions directly in the DB, +matching the Better-Auth session validation flow. +""" + +import secrets +import uuid +from datetime import UTC, datetime, timedelta import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import create_engine, event +from sqlalchemy import create_engine, event, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker @@ -51,6 +59,46 @@ async def db_engine(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + # Create Better-Auth tables (not managed by SQLAlchemy models) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + token TEXT NOT NULL UNIQUE, + user_id TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + ip_address TEXT, + user_agent TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + account_id TEXT NOT NULL, + provider_id TEXT NOT NULL, + access_token TEXT, + refresh_token TEXT, + access_token_expires_at TIMESTAMP, + refresh_token_expires_at TIMESTAMP, + scope TEXT, + id_token TEXT, + password TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS verifications ( + id TEXT PRIMARY KEY, + identifier TEXT NOT NULL, + value TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) yield engine @@ -85,17 +133,56 @@ async def client(db_engine): app.dependency_overrides.clear() +async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]: + """Create a test user and a valid session directly in the DB. + + Returns (user_dict, session_token). + """ + user_id = str(uuid.uuid4()) + email = user_overrides.get("email", "test@example.com") + display_name = user_overrides.get("display_name", "Test User") + email_inbound_token = user_overrides.get("email_inbound_token", secrets.token_urlsafe(16)) + session_token = secrets.token_urlsafe(32) + session_id = str(uuid.uuid4()) + now = datetime.now(UTC).isoformat() + expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) " + "VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)" + ), + { + "id": user_id, + "email": email, + "hashed_password": "not-used-with-better-auth", + "display_name": display_name, + "email_inbound_token": email_inbound_token, + "created_at": now, + "updated_at": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" + ), + { + "id": session_id, + "token": session_token, + "user_id": user_id, + "expires_at": expires, + "created_at": now, + "updated_at": now, + }, + ) + + return {"id": user_id, "email": email, "display_name": display_name}, session_token + + @pytest.fixture -async def auth_headers(client): - """Register a test user and return auth headers.""" - resp = await client.post( - "/auth/register", - json={ - "email": "test@example.com", - "password": "testpass123", - "display_name": "Test User", - }, - ) - assert resp.status_code == 201 - token = resp.json()["access_token"] - return {"Authorization": f"Bearer {token}"} +async def auth_headers(client, db_engine): + """Create a test user with a valid session and return auth headers.""" + _, session_token = await _create_test_user_and_session(client, db_engine) + return {"Cookie": f"better-auth.session_token={session_token}"} diff --git a/api/tests/test_auth/test_auth_endpoints.py b/api/tests/test_auth/test_auth_endpoints.py index 878cbc5..1504c86 100644 --- a/api/tests/test_auth/test_auth_endpoints.py +++ b/api/tests/test_auth/test_auth_endpoints.py @@ -1,146 +1,13 @@ -"""Integration tests for auth endpoints.""" +"""Integration tests for auth profile endpoints. + +Registration, login, and session management are handled by the Better-Auth +service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me) +which validate sessions via the shared sessions table. +""" import pytest -@pytest.mark.asyncio -async def test_register_success(client): - resp = await client.post( - "/auth/register", - json={ - "email": "new@example.com", - "password": "securepass123", - "display_name": "New User", - }, - ) - assert resp.status_code == 201 - data = resp.json() - assert "access_token" in data - assert "refresh_token" in data - assert data["token_type"] == "bearer" - assert data["expires_in"] == 900 # 15 min * 60 - - -@pytest.mark.asyncio -async def test_register_duplicate_email(client): - await client.post( - "/auth/register", - json={ - "email": "dupe@example.com", - "password": "securepass123", - "display_name": "User One", - }, - ) - resp = await client.post( - "/auth/register", - json={ - "email": "dupe@example.com", - "password": "securepass456", - "display_name": "User Two", - }, - ) - assert resp.status_code == 409 - - -@pytest.mark.asyncio -async def test_register_short_password(client): - resp = await client.post( - "/auth/register", - json={ - "email": "short@example.com", - "password": "short", - "display_name": "Short Pass", - }, - ) - assert resp.status_code == 422 - - -@pytest.mark.asyncio -async def test_login_success(client): - await client.post( - "/auth/register", - json={ - "email": "login@example.com", - "password": "securepass123", - "display_name": "Login User", - }, - ) - resp = await client.post( - "/auth/login", - json={ - "email": "login@example.com", - "password": "securepass123", - }, - ) - assert resp.status_code == 200 - assert "access_token" in resp.json() - - -@pytest.mark.asyncio -async def test_login_wrong_password(client): - await client.post( - "/auth/register", - json={ - "email": "wrong@example.com", - "password": "securepass123", - "display_name": "Wrong Pass", - }, - ) - resp = await client.post( - "/auth/login", - json={ - "email": "wrong@example.com", - "password": "badpassword1", - }, - ) - assert resp.status_code == 401 - - -@pytest.mark.asyncio -async def test_login_nonexistent_user(client): - resp = await client.post( - "/auth/login", - json={ - "email": "ghost@example.com", - "password": "doesntmatter", - }, - ) - assert resp.status_code == 401 - - -@pytest.mark.asyncio -async def test_refresh_token(client): - reg = await client.post( - "/auth/register", - json={ - "email": "refresh@example.com", - "password": "securepass123", - "display_name": "Refresh User", - }, - ) - refresh_token = reg.json()["refresh_token"] - - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": refresh_token, - }, - ) - assert resp.status_code == 200 - assert "access_token" in resp.json() - - -@pytest.mark.asyncio -async def test_refresh_with_invalid_token(client): - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": "invalid.token.here", - }, - ) - assert resp.status_code == 401 - - @pytest.mark.asyncio async def test_get_me(client, auth_headers): resp = await client.get("/auth/me", headers=auth_headers) @@ -155,7 +22,32 @@ async def test_get_me(client, auth_headers): @pytest.mark.asyncio async def test_get_me_unauthorized(client): resp = await client.get("/auth/me") - assert resp.status_code in (401, 403) # No auth header + assert resp.status_code in (401, 403) + + +@pytest.mark.asyncio +async def test_get_me_invalid_session(client): + resp = await client.get( + "/auth/me", + headers={"Cookie": "better-auth.session_token=invalid-token"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me_with_bearer_token(client, db_engine): + """Session tokens can also be passed as Bearer tokens for API clients.""" + from tests.conftest import _create_test_user_and_session + + _, session_token = await _create_test_user_and_session( + client, db_engine, email="bearer@example.com", display_name="Bearer User" + ) + resp = await client.get( + "/auth/me", + headers={"Authorization": f"Bearer {session_token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["email"] == "bearer@example.com" @pytest.mark.asyncio @@ -163,9 +55,7 @@ async def test_update_me(client, auth_headers): resp = await client.patch( "/auth/me", headers=auth_headers, - json={ - "display_name": "Updated Name", - }, + json={"display_name": "Updated Name"}, ) assert resp.status_code == 200 assert resp.json()["display_name"] == "Updated Name" @@ -176,34 +66,58 @@ async def test_delete_me(client, auth_headers): resp = await client.delete("/auth/me", headers=auth_headers) assert resp.status_code == 204 - # Verify user is gone (token still valid but user deleted) + # Session is still valid but user is gone resp = await client.get("/auth/me", headers=auth_headers) assert resp.status_code == 404 @pytest.mark.asyncio -async def test_refresh_after_delete_fails(client): - """Refresh token for a deleted user must be rejected.""" - reg = await client.post( - "/auth/register", - json={ - "email": "ghost@example.com", - "password": "securepass123", - "display_name": "Ghost User", - }, - ) - tokens = reg.json() - headers = {"Authorization": f"Bearer {tokens['access_token']}"} +async def test_expired_session_rejected(client, db_engine): + """Expired sessions must be rejected.""" + import secrets + import uuid + from datetime import UTC, datetime, timedelta - # Delete the user - resp = await client.delete("/auth/me", headers=headers) - assert resp.status_code == 204 + from sqlalchemy import text - # Refresh token should now fail - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": tokens["refresh_token"], - }, + user_id = str(uuid.uuid4()) + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) " + "VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)" + ), + { + "id": user_id, + "email": "expired@example.com", + "hp": "unused", + "dn": "Expired User", + "eit": secrets.token_urlsafe(16), + "ca": now, + "ua": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :uid, :ea, :ca, :ua)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "uid": user_id, + "ea": expired, + "ca": now, + "ua": now, + }, + ) + + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={session_token}"}, ) assert resp.status_code == 401 diff --git a/api/tests/test_e2e/conftest.py b/api/tests/test_e2e/conftest.py index a48418d..29ae3d4 100644 --- a/api/tests/test_e2e/conftest.py +++ b/api/tests/test_e2e/conftest.py @@ -7,12 +7,13 @@ exercise cross-resource queries against real data. from datetime import date, timedelta from decimal import Decimal -from uuid import UUID +import uuid + +from sqlalchemy import text import pytest from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from cartsnitch_api.auth.jwt import decode_token from cartsnitch_api.models import ( Coupon, NormalizedProduct, @@ -33,17 +34,20 @@ ANCHOR_DATE = date.today() @pytest.fixture async def seed_data(db_engine, auth_headers): """Seed a full dataset and return identifiers for test assertions.""" + import uuid + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: # -- Stores -- - meijer = Store(name="Meijer", slug="meijer") - kroger = Store(name="Kroger", slug="kroger") - target = Store(name="Target", slug="target") + meijer = Store(name="Meijer", slug="meijer", id=uuid.uuid4()) + kroger = Store(name="Kroger", slug="kroger", id=uuid.uuid4()) + target = Store(name="Target", slug="target", id=uuid.uuid4()) session.add_all([meijer, kroger, target]) await session.flush() # -- Products -- cheerios = NormalizedProduct( + id=uuid.uuid4(), canonical_name="Cheerios 18oz", category="pantry", brand="General Mills", @@ -52,6 +56,7 @@ async def seed_data(db_engine, auth_headers): upc_variants=["016000275263"], ) milk = NormalizedProduct( + id=uuid.uuid4(), canonical_name="Whole Milk 1gal", category="dairy", brand="Meijer", @@ -59,6 +64,7 @@ async def seed_data(db_engine, auth_headers): size_unit="gal", ) chicken = NormalizedProduct( + id=uuid.uuid4(), canonical_name="Chicken Breast 1lb", category="meat", brand=None, @@ -75,6 +81,7 @@ async def seed_data(db_engine, auth_headers): for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]): prices.append( PriceHistory( + id=uuid.uuid4(), normalized_product_id=cheerios.id, store_id=meijer.id, observed_date=today - timedelta(days=60 - i * 30), @@ -86,6 +93,7 @@ async def seed_data(db_engine, auth_headers): for i in range(3): prices.append( PriceHistory( + id=uuid.uuid4(), normalized_product_id=cheerios.id, store_id=kroger.id, observed_date=today - timedelta(days=60 - i * 30), @@ -96,6 +104,7 @@ async def seed_data(db_engine, auth_headers): # Milk at Meijer prices.append( PriceHistory( + id=uuid.uuid4(), normalized_product_id=milk.id, store_id=meijer.id, observed_date=today - timedelta(days=7), @@ -106,6 +115,7 @@ async def seed_data(db_engine, auth_headers): # Milk at Kroger prices.append( PriceHistory( + id=uuid.uuid4(), normalized_product_id=milk.id, store_id=kroger.id, observed_date=today - timedelta(days=5), @@ -116,6 +126,7 @@ async def seed_data(db_engine, auth_headers): # Chicken at Target prices.append( PriceHistory( + id=uuid.uuid4(), normalized_product_id=chicken.id, store_id=target.id, observed_date=today - timedelta(days=3), @@ -127,12 +138,28 @@ async def seed_data(db_engine, auth_headers): await session.flush() # -- Purchases (need the user_id from the registered test user) -- - token = auth_headers["Authorization"].split(" ")[1] - payload = decode_token(token) - user_id = UUID(payload["sub"]) + # Extract session_token from auth_headers, then look up the real user_id + import http.cookies + cookie_header = auth_headers.get("Cookie", "") + cookies = http.cookies.SimpleCookie() + cookies.load(cookie_header) + session_token = cookies.get("better-auth.session_token").value if "better-auth.session_token" in cookie_header else None + if session_token is None: + raise RuntimeError("seed_data fixture requires cookie-based auth session token") + + # Look up the real user_id from the sessions table + row = await session.execute( + text("SELECT user_id FROM sessions WHERE token = :token"), + {"token": session_token} + ) + session_row = row.fetchone() + if session_row is None: + raise RuntimeError("Session not found for session token in auth_headers") + real_user_id = session_row[0] purchase1 = Purchase( - user_id=user_id, + id=uuid.uuid4(), + user_id=uuid.UUID(real_user_id), store_id=meijer.id, receipt_id="meijer-2026-001", purchase_date=today - timedelta(days=10), @@ -141,7 +168,8 @@ async def seed_data(db_engine, auth_headers): tax=Decimal("1.95"), ) purchase2 = Purchase( - user_id=user_id, + id=uuid.uuid4(), + user_id=uuid.UUID(real_user_id), store_id=kroger.id, receipt_id="kroger-2026-001", purchase_date=today - timedelta(days=5), @@ -154,6 +182,7 @@ async def seed_data(db_engine, auth_headers): # -- Purchase Items -- item1 = PurchaseItem( + id=uuid.uuid4(), purchase_id=purchase1.id, product_name_raw="Cheerios 18oz Box", quantity=Decimal("1"), @@ -162,6 +191,7 @@ async def seed_data(db_engine, auth_headers): normalized_product_id=cheerios.id, ) item2 = PurchaseItem( + id=uuid.uuid4(), purchase_id=purchase1.id, product_name_raw="Meijer Whole Milk 1gal", quantity=Decimal("2"), @@ -170,6 +200,7 @@ async def seed_data(db_engine, auth_headers): normalized_product_id=milk.id, ) item3 = PurchaseItem( + id=uuid.uuid4(), purchase_id=purchase2.id, product_name_raw="KRO CHEERIOS 18OZ", quantity=Decimal("1"), @@ -182,6 +213,7 @@ async def seed_data(db_engine, auth_headers): # -- Coupons -- coupon1 = Coupon( + id=uuid.uuid4(), store_id=meijer.id, normalized_product_id=cheerios.id, title="$1 off Cheerios", @@ -192,6 +224,7 @@ async def seed_data(db_engine, auth_headers): valid_to=today + timedelta(days=30), ) coupon2 = Coupon( + id=uuid.uuid4(), store_id=kroger.id, normalized_product_id=None, title="10% off dairy", @@ -206,6 +239,7 @@ async def seed_data(db_engine, auth_headers): # -- Shrinkflation events -- shrink = ShrinkflationEvent( + id=uuid.uuid4(), normalized_product_id=cheerios.id, detected_date=today - timedelta(days=15), old_size="20", @@ -240,7 +274,7 @@ async def seed_data(db_engine, auth_headers): return { "headers": auth_headers, - "user_id": user_id, + "user_id": real_user_id, "stores": {"meijer": meijer, "kroger": kroger, "target": target}, "products": {"cheerios": cheerios, "milk": milk, "chicken": chicken}, "purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2}, diff --git a/api/tests/test_e2e/test_auth_validation.py b/api/tests/test_e2e/test_auth_validation.py index bbded83..23c28d6 100644 --- a/api/tests/test_e2e/test_auth_validation.py +++ b/api/tests/test_e2e/test_auth_validation.py @@ -1,132 +1,103 @@ -"""E2E: Auth and token validation flows.""" +"""E2E: Auth and session validation flows. -import asyncio +Registration and login are handled by the Better-Auth service. +These tests validate session token handling at the API gateway level. +""" import pytest - -@pytest.mark.asyncio -class TestAuthRegistrationLogin: - """Full registration → login → token refresh → profile flow.""" - - async def test_full_auth_lifecycle(self, client, db_engine): - """Register → login → get profile → refresh → get profile again.""" - # Register - reg = await client.post( - "/auth/register", - json={ - "email": "lifecycle@example.com", - "password": "securepass123", - "display_name": "Lifecycle User", - }, - ) - assert reg.status_code == 201 - tokens = reg.json() - assert "access_token" in tokens - assert "refresh_token" in tokens - assert tokens["token_type"] == "bearer" - assert tokens["expires_in"] > 0 - - headers = {"Authorization": f"Bearer {tokens['access_token']}"} - - # Get profile with access token - me = await client.get("/auth/me", headers=headers) - assert me.status_code == 200 - assert me.json()["email"] == "lifecycle@example.com" - assert me.json()["display_name"] == "Lifecycle User" - - # Sleep 1s so the new token has a different exp than the registration token - await asyncio.sleep(1) - - # Login with same credentials - login = await client.post( - "/auth/login", - json={"email": "lifecycle@example.com", "password": "securepass123"}, - ) - assert login.status_code == 200 - login_tokens = login.json() - assert login_tokens["access_token"] != tokens["access_token"] - - # Refresh token - refresh = await client.post( - "/auth/refresh", - json={"refresh_token": tokens["refresh_token"]}, - ) - assert refresh.status_code == 200 - new_tokens = refresh.json() - assert new_tokens["access_token"] != tokens["access_token"] - - # Use refreshed token to access profile - new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"} - me2 = await client.get("/auth/me", headers=new_headers) - assert me2.status_code == 200 - assert me2.json()["email"] == "lifecycle@example.com" +from tests.conftest import _create_test_user_and_session @pytest.mark.asyncio -class TestTokenValidation: - """Token edge cases and error responses.""" +class TestSessionValidation: + """Session edge cases and error responses.""" - async def test_expired_token_rejected(self, client, db_engine): - """Manually craft an expired token and verify rejection.""" - import uuid - from datetime import UTC, datetime, timedelta - - from jose import jwt - - from cartsnitch_api.config import settings - - payload = { - "sub": str(uuid.uuid4()), - "exp": datetime.now(UTC) - timedelta(minutes=5), - "type": "access", - } - token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) - resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"}) + async def test_invalid_session_token_rejected(self, client, db_engine): + resp = await client.get( + "/auth/me", + headers={"Cookie": "better-auth.session_token=not-a-real-token"}, + ) assert resp.status_code == 401 - async def test_invalid_token_rejected(self, client, db_engine): - resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"}) - assert resp.status_code == 401 - - async def test_missing_auth_header(self, client, db_engine): + async def test_missing_auth(self, client, db_engine): resp = await client.get("/auth/me") assert resp.status_code in (401, 403) - async def test_refresh_token_cannot_access_endpoints(self, client, db_engine): - """A refresh token should not work as an access token.""" - reg = await client.post( - "/auth/register", - json={ - "email": "refresh-test@example.com", - "password": "securepass123", - "display_name": "Refresh Test", - }, + async def test_bearer_token_also_works(self, client, db_engine): + """Session tokens passed as Bearer tokens should also be accepted.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E" ) - refresh_token = reg.json()["refresh_token"] - resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}) - assert resp.status_code == 401 - - async def test_deleted_user_token_invalid(self, client, db_engine): - """After deleting an account, tokens should no longer work.""" - reg = await client.post( - "/auth/register", - json={ - "email": "delete-me@example.com", - "password": "securepass123", - "display_name": "Delete Me", - }, + resp = await client.get( + "/auth/me", + headers={"Authorization": f"Bearer {session_token}"}, ) - tokens = reg.json() - headers = {"Authorization": f"Bearer {tokens['access_token']}"} + assert resp.status_code == 200 + assert resp.json()["email"] == "bearer@e2e.com" + + async def test_deleted_user_session_returns_not_found(self, client, db_engine): + """After deleting a user, their session should result in 404 for profile.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="delete-me@e2e.com", display_name="Delete Me" + ) + headers = {"Cookie": f"better-auth.session_token={session_token}"} - # Delete account delete_resp = await client.delete("/auth/me", headers=headers) assert delete_resp.status_code == 204 - # Profile should fail me = await client.get("/auth/me", headers=headers) - assert me.status_code in (401, 404) + assert me.status_code == 404 + + async def test_expired_session_rejected(self, client, db_engine): + """Expired sessions must be rejected.""" + import secrets + import uuid + from datetime import UTC, datetime, timedelta + + from sqlalchemy import text + + user_id = str(uuid.uuid4()) + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) " + "VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)" + ), + { + "id": user_id, + "email": "expired@e2e.com", + "hp": "unused", + "dn": "Expired User", + "eit": secrets.token_urlsafe(16), + "ca": now, + "ua": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :uid, :ea, :ca, :ua)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "uid": user_id, + "ea": expired, + "ca": now, + "ua": now, + }, + ) + + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={session_token}"}, + ) + assert resp.status_code == 401 @pytest.mark.asyncio @@ -154,60 +125,38 @@ class TestAuthProtectedEndpoints: class TestCrossUserDataIsolation: """Verify that users cannot access other users' data.""" - async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data): - """Register a second user and verify they cannot see User A's purchases.""" - # User A's purchase (from seed_data) + async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data): + """A second user cannot see User A's purchases.""" purchase_id = str(seed_data["purchases"]["meijer_trip"].id) - # Register User B - reg = await client.post( - "/auth/register", - json={ - "email": "userb@example.com", - "password": "securepass123", - "display_name": "User B", - }, + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userb@e2e.com", display_name="User B" ) - assert reg.status_code == 201 - user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"} - # User B tries to access User A's specific purchase resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers) assert resp.status_code in (403, 404), ( "User B should not be able to access User A's purchase" ) - async def test_user_b_purchase_list_is_empty(self, client, seed_data): - """A new user should see no purchases (not User A's purchases).""" - reg = await client.post( - "/auth/register", - json={ - "email": "userc@example.com", - "password": "securepass123", - "display_name": "User C", - }, + async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data): + """A new user should see no purchases.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userc@e2e.com", display_name="User C" ) - assert reg.status_code == 201 - user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"} resp = await client.get("/purchases", headers=user_c_headers) assert resp.status_code == 200 assert len(resp.json()) == 0, "New user should have no purchases" - async def test_user_b_stores_isolated(self, client, seed_data): + async def test_user_b_stores_isolated(self, client, db_engine, seed_data): """User B's connected stores should be independent from User A.""" - reg = await client.post( - "/auth/register", - json={ - "email": "userd@example.com", - "password": "securepass123", - "display_name": "User D", - }, + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userd@e2e.com", display_name="User D" ) - assert reg.status_code == 201 - user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"} - # User D should have no connected stores resp = await client.get("/me/stores", headers=user_d_headers) assert resp.status_code == 200 assert len(resp.json()) == 0, "New user should have no connected stores" diff --git a/api/tests/test_e2e/test_error_responses.py b/api/tests/test_e2e/test_error_responses.py index c3ad16e..98c46fc 100644 --- a/api/tests/test_e2e/test_error_responses.py +++ b/api/tests/test_e2e/test_error_responses.py @@ -5,74 +5,6 @@ import pytest from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID -@pytest.mark.asyncio -class TestRegistrationErrors: - """Validation errors during user registration.""" - - async def test_short_password(self, client, db_engine): - resp = await client.post( - "/auth/register", - json={"email": "short@example.com", "password": "short", "display_name": "Test"}, - ) - assert resp.status_code == 422 - - async def test_invalid_email(self, client, db_engine): - resp = await client.post( - "/auth/register", - json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"}, - ) - assert resp.status_code == 422 - - async def test_missing_fields(self, client, db_engine): - resp = await client.post("/auth/register", json={}) - assert resp.status_code == 422 - - async def test_empty_display_name(self, client, db_engine): - resp = await client.post( - "/auth/register", - json={"email": "empty@example.com", "password": "securepass123", "display_name": ""}, - ) - assert resp.status_code == 422 - - async def test_duplicate_email(self, client, db_engine): - payload = { - "email": "dupe@example.com", - "password": "securepass123", - "display_name": "First", - } - first = await client.post("/auth/register", json=payload) - assert first.status_code == 201 - second = await client.post("/auth/register", json=payload) - assert second.status_code == 409 - - -@pytest.mark.asyncio -class TestLoginErrors: - """Login failure modes.""" - - async def test_wrong_password(self, client, db_engine): - await client.post( - "/auth/register", - json={ - "email": "login-err@example.com", - "password": "correctpass1", - "display_name": "Login", - }, - ) - resp = await client.post( - "/auth/login", - json={"email": "login-err@example.com", "password": "wrongpass123"}, - ) - assert resp.status_code == 401 - - async def test_nonexistent_user(self, client, db_engine): - resp = await client.post( - "/auth/login", - json={"email": "nobody@example.com", "password": "doesntmatter"}, - ) - assert resp.status_code == 401 - - @pytest.mark.asyncio class TestNotFoundErrors: """404 responses for missing resources.""" diff --git a/api/tests/test_middleware/test_error_handler.py b/api/tests/test_middleware/test_error_handler.py index 950351d..549f6b2 100644 --- a/api/tests/test_middleware/test_error_handler.py +++ b/api/tests/test_middleware/test_error_handler.py @@ -15,11 +15,12 @@ async def test_404_returns_structured_error(client): @pytest.mark.asyncio -async def test_validation_error_returns_422_with_field_errors(client): +async def test_validation_error_returns_422_with_field_errors(client, auth_headers): """Invalid request body should return structured validation errors.""" - resp = await client.post( - "/auth/register", - json={"email": "not-an-email", "password": "short", "display_name": ""}, + resp = await client.patch( + "/auth/me", + headers=auth_headers, + json={"display_name": ""}, ) assert resp.status_code == 422 body = resp.json() diff --git a/api/tests/test_openapi.py b/api/tests/test_openapi.py index 5684ee0..21ce0f7 100644 --- a/api/tests/test_openapi.py +++ b/api/tests/test_openapi.py @@ -6,13 +6,11 @@ from httpx import ASGITransport, AsyncClient from cartsnitch_api.main import app EXPECTED_ROUTES = [ - # Auth (6) - ("post", "/auth/register"), - ("post", "/auth/login"), - ("post", "/auth/refresh"), + # Auth (4 — register/login/refresh handled by Better-Auth service) ("get", "/auth/me"), ("patch", "/auth/me"), ("delete", "/auth/me"), + ("get", "/auth/me/email-in-address"), # Stores (4) ("get", "/stores"), ("get", "/me/stores"), @@ -89,4 +87,4 @@ async def test_route_count(): if method in ("get", "post", "put", "delete", "patch"): count += 1 - assert count == 34, f"Expected 34 routes, found {count}" + assert count == 31, f"Expected 31 routes, found {count}" diff --git a/api/tests/test_routes/test_purchases.py b/api/tests/test_routes/test_purchases.py index 14d5eb6..3589783 100644 --- a/api/tests/test_routes/test_purchases.py +++ b/api/tests/test_routes/test_purchases.py @@ -1,46 +1,82 @@ """Integration tests for purchase endpoints.""" +import secrets import uuid -from datetime import date +from datetime import UTC, datetime, date, timedelta from decimal import Decimal import pytest from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy import text -from cartsnitch_api.auth.jwt import create_access_token -from cartsnitch_api.models import Purchase, PurchaseItem, Store, User +from cartsnitch_api.models import Purchase, PurchaseItem, Store @pytest.fixture async def purchase_data(db_engine): - """Seed a user, store, purchase, and items.""" + """Seed a user, store, purchase, and items using session-cookie auth.""" factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: - from cartsnitch_api.auth.passwords import hash_password + user_id = str(uuid.uuid4()) + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() - user = User( - email="buyer@example.com", - hashed_password=hash_password("testpass123"), - display_name="Buyer", + # Create the user + await session.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) " + "VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)" + ), + { + "id": user_id, + "email": "buyer@example.com", + "hashed_password": "not-used-with-better-auth", + "display_name": "Buyer", + "email_inbound_token": secrets.token_urlsafe(16), + "created_at": now, + "updated_at": now, + }, ) - store = Store(name="Kroger", slug="kroger") - session.add_all([user, store]) - await session.commit() - await session.refresh(user) + + # Create the session + await session.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "user_id": user_id, + "expires_at": expires, + "created_at": now, + "updated_at": now, + }, + ) + + # Create the store + store = Store(name="Kroger", slug="kroger", id=uuid.uuid4()) + session.add(store) + await session.flush() await session.refresh(store) + # Create the purchase purchase = Purchase( - user_id=user.id, + id=uuid.uuid4(), + user_id=uuid.UUID(user_id), store_id=store.id, receipt_id="receipt-001", purchase_date=date(2026, 3, 10), total=Decimal("42.50"), ) session.add(purchase) - await session.commit() + await session.flush() await session.refresh(purchase) + # Create the purchase item item = PurchaseItem( + id=uuid.uuid4(), purchase_id=purchase.id, product_name_raw="Organic Milk 1gal", quantity=Decimal("1"), @@ -50,12 +86,11 @@ async def purchase_data(db_engine): session.add(item) await session.commit() - token = create_access_token(user.id) return { - "user": user, + "user_id": user_id, "store": store, "purchase": purchase, - "headers": {"Authorization": f"Bearer {token}"}, + "headers": {"Cookie": f"better-auth.session_token={session_token}"}, }