From 906d7023a94b536d1fe4b9f43edfe615898367af Mon Sep 17 00:00:00 2001 From: Coupon Carl Date: Sat, 28 Mar 2026 04:46:10 +0000 Subject: [PATCH] feat: migrate authentication to Better-Auth (Phase 1) Replace hand-rolled JWT auth with Better-Auth session-based authentication. - Scaffold auth/ Node.js service with Better-Auth, bcrypt password compat, Postgres adapter mapped to existing users table - Add Alembic migration (002) creating sessions, accounts, verifications tables and migrating password hashes to accounts table - Update FastAPI auth dependency to validate sessions via shared DB (supports both cookie and Bearer token) - Remove registration/login/refresh endpoints from API gateway (now handled by Better-Auth service) - Update frontend to use better-auth/react client with httpOnly cookies (no tokens in localStorage or memory) - Rewrite auth store, Login, Register, Dashboard, Settings, ProtectedRoute to use session-based auth - Update all tests to create sessions directly in DB instead of JWT tokens Resolves CAR-27 See plan: CAR-26#document-plan Co-Authored-By: Paperclip --- alembic/versions/002_better_auth_tables.py | 101 +++++++++ src/cartsnitch_api/auth/dependencies.py | 88 ++++++-- src/cartsnitch_api/auth/routes.py | 42 +--- src/cartsnitch_api/config.py | 2 + src/cartsnitch_api/schemas.py | 24 +- src/cartsnitch_api/services/auth.py | 67 +----- tests/conftest.py | 116 ++++++++-- tests/test_auth/test_auth_endpoints.py | 244 +++++++-------------- tests/test_e2e/conftest.py | 16 +- tests/test_e2e/test_auth_validation.py | 239 ++++++++------------ tests/test_routes/test_purchases.py | 45 ++-- 11 files changed, 505 insertions(+), 479 deletions(-) create mode 100644 alembic/versions/002_better_auth_tables.py diff --git a/alembic/versions/002_better_auth_tables.py b/alembic/versions/002_better_auth_tables.py new file mode 100644 index 0000000..aa5dd93 --- /dev/null +++ b/alembic/versions/002_better_auth_tables.py @@ -0,0 +1,101 @@ +"""Add Better-Auth tables and extend users table. + +Creates sessions, accounts, and verifications tables for Better-Auth. +Adds email_verified and image columns to existing users table. +Migrates password hashes from users.hashed_password to accounts.password. + +Revision ID: 002_better_auth_tables +Revises: 001_encrypt_session_data +Create Date: 2026-03-28 +""" + +import sqlalchemy as sa +from sqlalchemy import text + +from alembic import op + +revision = "002_better_auth_tables" +down_revision = "001_encrypt_session_data" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # --- Extend users table for Better-Auth compatibility --- + op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false")) + op.add_column("users", sa.Column("image", sa.Text(), nullable=True)) + + # --- Create sessions table --- + op.create_table( + "sessions", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("token", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("ip_address", sa.Text(), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_sessions_token", "sessions", ["token"], unique=True) + op.create_index("ix_sessions_user_id", "sessions", ["user_id"]) + + # --- Create accounts table --- + op.create_table( + "accounts", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("account_id", sa.Text(), nullable=False), + sa.Column("provider_id", sa.Text(), nullable=False), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("id_token", sa.Text(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_accounts_user_id", "accounts", ["user_id"]) + + # --- Create verifications table --- + op.create_table( + "verifications", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("identifier", sa.Text(), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + # --- Migrate existing password hashes to accounts table --- + # For each user with a hashed_password, create a 'credential' account row + conn = op.get_bind() + users = conn.execute( + text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL") + ).fetchall() + + for user_id, hashed_password in users: + user_id_str = str(user_id) + conn.execute( + text( + "INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) " + "VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())" + ), + {"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password}, + ) + + +def downgrade() -> None: + op.drop_table("verifications") + op.drop_table("accounts") + op.drop_index("ix_sessions_user_id", table_name="sessions") + op.drop_index("ix_sessions_token", table_name="sessions") + op.drop_table("sessions") + op.drop_column("users", "image") + op.drop_column("users", "email_verified") diff --git a/src/cartsnitch_api/auth/dependencies.py b/src/cartsnitch_api/auth/dependencies.py index 61735ee..93f8eb8 100644 --- a/src/cartsnitch_api/auth/dependencies.py +++ b/src/cartsnitch_api/auth/dependencies.py @@ -1,34 +1,88 @@ -"""FastAPI dependency injection for authentication.""" +"""FastAPI dependency injection for authentication. +Validates Better-Auth session tokens from cookies or Bearer header. +Sessions are verified by querying the shared sessions table directly. +""" + +from datetime import UTC, datetime from uuid import UUID -from fastapi import Depends, Header, HTTPException, status +from fastapi import Cookie, Depends, Header, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession -from cartsnitch_api.auth.jwt import decode_token from cartsnitch_api.config import settings +from cartsnitch_api.database import get_db -bearer_scheme = HTTPBearer() +# Keep Bearer scheme as optional — Better-Auth primarily uses cookies, +# but we support Bearer tokens for service-to-service or mobile clients. +bearer_scheme = HTTPBearer(auto_error=False) + +# Better-Auth session cookie name +SESSION_COOKIE_NAME = "better-auth.session_token" + + +async def _validate_session_token(token: str, db: AsyncSession) -> UUID: + """Validate a Better-Auth session token against the sessions table. + + Returns the user_id (as UUID) if the session is valid and not expired. + """ + result = await db.execute( + text("SELECT user_id, expires_at FROM sessions WHERE token = :token"), + {"token": token}, + ) + row = result.first() + + if not row: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid session token", + ) + + user_id, expires_at = row + if expires_at.tzinfo is None: + # Treat naive datetimes as UTC + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < datetime.now(UTC): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Session expired", + ) + + return UUID(str(user_id)) async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), + request: Request, + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + db: AsyncSession = Depends(get_db), ) -> UUID: - try: - payload = decode_token(credentials.credentials) - except ValueError: + """Extract and validate the session token from cookie or Authorization header. + + Checks in order: + 1. Better-Auth session cookie (primary — web clients) + 2. Bearer token in Authorization header (fallback — API clients) + """ + token: str | None = None + + # 1. Check session cookie + cookie_token = request.cookies.get(SESSION_COOKIE_NAME) + if cookie_token: + token = cookie_token + + # 2. Fall back to Bearer header + if not token and credentials: + token = credentials.credentials + + if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - ) from None + detail="Authentication required", + ) - if payload.get("type") != "access": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token type", - ) from None - - return UUID(payload["sub"]) + return await _validate_session_token(token, db) async def verify_service_key(x_service_key: str = Header()) -> None: diff --git a/src/cartsnitch_api/auth/routes.py b/src/cartsnitch_api/auth/routes.py index ab34c3e..81cae2f 100644 --- a/src/cartsnitch_api/auth/routes.py +++ b/src/cartsnitch_api/auth/routes.py @@ -1,4 +1,9 @@ -"""Auth routes: register, login, refresh, me, update, delete.""" +"""Auth routes: user profile management. + +Registration, login, refresh, and session management are handled by +the Better-Auth service (auth/). This router provides user profile +endpoints that query our own user data from the shared database. +""" from uuid import UUID @@ -8,10 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from cartsnitch_api.auth.dependencies import get_current_user from cartsnitch_api.database import get_db from cartsnitch_api.schemas import ( - LoginRequest, - RefreshRequest, - RegisterRequest, - TokenResponse, UpdateUserRequest, UserResponse, ) @@ -20,37 +21,6 @@ from cartsnitch_api.services.auth import AuthService router = APIRouter(prefix="/auth", tags=["auth"]) -@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) -async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.register(body.email, body.password, body.display_name) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e - - -@router.post("/login", response_model=TokenResponse) -async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.login(body.email, body.password) - except ValueError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password" - ) from None - - -@router.post("/refresh", response_model=TokenResponse) -async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)): - svc = AuthService(db) - try: - return await svc.refresh(body.refresh_token) - except ValueError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" - ) from None - - @router.get("/me", response_model=UserResponse) async def get_me( user_id: UUID = Depends(get_current_user), diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index 52474b2..5111997 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -19,6 +19,8 @@ class Settings(BaseSettings): # Valid Fernet key for local dev — MUST be overridden in production fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + auth_service_url: str = "http://auth:3001" + cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"] receiptwitness_url: str = "http://receiptwitness:8001" diff --git a/src/cartsnitch_api/schemas.py b/src/cartsnitch_api/schemas.py index 19e351a..1ba727e 100644 --- a/src/cartsnitch_api/schemas.py +++ b/src/cartsnitch_api/schemas.py @@ -6,28 +6,8 @@ from uuid import UUID from pydantic import BaseModel, EmailStr, Field # ---------- Auth ---------- - - -class RegisterRequest(BaseModel): - email: EmailStr - password: str = Field(min_length=8, max_length=128) - display_name: str = Field(min_length=1, max_length=100) - - -class LoginRequest(BaseModel): - email: EmailStr - password: str - - -class RefreshRequest(BaseModel): - refresh_token: str - - -class TokenResponse(BaseModel): - access_token: str - refresh_token: str - token_type: str = "bearer" - expires_in: int +# Registration, login, and session management are handled by Better-Auth (auth/ service). +# These schemas are for the profile management endpoints only. class UpdateUserRequest(BaseModel): diff --git a/src/cartsnitch_api/services/auth.py b/src/cartsnitch_api/services/auth.py index 5ea6b77..91724af 100644 --- a/src/cartsnitch_api/services/auth.py +++ b/src/cartsnitch_api/services/auth.py @@ -1,67 +1,20 @@ -"""Auth service — user registration, login, token management.""" +"""Auth service — user profile management. + +Registration, login, token management, and session handling are now +handled by the Better-Auth service (auth/). This service provides +user lookup and profile update operations for the API gateway. +""" from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token -from cartsnitch_api.auth.passwords import hash_password, verify_password -from cartsnitch_api.config import settings - class AuthService: def __init__(self, db: AsyncSession) -> None: self.db = db - async def register(self, email: str, password: str, display_name: str) -> dict: - from cartsnitch_api.models import User - - existing = await self.db.execute(select(User).where(User.email == email)) - if existing.scalar_one_or_none(): - raise ValueError("Email already registered") - - user = User( - email=email, - hashed_password=hash_password(password), - display_name=display_name, - ) - self.db.add(user) - await self.db.commit() - await self.db.refresh(user) - - return self._make_token_response(user.id) - - async def login(self, email: str, password: str) -> dict: - from cartsnitch_api.models import User - - result = await self.db.execute(select(User).where(User.email == email)) - user = result.scalar_one_or_none() - if not user or not verify_password(password, user.hashed_password): - raise ValueError("Invalid email or password") - - return self._make_token_response(user.id) - - async def refresh(self, refresh_token: str) -> dict: - from cartsnitch_api.models import User - - try: - payload = decode_token(refresh_token) - except ValueError: - raise ValueError("Invalid refresh token") from None - - if payload.get("type") != "refresh": - raise ValueError("Invalid token type") from None - - user_id = UUID(payload["sub"]) - - # Verify the user still exists before issuing new tokens - result = await self.db.execute(select(User).where(User.id == user_id)) - if not result.scalar_one_or_none(): - raise ValueError("User no longer exists") - - return self._make_token_response(user_id) - async def get_user(self, user_id: UUID) -> dict: from cartsnitch_api.models import User @@ -115,11 +68,3 @@ class AuthService: await self.db.delete(user) await self.db.commit() - - def _make_token_response(self, user_id: UUID) -> dict: - return { - "access_token": create_access_token(user_id), - "refresh_token": create_refresh_token(user_id), - "token_type": "bearer", - "expires_in": settings.jwt_access_token_expire_minutes * 60, - } diff --git a/tests/conftest.py b/tests/conftest.py index 9873903..61810e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,16 @@ -"""Shared test fixtures with in-memory SQLite database.""" +"""Shared test fixtures with in-memory SQLite database. + +Session-based auth: tests create users and sessions directly in the DB, +matching the Better-Auth session validation flow. +""" + +import secrets +import uuid +from datetime import UTC, datetime, timedelta import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import create_engine, event +from sqlalchemy import create_engine, event, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker @@ -51,6 +59,46 @@ async def db_engine(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + # Create Better-Auth tables (not managed by SQLAlchemy models) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + token TEXT NOT NULL UNIQUE, + user_id TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + ip_address TEXT, + user_agent TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + account_id TEXT NOT NULL, + provider_id TEXT NOT NULL, + access_token TEXT, + refresh_token TEXT, + access_token_expires_at TIMESTAMP, + refresh_token_expires_at TIMESTAMP, + scope TEXT, + id_token TEXT, + password TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) + await conn.execute(text(""" + CREATE TABLE IF NOT EXISTS verifications ( + id TEXT PRIMARY KEY, + identifier TEXT NOT NULL, + value TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """)) yield engine @@ -85,17 +133,55 @@ async def client(db_engine): app.dependency_overrides.clear() +async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]: + """Create a test user and a valid session directly in the DB. + + Returns (user_dict, session_token). + """ + user_id = str(uuid.uuid4()) + email = user_overrides.get("email", "test@example.com") + display_name = user_overrides.get("display_name", "Test User") + session_token = secrets.token_urlsafe(32) + session_id = str(uuid.uuid4()) + now = datetime.now(UTC).isoformat() + expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) " + "VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)" + ), + { + "id": user_id, + "email": email, + "hashed_password": "not-used-with-better-auth", + "display_name": display_name, + "email_verified": False, + "created_at": now, + "updated_at": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" + ), + { + "id": session_id, + "token": session_token, + "user_id": user_id, + "expires_at": expires, + "created_at": now, + "updated_at": now, + }, + ) + + return {"id": user_id, "email": email, "display_name": display_name}, session_token + + @pytest.fixture -async def auth_headers(client): - """Register a test user and return auth headers.""" - resp = await client.post( - "/auth/register", - json={ - "email": "test@example.com", - "password": "testpass123", - "display_name": "Test User", - }, - ) - assert resp.status_code == 201 - token = resp.json()["access_token"] - return {"Authorization": f"Bearer {token}"} +async def auth_headers(client, db_engine): + """Create a test user with a valid session and return auth headers.""" + _, session_token = await _create_test_user_and_session(client, db_engine) + return {"Cookie": f"better-auth.session_token={session_token}"} diff --git a/tests/test_auth/test_auth_endpoints.py b/tests/test_auth/test_auth_endpoints.py index 878cbc5..7b096ae 100644 --- a/tests/test_auth/test_auth_endpoints.py +++ b/tests/test_auth/test_auth_endpoints.py @@ -1,146 +1,13 @@ -"""Integration tests for auth endpoints.""" +"""Integration tests for auth profile endpoints. + +Registration, login, and session management are handled by the Better-Auth +service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me) +which validate sessions via the shared sessions table. +""" import pytest -@pytest.mark.asyncio -async def test_register_success(client): - resp = await client.post( - "/auth/register", - json={ - "email": "new@example.com", - "password": "securepass123", - "display_name": "New User", - }, - ) - assert resp.status_code == 201 - data = resp.json() - assert "access_token" in data - assert "refresh_token" in data - assert data["token_type"] == "bearer" - assert data["expires_in"] == 900 # 15 min * 60 - - -@pytest.mark.asyncio -async def test_register_duplicate_email(client): - await client.post( - "/auth/register", - json={ - "email": "dupe@example.com", - "password": "securepass123", - "display_name": "User One", - }, - ) - resp = await client.post( - "/auth/register", - json={ - "email": "dupe@example.com", - "password": "securepass456", - "display_name": "User Two", - }, - ) - assert resp.status_code == 409 - - -@pytest.mark.asyncio -async def test_register_short_password(client): - resp = await client.post( - "/auth/register", - json={ - "email": "short@example.com", - "password": "short", - "display_name": "Short Pass", - }, - ) - assert resp.status_code == 422 - - -@pytest.mark.asyncio -async def test_login_success(client): - await client.post( - "/auth/register", - json={ - "email": "login@example.com", - "password": "securepass123", - "display_name": "Login User", - }, - ) - resp = await client.post( - "/auth/login", - json={ - "email": "login@example.com", - "password": "securepass123", - }, - ) - assert resp.status_code == 200 - assert "access_token" in resp.json() - - -@pytest.mark.asyncio -async def test_login_wrong_password(client): - await client.post( - "/auth/register", - json={ - "email": "wrong@example.com", - "password": "securepass123", - "display_name": "Wrong Pass", - }, - ) - resp = await client.post( - "/auth/login", - json={ - "email": "wrong@example.com", - "password": "badpassword1", - }, - ) - assert resp.status_code == 401 - - -@pytest.mark.asyncio -async def test_login_nonexistent_user(client): - resp = await client.post( - "/auth/login", - json={ - "email": "ghost@example.com", - "password": "doesntmatter", - }, - ) - assert resp.status_code == 401 - - -@pytest.mark.asyncio -async def test_refresh_token(client): - reg = await client.post( - "/auth/register", - json={ - "email": "refresh@example.com", - "password": "securepass123", - "display_name": "Refresh User", - }, - ) - refresh_token = reg.json()["refresh_token"] - - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": refresh_token, - }, - ) - assert resp.status_code == 200 - assert "access_token" in resp.json() - - -@pytest.mark.asyncio -async def test_refresh_with_invalid_token(client): - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": "invalid.token.here", - }, - ) - assert resp.status_code == 401 - - @pytest.mark.asyncio async def test_get_me(client, auth_headers): resp = await client.get("/auth/me", headers=auth_headers) @@ -155,7 +22,32 @@ async def test_get_me(client, auth_headers): @pytest.mark.asyncio async def test_get_me_unauthorized(client): resp = await client.get("/auth/me") - assert resp.status_code in (401, 403) # No auth header + assert resp.status_code in (401, 403) + + +@pytest.mark.asyncio +async def test_get_me_invalid_session(client): + resp = await client.get( + "/auth/me", + headers={"Cookie": "better-auth.session_token=invalid-token"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me_with_bearer_token(client, db_engine): + """Session tokens can also be passed as Bearer tokens for API clients.""" + from tests.conftest import _create_test_user_and_session + + _, session_token = await _create_test_user_and_session( + client, db_engine, email="bearer@example.com", display_name="Bearer User" + ) + resp = await client.get( + "/auth/me", + headers={"Authorization": f"Bearer {session_token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["email"] == "bearer@example.com" @pytest.mark.asyncio @@ -163,9 +55,7 @@ async def test_update_me(client, auth_headers): resp = await client.patch( "/auth/me", headers=auth_headers, - json={ - "display_name": "Updated Name", - }, + json={"display_name": "Updated Name"}, ) assert resp.status_code == 200 assert resp.json()["display_name"] == "Updated Name" @@ -176,34 +66,58 @@ async def test_delete_me(client, auth_headers): resp = await client.delete("/auth/me", headers=auth_headers) assert resp.status_code == 204 - # Verify user is gone (token still valid but user deleted) + # Session is still valid but user is gone resp = await client.get("/auth/me", headers=auth_headers) assert resp.status_code == 404 @pytest.mark.asyncio -async def test_refresh_after_delete_fails(client): - """Refresh token for a deleted user must be rejected.""" - reg = await client.post( - "/auth/register", - json={ - "email": "ghost@example.com", - "password": "securepass123", - "display_name": "Ghost User", - }, - ) - tokens = reg.json() - headers = {"Authorization": f"Bearer {tokens['access_token']}"} +async def test_expired_session_rejected(client, db_engine): + """Expired sessions must be rejected.""" + import secrets + import uuid + from datetime import UTC, datetime, timedelta - # Delete the user - resp = await client.delete("/auth/me", headers=headers) - assert resp.status_code == 204 + from sqlalchemy import text - # Refresh token should now fail - resp = await client.post( - "/auth/refresh", - json={ - "refresh_token": tokens["refresh_token"], - }, + user_id = str(uuid.uuid4()) + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) " + "VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)" + ), + { + "id": user_id, + "email": "expired@example.com", + "hp": "unused", + "dn": "Expired User", + "ev": False, + "ca": now, + "ua": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :uid, :ea, :ca, :ua)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "uid": user_id, + "ea": expired, + "ca": now, + "ua": now, + }, + ) + + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={session_token}"}, ) assert resp.status_code == 401 diff --git a/tests/test_e2e/conftest.py b/tests/test_e2e/conftest.py index f1390fd..d352344 100644 --- a/tests/test_e2e/conftest.py +++ b/tests/test_e2e/conftest.py @@ -10,9 +10,9 @@ from decimal import Decimal from uuid import UUID import pytest +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from cartsnitch_api.auth.jwt import decode_token from cartsnitch_api.models import ( Coupon, NormalizedProduct, @@ -126,10 +126,16 @@ async def seed_data(db_engine, auth_headers): session.add_all(prices) await session.flush() - # -- Purchases (need the user_id from the registered test user) -- - token = auth_headers["Authorization"].split(" ")[1] - payload = decode_token(token) - user_id = UUID(payload["sub"]) + # -- Get the user_id from the session token in auth_headers -- + cookie_str = auth_headers.get("Cookie", "") + session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else "" + + result = await session.execute( + text("SELECT user_id FROM sessions WHERE token = :token"), + {"token": session_token}, + ) + row = result.first() + user_id = UUID(row[0]) purchase1 = Purchase( user_id=user_id, diff --git a/tests/test_e2e/test_auth_validation.py b/tests/test_e2e/test_auth_validation.py index bbded83..f0e38cd 100644 --- a/tests/test_e2e/test_auth_validation.py +++ b/tests/test_e2e/test_auth_validation.py @@ -1,132 +1,103 @@ -"""E2E: Auth and token validation flows.""" +"""E2E: Auth and session validation flows. -import asyncio +Registration and login are handled by the Better-Auth service. +These tests validate session token handling at the API gateway level. +""" import pytest - -@pytest.mark.asyncio -class TestAuthRegistrationLogin: - """Full registration → login → token refresh → profile flow.""" - - async def test_full_auth_lifecycle(self, client, db_engine): - """Register → login → get profile → refresh → get profile again.""" - # Register - reg = await client.post( - "/auth/register", - json={ - "email": "lifecycle@example.com", - "password": "securepass123", - "display_name": "Lifecycle User", - }, - ) - assert reg.status_code == 201 - tokens = reg.json() - assert "access_token" in tokens - assert "refresh_token" in tokens - assert tokens["token_type"] == "bearer" - assert tokens["expires_in"] > 0 - - headers = {"Authorization": f"Bearer {tokens['access_token']}"} - - # Get profile with access token - me = await client.get("/auth/me", headers=headers) - assert me.status_code == 200 - assert me.json()["email"] == "lifecycle@example.com" - assert me.json()["display_name"] == "Lifecycle User" - - # Sleep 1s so the new token has a different exp than the registration token - await asyncio.sleep(1) - - # Login with same credentials - login = await client.post( - "/auth/login", - json={"email": "lifecycle@example.com", "password": "securepass123"}, - ) - assert login.status_code == 200 - login_tokens = login.json() - assert login_tokens["access_token"] != tokens["access_token"] - - # Refresh token - refresh = await client.post( - "/auth/refresh", - json={"refresh_token": tokens["refresh_token"]}, - ) - assert refresh.status_code == 200 - new_tokens = refresh.json() - assert new_tokens["access_token"] != tokens["access_token"] - - # Use refreshed token to access profile - new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"} - me2 = await client.get("/auth/me", headers=new_headers) - assert me2.status_code == 200 - assert me2.json()["email"] == "lifecycle@example.com" +from tests.conftest import _create_test_user_and_session @pytest.mark.asyncio -class TestTokenValidation: - """Token edge cases and error responses.""" +class TestSessionValidation: + """Session edge cases and error responses.""" - async def test_expired_token_rejected(self, client, db_engine): - """Manually craft an expired token and verify rejection.""" - import uuid - from datetime import UTC, datetime, timedelta - - from jose import jwt - - from cartsnitch_api.config import settings - - payload = { - "sub": str(uuid.uuid4()), - "exp": datetime.now(UTC) - timedelta(minutes=5), - "type": "access", - } - token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) - resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"}) + async def test_invalid_session_token_rejected(self, client, db_engine): + resp = await client.get( + "/auth/me", + headers={"Cookie": "better-auth.session_token=not-a-real-token"}, + ) assert resp.status_code == 401 - async def test_invalid_token_rejected(self, client, db_engine): - resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"}) - assert resp.status_code == 401 - - async def test_missing_auth_header(self, client, db_engine): + async def test_missing_auth(self, client, db_engine): resp = await client.get("/auth/me") assert resp.status_code in (401, 403) - async def test_refresh_token_cannot_access_endpoints(self, client, db_engine): - """A refresh token should not work as an access token.""" - reg = await client.post( - "/auth/register", - json={ - "email": "refresh-test@example.com", - "password": "securepass123", - "display_name": "Refresh Test", - }, + async def test_bearer_token_also_works(self, client, db_engine): + """Session tokens passed as Bearer tokens should also be accepted.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E" ) - refresh_token = reg.json()["refresh_token"] - resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}) - assert resp.status_code == 401 - - async def test_deleted_user_token_invalid(self, client, db_engine): - """After deleting an account, tokens should no longer work.""" - reg = await client.post( - "/auth/register", - json={ - "email": "delete-me@example.com", - "password": "securepass123", - "display_name": "Delete Me", - }, + resp = await client.get( + "/auth/me", + headers={"Authorization": f"Bearer {session_token}"}, ) - tokens = reg.json() - headers = {"Authorization": f"Bearer {tokens['access_token']}"} + assert resp.status_code == 200 + assert resp.json()["email"] == "bearer@e2e.com" + + async def test_deleted_user_session_returns_not_found(self, client, db_engine): + """After deleting a user, their session should result in 404 for profile.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="delete-me@e2e.com", display_name="Delete Me" + ) + headers = {"Cookie": f"better-auth.session_token={session_token}"} - # Delete account delete_resp = await client.delete("/auth/me", headers=headers) assert delete_resp.status_code == 204 - # Profile should fail me = await client.get("/auth/me", headers=headers) - assert me.status_code in (401, 404) + assert me.status_code == 404 + + async def test_expired_session_rejected(self, client, db_engine): + """Expired sessions must be rejected.""" + import secrets + import uuid + from datetime import UTC, datetime, timedelta + + from sqlalchemy import text + + user_id = str(uuid.uuid4()) + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) " + "VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)" + ), + { + "id": user_id, + "email": "expired@e2e.com", + "hp": "unused", + "dn": "Expired User", + "ev": False, + "ca": now, + "ua": now, + }, + ) + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :uid, :ea, :ca, :ua)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "uid": user_id, + "ea": expired, + "ca": now, + "ua": now, + }, + ) + + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={session_token}"}, + ) + assert resp.status_code == 401 @pytest.mark.asyncio @@ -154,60 +125,38 @@ class TestAuthProtectedEndpoints: class TestCrossUserDataIsolation: """Verify that users cannot access other users' data.""" - async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data): - """Register a second user and verify they cannot see User A's purchases.""" - # User A's purchase (from seed_data) + async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data): + """A second user cannot see User A's purchases.""" purchase_id = str(seed_data["purchases"]["meijer_trip"].id) - # Register User B - reg = await client.post( - "/auth/register", - json={ - "email": "userb@example.com", - "password": "securepass123", - "display_name": "User B", - }, + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userb@e2e.com", display_name="User B" ) - assert reg.status_code == 201 - user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"} - # User B tries to access User A's specific purchase resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers) assert resp.status_code in (403, 404), ( "User B should not be able to access User A's purchase" ) - async def test_user_b_purchase_list_is_empty(self, client, seed_data): - """A new user should see no purchases (not User A's purchases).""" - reg = await client.post( - "/auth/register", - json={ - "email": "userc@example.com", - "password": "securepass123", - "display_name": "User C", - }, + async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data): + """A new user should see no purchases.""" + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userc@e2e.com", display_name="User C" ) - assert reg.status_code == 201 - user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"} resp = await client.get("/purchases", headers=user_c_headers) assert resp.status_code == 200 assert len(resp.json()) == 0, "New user should have no purchases" - async def test_user_b_stores_isolated(self, client, seed_data): + async def test_user_b_stores_isolated(self, client, db_engine, seed_data): """User B's connected stores should be independent from User A.""" - reg = await client.post( - "/auth/register", - json={ - "email": "userd@example.com", - "password": "securepass123", - "display_name": "User D", - }, + _, session_token = await _create_test_user_and_session( + client, db_engine, email="userd@e2e.com", display_name="User D" ) - assert reg.status_code == 201 - user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"} - # User D should have no connected stores resp = await client.get("/me/stores", headers=user_d_headers) assert resp.status_code == 200 assert len(resp.json()) == 0, "New user should have no connected stores" diff --git a/tests/test_routes/test_purchases.py b/tests/test_routes/test_purchases.py index 14d5eb6..2b1f47b 100644 --- a/tests/test_routes/test_purchases.py +++ b/tests/test_routes/test_purchases.py @@ -1,26 +1,25 @@ """Integration tests for purchase endpoints.""" +import secrets import uuid -from datetime import date +from datetime import UTC, date, datetime, timedelta from decimal import Decimal import pytest +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from cartsnitch_api.auth.jwt import create_access_token from cartsnitch_api.models import Purchase, PurchaseItem, Store, User @pytest.fixture async def purchase_data(db_engine): - """Seed a user, store, purchase, and items.""" + """Seed a user, store, purchase, items, and a valid session.""" factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: - from cartsnitch_api.auth.passwords import hash_password - user = User( email="buyer@example.com", - hashed_password=hash_password("testpass123"), + hashed_password="not-used-with-better-auth", display_name="Buyer", ) store = Store(name="Kroger", slug="kroger") @@ -50,13 +49,33 @@ async def purchase_data(db_engine): session.add(item) await session.commit() - token = create_access_token(user.id) - return { - "user": user, - "store": store, - "purchase": purchase, - "headers": {"Authorization": f"Bearer {token}"}, - } + # Create a session token directly in the sessions table + session_token = secrets.token_urlsafe(32) + now = datetime.now(UTC).isoformat() + expires = (datetime.now(UTC) + timedelta(days=7)).isoformat() + + async with db_engine.begin() as conn: + await conn.execute( + text( + "INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) " + "VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)" + ), + { + "id": str(uuid.uuid4()), + "token": session_token, + "user_id": str(user.id), + "expires_at": expires, + "created_at": now, + "updated_at": now, + }, + ) + + return { + "user": user, + "store": store, + "purchase": purchase, + "headers": {"Cookie": f"better-auth.session_token={session_token}"}, + } @pytest.mark.asyncio