forked from cartsnitch/api
fix(auth): revert to Better-Auth session-cookie auth, preserve email-in feature
- Revert auth/dependencies.py, auth/routes.py, services/auth.py, schemas.py to Better-Auth session-cookie auth (removed JWT register/login/refresh) - Preserve GET /auth/me/email-in-address endpoint - Fix UUIDString TypeDecorator: process_result_value returns uuid.UUID (not str) so SQLAlchemy 2.0 sentinel tracking matches UUID-to-UUID - Fix seed_data fixture: look up real user_id from session token via sessions table; purchases now reference actual user FK - Update purchase_data fixture to use session-cookie auth - Update test_auth_endpoints, test_auth_validation to cookie-based tests - Remove TestRegistrationErrors and TestLoginErrors (no longer applicable) - Update test_openapi.py expected routes and count - Update test_error_handler.py to use PATCH /auth/me validation Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -1,34 +1,91 @@
|
||||
"""FastAPI dependency injection for authentication."""
|
||||
"""FastAPI dependency injection for authentication.
|
||||
|
||||
Validates Better-Auth session tokens from cookies or Bearer header.
|
||||
Sessions are verified by querying the shared sessions table directly.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.jwt import decode_token
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.database import get_db
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
||||
# but we support Bearer tokens for service-to-service or mobile clients.
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Better-Auth session cookie name
|
||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||
|
||||
|
||||
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
|
||||
"""Validate a Better-Auth session token against the sessions table.
|
||||
|
||||
Returns the user_id (as UUID) if the session is valid and not expired.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid session token",
|
||||
)
|
||||
|
||||
user_id, expires_at = row
|
||||
# SQLite stores datetimes as ISO strings; parse if necessary
|
||||
if isinstance(expires_at, str):
|
||||
expires_at = datetime.fromisoformat(expires_at)
|
||||
if expires_at.tzinfo is None:
|
||||
# Treat naive datetimes as UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < datetime.now(UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired",
|
||||
)
|
||||
|
||||
return UUID(str(user_id))
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> UUID:
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
except ValueError:
|
||||
"""Extract and validate the session token from cookie or Authorization header.
|
||||
|
||||
Checks in order:
|
||||
1. Better-Auth session cookie (primary — web clients)
|
||||
2. Bearer token in Authorization header (fallback — API clients)
|
||||
"""
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie
|
||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
token = cookie_token
|
||||
|
||||
# 2. Fall back to Bearer header
|
||||
if not token and credentials:
|
||||
token = credentials.credentials
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
) from None
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
) from None
|
||||
|
||||
return UUID(payload["sub"])
|
||||
return await _validate_session_token(token, db)
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Auth routes: register, login, refresh, me, update, delete."""
|
||||
"""Auth routes: user profile management.
|
||||
|
||||
Registration, login, refresh, and session management are handled by
|
||||
the Better-Auth service (auth/). This router provides user profile
|
||||
endpoints that query our own user data from the shared database.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.models import User
|
||||
from cartsnitch_api.schemas import (
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
TokenResponse,
|
||||
EmailInAddressResponse,
|
||||
UpdateUserRequest,
|
||||
UserResponse,
|
||||
)
|
||||
@@ -23,37 +22,6 @@ from cartsnitch_api.services.auth import AuthService
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.register(body.email, body.password, body.display_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.login(body.email, body.password)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
|
||||
) from None
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.refresh(body.refresh_token)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
@@ -99,26 +67,15 @@ async def delete_me(
|
||||
) from None
|
||||
|
||||
|
||||
class EmailInAddressResponse(BaseModel):
|
||||
email_address: str
|
||||
instructions: str
|
||||
|
||||
|
||||
@router.get("/me/email-in-address", response_model=EmailInAddressResponse)
|
||||
async def get_email_in_address(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(User.email_inbound_token).where(User.id == user_id))
|
||||
token = result.scalar_one_or_none()
|
||||
if not token:
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.get_email_in_address(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found"
|
||||
) from None
|
||||
return EmailInAddressResponse(
|
||||
email_address=f"receipts+{token}@receipts.cartsnitch.com",
|
||||
instructions=(
|
||||
"Forward your digital receipt emails to this address. "
|
||||
"We currently support Meijer, Kroger, and Target receipt emails."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,39 @@
|
||||
"""Base model and mixins for all CartSnitch ORM models."""
|
||||
|
||||
import uuid
|
||||
import uuid as uuid_lib
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, String, TypeDecorator, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class UUIDString(TypeDecorator):
|
||||
"""Store UUIDs as VARCHAR(36) strings in all dialects.
|
||||
|
||||
This handles the fundamental mismatch between Python's uuid.UUID objects
|
||||
(used everywhere in application code) and SQLite's lack of a native UUID type.
|
||||
- On INSERT: converts uuid.UUID → str
|
||||
- On SELECT: returns uuid.UUID (so SQLAlchemy 2.0 sentinel tracking matches correctly)
|
||||
"""
|
||||
|
||||
impl = String(36)
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid_lib.UUID):
|
||||
return str(value)
|
||||
return value # already a string
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid_lib.UUID):
|
||||
return value
|
||||
return uuid_lib.UUID(value) # convert str → UUID for correct sentinel tracking
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all CartSnitch models."""
|
||||
|
||||
@@ -23,8 +50,14 @@ class TimestampMixin:
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin providing a UUID primary key."""
|
||||
"""Mixin providing a UUID primary key.
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
|
||||
Uses UUIDString so all DB dialects store the full 36-char UUID string
|
||||
without truncation, while Python code always works with uuid.UUID objects.
|
||||
"""
|
||||
|
||||
id: Mapped[uuid_lib.UUID] = mapped_column(
|
||||
UUIDString(),
|
||||
primary_key=True,
|
||||
default=uuid_lib.uuid4,
|
||||
)
|
||||
|
||||
@@ -6,28 +6,8 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
# ---------- Auth ----------
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
display_name: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
||||
# These schemas are for the profile management endpoints only.
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
@@ -285,6 +265,13 @@ class ErrorResponse(BaseModel):
|
||||
code: str | None = None
|
||||
|
||||
|
||||
# ---------- Email-In ----------
|
||||
|
||||
class EmailInAddressResponse(BaseModel):
|
||||
email_address: str
|
||||
instructions: str
|
||||
|
||||
|
||||
# Rebuild forward refs
|
||||
ProductDetailResponse.model_rebuild()
|
||||
PriceTrendResponse.model_rebuild()
|
||||
|
||||
@@ -1,71 +1,28 @@
|
||||
"""Auth service — user registration, login, token management."""
|
||||
"""Auth service — user profile management.
|
||||
|
||||
Registration, login, token management, and session handling are now
|
||||
handled by the Better-Auth service (auth/). This service provides
|
||||
user lookup and profile update operations for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
|
||||
from cartsnitch_api.auth.passwords import hash_password, verify_password
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def register(self, email: str, password: str, display_name: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
existing = await self.db.execute(select(User).where(User.email == email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=hash_password(password),
|
||||
display_name=display_name,
|
||||
)
|
||||
self.db.add(user)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def login(self, email: str, password: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def refresh(self, refresh_token: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid refresh token") from None
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise ValueError("Invalid token type") from None
|
||||
|
||||
user_id = UUID(payload["sub"])
|
||||
|
||||
# Verify the user still exists before issuing new tokens
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise ValueError("User no longer exists")
|
||||
|
||||
return self._make_token_response(user_id)
|
||||
|
||||
async def get_user(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
# Use str() to ensure consistent string comparison for UUID columns
|
||||
# (works with both SQLite VARCHAR and Postgres UUID storage)
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == str(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -80,7 +37,8 @@ class AuthService:
|
||||
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user_id_str = str(user_id)
|
||||
result = await self.db.execute(select(User).where(User.id == user_id_str))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -89,7 +47,7 @@ class AuthService:
|
||||
user.display_name = fields["display_name"]
|
||||
if "email" in fields and fields["email"] is not None:
|
||||
existing = await self.db.execute(
|
||||
select(User).where(User.email == fields["email"], User.id != user_id)
|
||||
select(User).where(User.email == fields["email"], User.id != user_id_str)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already in use")
|
||||
@@ -108,7 +66,7 @@ class AuthService:
|
||||
async def delete_user(self, user_id: UUID) -> None:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
result = await self.db.execute(select(User).where(User.id == str(user_id)))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
@@ -116,10 +74,20 @@ class AuthService:
|
||||
await self.db.delete(user)
|
||||
await self.db.commit()
|
||||
|
||||
def _make_token_response(self, user_id: UUID) -> dict:
|
||||
async def get_email_in_address(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User.email_inbound_token).where(User.id == str(user_id))
|
||||
)
|
||||
token = result.scalar_one_or_none()
|
||||
if not token:
|
||||
raise LookupError("Email inbound token not found")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user_id),
|
||||
"refresh_token": create_refresh_token(user_id),
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.jwt_access_token_expire_minutes * 60,
|
||||
"email_address": f"receipts+{token}@receipts.cartsnitch.com",
|
||||
"instructions": (
|
||||
"Forward your digital receipt emails to this address. "
|
||||
"We currently support Meijer, Kroger, and Target receipt emails."
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user