forked from cartsnitch/cartsnitch
fix(auth): revert to Better-Auth session-cookie auth, preserve email-in feature
- Revert auth/dependencies.py, auth/routes.py, services/auth.py, schemas.py to Better-Auth session-cookie auth (removed JWT register/login/refresh) - Preserve GET /auth/me/email-in-address endpoint - Fix UUIDString TypeDecorator: process_result_value returns uuid.UUID (not str) so SQLAlchemy 2.0 sentinel tracking matches UUID-to-UUID - Fix seed_data fixture: look up real user_id from session token via sessions table; purchases now reference actual user FK - Update purchase_data fixture to use session-cookie auth - Update test_auth_endpoints, test_auth_validation to cookie-based tests - Remove TestRegistrationErrors and TestLoginErrors (no longer applicable) - Update test_openapi.py expected routes and count - Update test_error_handler.py to use PATCH /auth/me validation Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -1,34 +1,91 @@
|
||||
"""FastAPI dependency injection for authentication."""
|
||||
"""FastAPI dependency injection for authentication.
|
||||
|
||||
Validates Better-Auth session tokens from cookies or Bearer header.
|
||||
Sessions are verified by querying the shared sessions table directly.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.jwt import decode_token
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.database import get_db
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
||||
# but we support Bearer tokens for service-to-service or mobile clients.
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Better-Auth session cookie name
|
||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||
|
||||
|
||||
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
|
||||
"""Validate a Better-Auth session token against the sessions table.
|
||||
|
||||
Returns the user_id (as UUID) if the session is valid and not expired.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid session token",
|
||||
)
|
||||
|
||||
user_id, expires_at = row
|
||||
# SQLite stores datetimes as ISO strings; parse if necessary
|
||||
if isinstance(expires_at, str):
|
||||
expires_at = datetime.fromisoformat(expires_at)
|
||||
if expires_at.tzinfo is None:
|
||||
# Treat naive datetimes as UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < datetime.now(UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired",
|
||||
)
|
||||
|
||||
return UUID(str(user_id))
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> UUID:
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
except ValueError:
|
||||
"""Extract and validate the session token from cookie or Authorization header.
|
||||
|
||||
Checks in order:
|
||||
1. Better-Auth session cookie (primary — web clients)
|
||||
2. Bearer token in Authorization header (fallback — API clients)
|
||||
"""
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie
|
||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
token = cookie_token
|
||||
|
||||
# 2. Fall back to Bearer header
|
||||
if not token and credentials:
|
||||
token = credentials.credentials
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
) from None
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
) from None
|
||||
|
||||
return UUID(payload["sub"])
|
||||
return await _validate_session_token(token, db)
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Auth routes: register, login, refresh, me, update, delete."""
|
||||
"""Auth routes: user profile management.
|
||||
|
||||
Registration, login, refresh, and session management are handled by
|
||||
the Better-Auth service (auth/). This router provides user profile
|
||||
endpoints that query our own user data from the shared database.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.models import User
|
||||
from cartsnitch_api.schemas import (
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
TokenResponse,
|
||||
EmailInAddressResponse,
|
||||
UpdateUserRequest,
|
||||
UserResponse,
|
||||
)
|
||||
@@ -23,37 +22,6 @@ from cartsnitch_api.services.auth import AuthService
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.register(body.email, body.password, body.display_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.login(body.email, body.password)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
|
||||
) from None
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.refresh(body.refresh_token)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
@@ -99,26 +67,15 @@ async def delete_me(
|
||||
) from None
|
||||
|
||||
|
||||
class EmailInAddressResponse(BaseModel):
|
||||
email_address: str
|
||||
instructions: str
|
||||
|
||||
|
||||
@router.get("/me/email-in-address", response_model=EmailInAddressResponse)
|
||||
async def get_email_in_address(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(User.email_inbound_token).where(User.id == user_id))
|
||||
token = result.scalar_one_or_none()
|
||||
if not token:
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.get_email_in_address(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found"
|
||||
) from None
|
||||
return EmailInAddressResponse(
|
||||
email_address=f"receipts+{token}@receipts.cartsnitch.com",
|
||||
instructions=(
|
||||
"Forward your digital receipt emails to this address. "
|
||||
"We currently support Meijer, Kroger, and Target receipt emails."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,39 @@
|
||||
"""Base model and mixins for all CartSnitch ORM models."""
|
||||
|
||||
import uuid
|
||||
import uuid as uuid_lib
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, String, TypeDecorator, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class UUIDString(TypeDecorator):
|
||||
"""Store UUIDs as VARCHAR(36) strings in all dialects.
|
||||
|
||||
This handles the fundamental mismatch between Python's uuid.UUID objects
|
||||
(used everywhere in application code) and SQLite's lack of a native UUID type.
|
||||
- On INSERT: converts uuid.UUID → str
|
||||
- On SELECT: returns uuid.UUID (so SQLAlchemy 2.0 sentinel tracking matches correctly)
|
||||
"""
|
||||
|
||||
impl = String(36)
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid_lib.UUID):
|
||||
return str(value)
|
||||
return value # already a string
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid_lib.UUID):
|
||||
return value
|
||||
return uuid_lib.UUID(value) # convert str → UUID for correct sentinel tracking
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all CartSnitch models."""
|
||||
|
||||
@@ -23,8 +50,14 @@ class TimestampMixin:
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin providing a UUID primary key."""
|
||||
"""Mixin providing a UUID primary key.
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
|
||||
Uses UUIDString so all DB dialects store the full 36-char UUID string
|
||||
without truncation, while Python code always works with uuid.UUID objects.
|
||||
"""
|
||||
|
||||
id: Mapped[uuid_lib.UUID] = mapped_column(
|
||||
UUIDString(),
|
||||
primary_key=True,
|
||||
default=uuid_lib.uuid4,
|
||||
)
|
||||
|
||||
@@ -6,28 +6,8 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
# ---------- Auth ----------
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
display_name: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
||||
# These schemas are for the profile management endpoints only.
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
@@ -285,6 +265,13 @@ class ErrorResponse(BaseModel):
|
||||
code: str | None = None
|
||||
|
||||
|
||||
# ---------- Email-In ----------
|
||||
|
||||
class EmailInAddressResponse(BaseModel):
|
||||
email_address: str
|
||||
instructions: str
|
||||
|
||||
|
||||
# Rebuild forward refs
|
||||
ProductDetailResponse.model_rebuild()
|
||||
PriceTrendResponse.model_rebuild()
|
||||
|
||||
@@ -1,71 +1,28 @@
|
||||
"""Auth service — user registration, login, token management."""
|
||||
"""Auth service — user profile management.
|
||||
|
||||
Registration, login, token management, and session handling are now
|
||||
handled by the Better-Auth service (auth/). This service provides
|
||||
user lookup and profile update operations for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
|
||||
from cartsnitch_api.auth.passwords import hash_password, verify_password
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def register(self, email: str, password: str, display_name: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
existing = await self.db.execute(select(User).where(User.email == email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=hash_password(password),
|
||||
display_name=display_name,
|
||||
)
|
||||
self.db.add(user)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def login(self, email: str, password: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def refresh(self, refresh_token: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid refresh token") from None
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise ValueError("Invalid token type") from None
|
||||
|
||||
user_id = UUID(payload["sub"])
|
||||
|
||||
# Verify the user still exists before issuing new tokens
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise ValueError("User no longer exists")
|
||||
|
||||
return self._make_token_response(user_id)
|
||||
|
||||
async def get_user(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
# Use str() to ensure consistent string comparison for UUID columns
|
||||
# (works with both SQLite VARCHAR and Postgres UUID storage)
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == str(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -80,7 +37,8 @@ class AuthService:
|
||||
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user_id_str = str(user_id)
|
||||
result = await self.db.execute(select(User).where(User.id == user_id_str))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -89,7 +47,7 @@ class AuthService:
|
||||
user.display_name = fields["display_name"]
|
||||
if "email" in fields and fields["email"] is not None:
|
||||
existing = await self.db.execute(
|
||||
select(User).where(User.email == fields["email"], User.id != user_id)
|
||||
select(User).where(User.email == fields["email"], User.id != user_id_str)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already in use")
|
||||
@@ -108,7 +66,7 @@ class AuthService:
|
||||
async def delete_user(self, user_id: UUID) -> None:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
result = await self.db.execute(select(User).where(User.id == str(user_id)))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -116,10 +74,20 @@ class AuthService:
|
||||
await self.db.delete(user)
|
||||
await self.db.commit()
|
||||
|
||||
def _make_token_response(self, user_id: UUID) -> dict:
|
||||
async def get_email_in_address(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User.email_inbound_token).where(User.id == str(user_id))
|
||||
)
|
||||
token = result.scalar_one_or_none()
|
||||
if not token:
|
||||
raise LookupError("Email inbound token not found")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user_id),
|
||||
"refresh_token": create_refresh_token(user_id),
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.jwt_access_token_expire_minutes * 60,
|
||||
"email_address": f"receipts+{token}@receipts.cartsnitch.com",
|
||||
"instructions": (
|
||||
"Forward your digital receipt emails to this address. "
|
||||
"We currently support Meijer, Kroger, and Target receipt emails."
|
||||
),
|
||||
}
|
||||
|
||||
+102
-15
@@ -1,8 +1,16 @@
|
||||
"""Shared test fixtures with in-memory SQLite database."""
|
||||
"""Shared test fixtures with in-memory SQLite database.
|
||||
|
||||
Session-based auth: tests create users and sessions directly in the DB,
|
||||
matching the Better-Auth session validation flow.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -51,6 +59,46 @@ async def db_engine():
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
user_id TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
access_token_expires_at TIMESTAMP,
|
||||
refresh_token_expires_at TIMESTAMP,
|
||||
scope TEXT,
|
||||
id_token TEXT,
|
||||
password TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS verifications (
|
||||
id TEXT PRIMARY KEY,
|
||||
identifier TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
|
||||
yield engine
|
||||
|
||||
@@ -85,17 +133,56 @@ async def client(db_engine):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||
"""Create a test user and a valid session directly in the DB.
|
||||
|
||||
Returns (user_dict, session_token).
|
||||
"""
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_overrides.get("email", "test@example.com")
|
||||
display_name = user_overrides.get("display_name", "Test User")
|
||||
email_inbound_token = user_overrides.get("email_inbound_token", secrets.token_urlsafe(16))
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
session_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": email,
|
||||
"hashed_password": "not-used-with-better-auth",
|
||||
"display_name": display_name,
|
||||
"email_inbound_token": email_inbound_token,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": session_id,
|
||||
"token": session_token,
|
||||
"user_id": user_id,
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client):
|
||||
"""Register a test user and return auth headers."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "test@example.com",
|
||||
"password": "testpass123",
|
||||
"display_name": "Test User",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
token = resp.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
async def auth_headers(client, db_engine):
|
||||
"""Create a test user with a valid session and return auth headers."""
|
||||
_, session_token = await _create_test_user_and_session(client, db_engine)
|
||||
return {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
@@ -1,146 +1,13 @@
|
||||
"""Integration tests for auth endpoints."""
|
||||
"""Integration tests for auth profile endpoints.
|
||||
|
||||
Registration, login, and session management are handled by the Better-Auth
|
||||
service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me)
|
||||
which validate sessions via the shared sessions table.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_success(client):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "new@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "New User",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
assert data["expires_in"] == 900 # 15 min * 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_duplicate_email(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User One",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass456",
|
||||
"display_name": "User Two",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_short_password(client):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "short@example.com",
|
||||
"password": "short",
|
||||
"display_name": "Short Pass",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Login User",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "login@example.com",
|
||||
"password": "securepass123",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "wrong@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Wrong Pass",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "wrong@example.com",
|
||||
"password": "badpassword1",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_nonexistent_user(client):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "ghost@example.com",
|
||||
"password": "doesntmatter",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token(client):
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "refresh@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Refresh User",
|
||||
},
|
||||
)
|
||||
refresh_token = reg.json()["refresh_token"]
|
||||
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_with_invalid_token(client):
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "invalid.token.here",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me(client, auth_headers):
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
@@ -155,7 +22,32 @@ async def test_get_me(client, auth_headers):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_unauthorized(client):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403) # No auth header
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_invalid_session(client):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=invalid-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_with_bearer_token(client, db_engine):
|
||||
"""Session tokens can also be passed as Bearer tokens for API clients."""
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@example.com", display_name="Bearer User"
|
||||
)
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -163,9 +55,7 @@ async def test_update_me(client, auth_headers):
|
||||
resp = await client.patch(
|
||||
"/auth/me",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"display_name": "Updated Name",
|
||||
},
|
||||
json={"display_name": "Updated Name"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["display_name"] == "Updated Name"
|
||||
@@ -176,34 +66,58 @@ async def test_delete_me(client, auth_headers):
|
||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify user is gone (token still valid but user deleted)
|
||||
# Session is still valid but user is gone
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_after_delete_fails(client):
|
||||
"""Refresh token for a deleted user must be rejected."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "ghost@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Ghost User",
|
||||
},
|
||||
)
|
||||
tokens = reg.json()
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
async def test_expired_session_rejected(client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
# Delete the user
|
||||
resp = await client.delete("/auth/me", headers=headers)
|
||||
assert resp.status_code == 204
|
||||
from sqlalchemy import text
|
||||
|
||||
# Refresh token should now fail
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": tokens["refresh_token"],
|
||||
},
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@example.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"eit": secrets.token_urlsafe(16),
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@@ -7,12 +7,13 @@ exercise cross-resource queries against real data.
|
||||
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.auth.jwt import decode_token
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
@@ -33,17 +34,20 @@ ANCHOR_DATE = date.today()
|
||||
@pytest.fixture
|
||||
async def seed_data(db_engine, auth_headers):
|
||||
"""Seed a full dataset and return identifiers for test assertions."""
|
||||
import uuid
|
||||
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
# -- Stores --
|
||||
meijer = Store(name="Meijer", slug="meijer")
|
||||
kroger = Store(name="Kroger", slug="kroger")
|
||||
target = Store(name="Target", slug="target")
|
||||
meijer = Store(name="Meijer", slug="meijer", id=uuid.uuid4())
|
||||
kroger = Store(name="Kroger", slug="kroger", id=uuid.uuid4())
|
||||
target = Store(name="Target", slug="target", id=uuid.uuid4())
|
||||
session.add_all([meijer, kroger, target])
|
||||
await session.flush()
|
||||
|
||||
# -- Products --
|
||||
cheerios = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
@@ -52,6 +56,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
milk = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Whole Milk 1gal",
|
||||
category="dairy",
|
||||
brand="Meijer",
|
||||
@@ -59,6 +64,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
size_unit="gal",
|
||||
)
|
||||
chicken = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Chicken Breast 1lb",
|
||||
category="meat",
|
||||
brand=None,
|
||||
@@ -75,6 +81,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
@@ -86,6 +93,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
for i in range(3):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
@@ -96,6 +104,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
# Milk at Meijer
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=milk.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=7),
|
||||
@@ -106,6 +115,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
# Milk at Kroger
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=milk.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=5),
|
||||
@@ -116,6 +126,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
# Chicken at Target
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=chicken.id,
|
||||
store_id=target.id,
|
||||
observed_date=today - timedelta(days=3),
|
||||
@@ -127,12 +138,28 @@ async def seed_data(db_engine, auth_headers):
|
||||
await session.flush()
|
||||
|
||||
# -- Purchases (need the user_id from the registered test user) --
|
||||
token = auth_headers["Authorization"].split(" ")[1]
|
||||
payload = decode_token(token)
|
||||
user_id = UUID(payload["sub"])
|
||||
# Extract session_token from auth_headers, then look up the real user_id
|
||||
import http.cookies
|
||||
cookie_header = auth_headers.get("Cookie", "")
|
||||
cookies = http.cookies.SimpleCookie()
|
||||
cookies.load(cookie_header)
|
||||
session_token = cookies.get("better-auth.session_token").value if "better-auth.session_token" in cookie_header else None
|
||||
if session_token is None:
|
||||
raise RuntimeError("seed_data fixture requires cookie-based auth session token")
|
||||
|
||||
# Look up the real user_id from the sessions table
|
||||
row = await session.execute(
|
||||
text("SELECT user_id FROM sessions WHERE token = :token"),
|
||||
{"token": session_token}
|
||||
)
|
||||
session_row = row.fetchone()
|
||||
if session_row is None:
|
||||
raise RuntimeError("Session not found for session token in auth_headers")
|
||||
real_user_id = session_row[0]
|
||||
|
||||
purchase1 = Purchase(
|
||||
user_id=user_id,
|
||||
id=uuid.uuid4(),
|
||||
user_id=uuid.UUID(real_user_id),
|
||||
store_id=meijer.id,
|
||||
receipt_id="meijer-2026-001",
|
||||
purchase_date=today - timedelta(days=10),
|
||||
@@ -141,7 +168,8 @@ async def seed_data(db_engine, auth_headers):
|
||||
tax=Decimal("1.95"),
|
||||
)
|
||||
purchase2 = Purchase(
|
||||
user_id=user_id,
|
||||
id=uuid.uuid4(),
|
||||
user_id=uuid.UUID(real_user_id),
|
||||
store_id=kroger.id,
|
||||
receipt_id="kroger-2026-001",
|
||||
purchase_date=today - timedelta(days=5),
|
||||
@@ -154,6 +182,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
|
||||
# -- Purchase Items --
|
||||
item1 = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Cheerios 18oz Box",
|
||||
quantity=Decimal("1"),
|
||||
@@ -162,6 +191,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
item2 = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Meijer Whole Milk 1gal",
|
||||
quantity=Decimal("2"),
|
||||
@@ -170,6 +200,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
normalized_product_id=milk.id,
|
||||
)
|
||||
item3 = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase2.id,
|
||||
product_name_raw="KRO CHEERIOS 18OZ",
|
||||
quantity=Decimal("1"),
|
||||
@@ -182,6 +213,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
|
||||
# -- Coupons --
|
||||
coupon1 = Coupon(
|
||||
id=uuid.uuid4(),
|
||||
store_id=meijer.id,
|
||||
normalized_product_id=cheerios.id,
|
||||
title="$1 off Cheerios",
|
||||
@@ -192,6 +224,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
valid_to=today + timedelta(days=30),
|
||||
)
|
||||
coupon2 = Coupon(
|
||||
id=uuid.uuid4(),
|
||||
store_id=kroger.id,
|
||||
normalized_product_id=None,
|
||||
title="10% off dairy",
|
||||
@@ -206,6 +239,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
|
||||
# -- Shrinkflation events --
|
||||
shrink = ShrinkflationEvent(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=cheerios.id,
|
||||
detected_date=today - timedelta(days=15),
|
||||
old_size="20",
|
||||
@@ -240,7 +274,7 @@ async def seed_data(db_engine, auth_headers):
|
||||
|
||||
return {
|
||||
"headers": auth_headers,
|
||||
"user_id": user_id,
|
||||
"user_id": real_user_id,
|
||||
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
|
||||
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
|
||||
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
|
||||
|
||||
@@ -1,132 +1,103 @@
|
||||
"""E2E: Auth and token validation flows."""
|
||||
"""E2E: Auth and session validation flows.
|
||||
|
||||
import asyncio
|
||||
Registration and login are handled by the Better-Auth service.
|
||||
These tests validate session token handling at the API gateway level.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthRegistrationLogin:
|
||||
"""Full registration → login → token refresh → profile flow."""
|
||||
|
||||
async def test_full_auth_lifecycle(self, client, db_engine):
|
||||
"""Register → login → get profile → refresh → get profile again."""
|
||||
# Register
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "lifecycle@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Lifecycle User",
|
||||
},
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
tokens = reg.json()
|
||||
assert "access_token" in tokens
|
||||
assert "refresh_token" in tokens
|
||||
assert tokens["token_type"] == "bearer"
|
||||
assert tokens["expires_in"] > 0
|
||||
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
|
||||
# Get profile with access token
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code == 200
|
||||
assert me.json()["email"] == "lifecycle@example.com"
|
||||
assert me.json()["display_name"] == "Lifecycle User"
|
||||
|
||||
# Sleep 1s so the new token has a different exp than the registration token
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Login with same credentials
|
||||
login = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "lifecycle@example.com", "password": "securepass123"},
|
||||
)
|
||||
assert login.status_code == 200
|
||||
login_tokens = login.json()
|
||||
assert login_tokens["access_token"] != tokens["access_token"]
|
||||
|
||||
# Refresh token
|
||||
refresh = await client.post(
|
||||
"/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert refresh.status_code == 200
|
||||
new_tokens = refresh.json()
|
||||
assert new_tokens["access_token"] != tokens["access_token"]
|
||||
|
||||
# Use refreshed token to access profile
|
||||
new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"}
|
||||
me2 = await client.get("/auth/me", headers=new_headers)
|
||||
assert me2.status_code == 200
|
||||
assert me2.json()["email"] == "lifecycle@example.com"
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTokenValidation:
|
||||
"""Token edge cases and error responses."""
|
||||
class TestSessionValidation:
|
||||
"""Session edge cases and error responses."""
|
||||
|
||||
async def test_expired_token_rejected(self, client, db_engine):
|
||||
"""Manually craft an expired token and verify rejection."""
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"exp": datetime.now(UTC) - timedelta(minutes=5),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
async def test_invalid_session_token_rejected(self, client, db_engine):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=not-a-real-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_invalid_token_rejected(self, client, db_engine):
|
||||
resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_missing_auth_header(self, client, db_engine):
|
||||
async def test_missing_auth(self, client, db_engine):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
async def test_refresh_token_cannot_access_endpoints(self, client, db_engine):
|
||||
"""A refresh token should not work as an access token."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "refresh-test@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Refresh Test",
|
||||
},
|
||||
async def test_bearer_token_also_works(self, client, db_engine):
|
||||
"""Session tokens passed as Bearer tokens should also be accepted."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E"
|
||||
)
|
||||
refresh_token = reg.json()["refresh_token"]
|
||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_deleted_user_token_invalid(self, client, db_engine):
|
||||
"""After deleting an account, tokens should no longer work."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "delete-me@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Delete Me",
|
||||
},
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
tokens = reg.json()
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@e2e.com"
|
||||
|
||||
async def test_deleted_user_session_returns_not_found(self, client, db_engine):
|
||||
"""After deleting a user, their session should result in 404 for profile."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="delete-me@e2e.com", display_name="Delete Me"
|
||||
)
|
||||
headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
# Delete account
|
||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||
assert delete_resp.status_code == 204
|
||||
|
||||
# Profile should fail
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code in (401, 404)
|
||||
assert me.status_code == 404
|
||||
|
||||
async def test_expired_session_rejected(self, client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :eit, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@e2e.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"eit": secrets.token_urlsafe(16),
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -154,60 +125,38 @@ class TestAuthProtectedEndpoints:
|
||||
class TestCrossUserDataIsolation:
|
||||
"""Verify that users cannot access other users' data."""
|
||||
|
||||
async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data):
|
||||
"""Register a second user and verify they cannot see User A's purchases."""
|
||||
# User A's purchase (from seed_data)
|
||||
async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data):
|
||||
"""A second user cannot see User A's purchases."""
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
# Register User B
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userb@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User B",
|
||||
},
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userb@e2e.com", display_name="User B"
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
# User B tries to access User A's specific purchase
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||
assert resp.status_code in (403, 404), (
|
||||
"User B should not be able to access User A's purchase"
|
||||
)
|
||||
|
||||
async def test_user_b_purchase_list_is_empty(self, client, seed_data):
|
||||
"""A new user should see no purchases (not User A's purchases)."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userc@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User C",
|
||||
},
|
||||
async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data):
|
||||
"""A new user should see no purchases."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userc@e2e.com", display_name="User C"
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get("/purchases", headers=user_c_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||
|
||||
async def test_user_b_stores_isolated(self, client, seed_data):
|
||||
async def test_user_b_stores_isolated(self, client, db_engine, seed_data):
|
||||
"""User B's connected stores should be independent from User A."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userd@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User D",
|
||||
},
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userd@e2e.com", display_name="User D"
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
# User D should have no connected stores
|
||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||
|
||||
@@ -5,74 +5,6 @@ import pytest
|
||||
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRegistrationErrors:
|
||||
"""Validation errors during user registration."""
|
||||
|
||||
async def test_short_password(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_email(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_missing_fields(self, client, db_engine):
|
||||
resp = await client.post("/auth/register", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_empty_display_name(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_duplicate_email(self, client, db_engine):
|
||||
payload = {
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "First",
|
||||
}
|
||||
first = await client.post("/auth/register", json=payload)
|
||||
assert first.status_code == 201
|
||||
second = await client.post("/auth/register", json=payload)
|
||||
assert second.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoginErrors:
|
||||
"""Login failure modes."""
|
||||
|
||||
async def test_wrong_password(self, client, db_engine):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login-err@example.com",
|
||||
"password": "correctpass1",
|
||||
"display_name": "Login",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "login-err@example.com", "password": "wrongpass123"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_nonexistent_user(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "nobody@example.com", "password": "doesntmatter"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestNotFoundErrors:
|
||||
"""404 responses for missing resources."""
|
||||
|
||||
@@ -15,11 +15,12 @@ async def test_404_returns_structured_error(client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_returns_422_with_field_errors(client):
|
||||
async def test_validation_error_returns_422_with_field_errors(client, auth_headers):
|
||||
"""Invalid request body should return structured validation errors."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "short", "display_name": ""},
|
||||
resp = await client.patch(
|
||||
"/auth/me",
|
||||
headers=auth_headers,
|
||||
json={"display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
|
||||
@@ -6,13 +6,11 @@ from httpx import ASGITransport, AsyncClient
|
||||
from cartsnitch_api.main import app
|
||||
|
||||
EXPECTED_ROUTES = [
|
||||
# Auth (6)
|
||||
("post", "/auth/register"),
|
||||
("post", "/auth/login"),
|
||||
("post", "/auth/refresh"),
|
||||
# Auth (4 — register/login/refresh handled by Better-Auth service)
|
||||
("get", "/auth/me"),
|
||||
("patch", "/auth/me"),
|
||||
("delete", "/auth/me"),
|
||||
("get", "/auth/me/email-in-address"),
|
||||
# Stores (4)
|
||||
("get", "/stores"),
|
||||
("get", "/me/stores"),
|
||||
@@ -89,4 +87,4 @@ async def test_route_count():
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
count += 1
|
||||
|
||||
assert count == 34, f"Expected 34 routes, found {count}"
|
||||
assert count == 31, f"Expected 31 routes, found {count}"
|
||||
|
||||
@@ -1,46 +1,82 @@
|
||||
"""Integration tests for purchase endpoints."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import date
|
||||
from datetime import UTC, datetime, date, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import text
|
||||
|
||||
from cartsnitch_api.auth.jwt import create_access_token
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def purchase_data(db_engine):
|
||||
"""Seed a user, store, purchase, and items."""
|
||||
"""Seed a user, store, purchase, and items using session-cookie auth."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
from cartsnitch_api.auth.passwords import hash_password
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
user = User(
|
||||
email="buyer@example.com",
|
||||
hashed_password=hash_password("testpass123"),
|
||||
display_name="Buyer",
|
||||
# Create the user
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_inbound_token, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hashed_password, :display_name, :email_inbound_token, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "buyer@example.com",
|
||||
"hashed_password": "not-used-with-better-auth",
|
||||
"display_name": "Buyer",
|
||||
"email_inbound_token": secrets.token_urlsafe(16),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
store = Store(name="Kroger", slug="kroger")
|
||||
session.add_all([user, store])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
# Create the session
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"user_id": user_id,
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
# Create the store
|
||||
store = Store(name="Kroger", slug="kroger", id=uuid.uuid4())
|
||||
session.add(store)
|
||||
await session.flush()
|
||||
await session.refresh(store)
|
||||
|
||||
# Create the purchase
|
||||
purchase = Purchase(
|
||||
user_id=user.id,
|
||||
id=uuid.uuid4(),
|
||||
user_id=uuid.UUID(user_id),
|
||||
store_id=store.id,
|
||||
receipt_id="receipt-001",
|
||||
purchase_date=date(2026, 3, 10),
|
||||
total=Decimal("42.50"),
|
||||
)
|
||||
session.add(purchase)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
await session.refresh(purchase)
|
||||
|
||||
# Create the purchase item
|
||||
item = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Organic Milk 1gal",
|
||||
quantity=Decimal("1"),
|
||||
@@ -50,12 +86,11 @@ async def purchase_data(db_engine):
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
token = create_access_token(user.id)
|
||||
return {
|
||||
"user": user,
|
||||
"user_id": user_id,
|
||||
"store": store,
|
||||
"purchase": purchase,
|
||||
"headers": {"Authorization": f"Bearer {token}"},
|
||||
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user