fix: align auth client basePath with server config
fix: align auth client basePath with server config
This commit is contained in:
@@ -0,0 +1,101 @@
|
|||||||
|
"""Add Better-Auth tables and extend users table.
|
||||||
|
|
||||||
|
Creates sessions, accounts, and verifications tables for Better-Auth.
|
||||||
|
Adds email_verified and image columns to existing users table.
|
||||||
|
Migrates password hashes from users.hashed_password to accounts.password.
|
||||||
|
|
||||||
|
Revision ID: 002_better_auth_tables
|
||||||
|
Revises: 001_encrypt_session_data
|
||||||
|
Create Date: 2026-03-28
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "002_better_auth_tables"
|
||||||
|
down_revision = "001_encrypt_session_data"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# --- Extend users table for Better-Auth compatibility ---
|
||||||
|
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
|
||||||
|
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# --- Create sessions table ---
|
||||||
|
op.create_table(
|
||||||
|
"sessions",
|
||||||
|
sa.Column("id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("token", sa.Text(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("ip_address", sa.Text(), nullable=True),
|
||||||
|
sa.Column("user_agent", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
|
||||||
|
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
|
||||||
|
|
||||||
|
# --- Create accounts table ---
|
||||||
|
op.create_table(
|
||||||
|
"accounts",
|
||||||
|
sa.Column("id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("account_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("provider_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("access_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("scope", sa.Text(), nullable=True),
|
||||||
|
sa.Column("id_token", sa.Text(), nullable=True),
|
||||||
|
sa.Column("password", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
|
||||||
|
|
||||||
|
# --- Create verifications table ---
|
||||||
|
op.create_table(
|
||||||
|
"verifications",
|
||||||
|
sa.Column("id", sa.Text(), nullable=False),
|
||||||
|
sa.Column("identifier", sa.Text(), nullable=False),
|
||||||
|
sa.Column("value", sa.Text(), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Migrate existing password hashes to accounts table ---
|
||||||
|
# For each user with a hashed_password, create a 'credential' account row
|
||||||
|
conn = op.get_bind()
|
||||||
|
users = conn.execute(
|
||||||
|
text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for user_id, hashed_password in users:
|
||||||
|
user_id_str = str(user_id)
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
|
||||||
|
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
|
||||||
|
),
|
||||||
|
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("verifications")
|
||||||
|
op.drop_table("accounts")
|
||||||
|
op.drop_index("ix_sessions_user_id", table_name="sessions")
|
||||||
|
op.drop_index("ix_sessions_token", table_name="sessions")
|
||||||
|
op.drop_table("sessions")
|
||||||
|
op.drop_column("users", "image")
|
||||||
|
op.drop_column("users", "email_verified")
|
||||||
@@ -1,34 +1,88 @@
|
|||||||
"""FastAPI dependency injection for authentication."""
|
"""FastAPI dependency injection for authentication.
|
||||||
|
|
||||||
|
Validates Better-Auth session tokens from cookies or Bearer header.
|
||||||
|
Sessions are verified by querying the shared sessions table directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cartsnitch_api.auth.jwt import decode_token
|
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
|
from cartsnitch_api.database import get_db
|
||||||
|
|
||||||
bearer_scheme = HTTPBearer()
|
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
||||||
|
# but we support Bearer tokens for service-to-service or mobile clients.
|
||||||
|
bearer_scheme = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
# Better-Auth session cookie name
|
||||||
|
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
|
||||||
|
"""Validate a Better-Auth session token against the sessions table.
|
||||||
|
|
||||||
|
Returns the user_id (as UUID) if the session is valid and not expired.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||||
|
{"token": token},
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid session token",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id, expires_at = row
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
# Treat naive datetimes as UTC
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
if expires_at < datetime.now(UTC):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Session expired",
|
||||||
|
)
|
||||||
|
|
||||||
|
return UUID(str(user_id))
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
request: Request,
|
||||||
|
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> UUID:
|
) -> UUID:
|
||||||
try:
|
"""Extract and validate the session token from cookie or Authorization header.
|
||||||
payload = decode_token(credentials.credentials)
|
|
||||||
except ValueError:
|
Checks in order:
|
||||||
|
1. Better-Auth session cookie (primary — web clients)
|
||||||
|
2. Bearer token in Authorization header (fallback — API clients)
|
||||||
|
"""
|
||||||
|
token: str | None = None
|
||||||
|
|
||||||
|
# 1. Check session cookie
|
||||||
|
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||||
|
if cookie_token:
|
||||||
|
token = cookie_token
|
||||||
|
|
||||||
|
# 2. Fall back to Bearer header
|
||||||
|
if not token and credentials:
|
||||||
|
token = credentials.credentials
|
||||||
|
|
||||||
|
if not token:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid or expired token",
|
detail="Authentication required",
|
||||||
) from None
|
)
|
||||||
|
|
||||||
if payload.get("type") != "access":
|
return await _validate_session_token(token, db)
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid token type",
|
|
||||||
) from None
|
|
||||||
|
|
||||||
return UUID(payload["sub"])
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
"""Auth routes: register, login, refresh, me, update, delete."""
|
"""Auth routes: user profile management.
|
||||||
|
|
||||||
|
Registration, login, refresh, and session management are handled by
|
||||||
|
the Better-Auth service (auth/). This router provides user profile
|
||||||
|
endpoints that query our own user data from the shared database.
|
||||||
|
"""
|
||||||
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -8,10 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from cartsnitch_api.auth.dependencies import get_current_user
|
from cartsnitch_api.auth.dependencies import get_current_user
|
||||||
from cartsnitch_api.database import get_db
|
from cartsnitch_api.database import get_db
|
||||||
from cartsnitch_api.schemas import (
|
from cartsnitch_api.schemas import (
|
||||||
LoginRequest,
|
|
||||||
RefreshRequest,
|
|
||||||
RegisterRequest,
|
|
||||||
TokenResponse,
|
|
||||||
UpdateUserRequest,
|
UpdateUserRequest,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
)
|
)
|
||||||
@@ -20,37 +21,6 @@ from cartsnitch_api.services.auth import AuthService
|
|||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|
||||||
svc = AuthService(db)
|
|
||||||
try:
|
|
||||||
return await svc.register(body.email, body.password, body.display_name)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
|
||||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|
||||||
svc = AuthService(db)
|
|
||||||
try:
|
|
||||||
return await svc.login(body.email, body.password)
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=TokenResponse)
|
|
||||||
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
|
||||||
svc = AuthService(db)
|
|
||||||
try:
|
|
||||||
return await svc.refresh(body.refresh_token)
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def get_me(
|
async def get_me(
|
||||||
user_id: UUID = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ class Settings(BaseSettings):
|
|||||||
# Valid Fernet key for local dev — MUST be overridden in production
|
# Valid Fernet key for local dev — MUST be overridden in production
|
||||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||||
|
|
||||||
|
auth_service_url: str = "http://auth:3001"
|
||||||
|
|
||||||
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
||||||
|
|
||||||
receiptwitness_url: str = "http://receiptwitness:8001"
|
receiptwitness_url: str = "http://receiptwitness:8001"
|
||||||
|
|||||||
@@ -6,28 +6,8 @@ from uuid import UUID
|
|||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
# ---------- Auth ----------
|
# ---------- Auth ----------
|
||||||
|
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
||||||
|
# These schemas are for the profile management endpoints only.
|
||||||
class RegisterRequest(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(min_length=8, max_length=128)
|
|
||||||
display_name: str = Field(min_length=1, max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshRequest(BaseModel):
|
|
||||||
refresh_token: str
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
refresh_token: str
|
|
||||||
token_type: str = "bearer"
|
|
||||||
expires_in: int
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateUserRequest(BaseModel):
|
class UpdateUserRequest(BaseModel):
|
||||||
|
|||||||
@@ -1,67 +1,20 @@
|
|||||||
"""Auth service — user registration, login, token management."""
|
"""Auth service — user profile management.
|
||||||
|
|
||||||
|
Registration, login, token management, and session handling are now
|
||||||
|
handled by the Better-Auth service (auth/). This service provides
|
||||||
|
user lookup and profile update operations for the API gateway.
|
||||||
|
"""
|
||||||
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
|
|
||||||
from cartsnitch_api.auth.passwords import hash_password, verify_password
|
|
||||||
from cartsnitch_api.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def register(self, email: str, password: str, display_name: str) -> dict:
|
|
||||||
from cartsnitch_api.models import User
|
|
||||||
|
|
||||||
existing = await self.db.execute(select(User).where(User.email == email))
|
|
||||||
if existing.scalar_one_or_none():
|
|
||||||
raise ValueError("Email already registered")
|
|
||||||
|
|
||||||
user = User(
|
|
||||||
email=email,
|
|
||||||
hashed_password=hash_password(password),
|
|
||||||
display_name=display_name,
|
|
||||||
)
|
|
||||||
self.db.add(user)
|
|
||||||
await self.db.commit()
|
|
||||||
await self.db.refresh(user)
|
|
||||||
|
|
||||||
return self._make_token_response(user.id)
|
|
||||||
|
|
||||||
async def login(self, email: str, password: str) -> dict:
|
|
||||||
from cartsnitch_api.models import User
|
|
||||||
|
|
||||||
result = await self.db.execute(select(User).where(User.email == email))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user or not verify_password(password, user.hashed_password):
|
|
||||||
raise ValueError("Invalid email or password")
|
|
||||||
|
|
||||||
return self._make_token_response(user.id)
|
|
||||||
|
|
||||||
async def refresh(self, refresh_token: str) -> dict:
|
|
||||||
from cartsnitch_api.models import User
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = decode_token(refresh_token)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("Invalid refresh token") from None
|
|
||||||
|
|
||||||
if payload.get("type") != "refresh":
|
|
||||||
raise ValueError("Invalid token type") from None
|
|
||||||
|
|
||||||
user_id = UUID(payload["sub"])
|
|
||||||
|
|
||||||
# Verify the user still exists before issuing new tokens
|
|
||||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
||||||
if not result.scalar_one_or_none():
|
|
||||||
raise ValueError("User no longer exists")
|
|
||||||
|
|
||||||
return self._make_token_response(user_id)
|
|
||||||
|
|
||||||
async def get_user(self, user_id: UUID) -> dict:
|
async def get_user(self, user_id: UUID) -> dict:
|
||||||
from cartsnitch_api.models import User
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
@@ -115,11 +68,3 @@ class AuthService:
|
|||||||
|
|
||||||
await self.db.delete(user)
|
await self.db.delete(user)
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
|
|
||||||
def _make_token_response(self, user_id: UUID) -> dict:
|
|
||||||
return {
|
|
||||||
"access_token": create_access_token(user_id),
|
|
||||||
"refresh_token": create_refresh_token(user_id),
|
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": settings.jwt_access_token_expire_minutes * 60,
|
|
||||||
}
|
|
||||||
|
|||||||
+101
-15
@@ -1,8 +1,16 @@
|
|||||||
"""Shared test fixtures with in-memory SQLite database."""
|
"""Shared test fixtures with in-memory SQLite database.
|
||||||
|
|
||||||
|
Session-based auth: tests create users and sessions directly in the DB,
|
||||||
|
matching the Better-Auth session validation flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy import create_engine, event
|
from sqlalchemy import create_engine, event, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@@ -51,6 +59,46 @@ async def db_engine():
|
|||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||||
|
await conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
ip_address TEXT,
|
||||||
|
user_agent TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
"""))
|
||||||
|
await conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
account_id TEXT NOT NULL,
|
||||||
|
provider_id TEXT NOT NULL,
|
||||||
|
access_token TEXT,
|
||||||
|
refresh_token TEXT,
|
||||||
|
access_token_expires_at TIMESTAMP,
|
||||||
|
refresh_token_expires_at TIMESTAMP,
|
||||||
|
scope TEXT,
|
||||||
|
id_token TEXT,
|
||||||
|
password TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
"""))
|
||||||
|
await conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS verifications (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
identifier TEXT NOT NULL,
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
"""))
|
||||||
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
@@ -85,17 +133,55 @@ async def client(db_engine):
|
|||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||||
|
"""Create a test user and a valid session directly in the DB.
|
||||||
|
|
||||||
|
Returns (user_dict, session_token).
|
||||||
|
"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
email = user_overrides.get("email", "test@example.com")
|
||||||
|
display_name = user_overrides.get("display_name", "Test User")
|
||||||
|
session_token = secrets.token_urlsafe(32)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||||
|
|
||||||
|
async with db_engine.begin() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": user_id,
|
||||||
|
"email": email,
|
||||||
|
"hashed_password": "not-used-with-better-auth",
|
||||||
|
"display_name": display_name,
|
||||||
|
"email_verified": False,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": session_id,
|
||||||
|
"token": session_token,
|
||||||
|
"user_id": user_id,
|
||||||
|
"expires_at": expires,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def auth_headers(client):
|
async def auth_headers(client, db_engine):
|
||||||
"""Register a test user and return auth headers."""
|
"""Create a test user with a valid session and return auth headers."""
|
||||||
resp = await client.post(
|
_, session_token = await _create_test_user_and_session(client, db_engine)
|
||||||
"/auth/register",
|
return {"Cookie": f"better-auth.session_token={session_token}"}
|
||||||
json={
|
|
||||||
"email": "test@example.com",
|
|
||||||
"password": "testpass123",
|
|
||||||
"display_name": "Test User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 201
|
|
||||||
token = resp.json()["access_token"]
|
|
||||||
return {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|||||||
@@ -1,146 +1,13 @@
|
|||||||
"""Integration tests for auth endpoints."""
|
"""Integration tests for auth profile endpoints.
|
||||||
|
|
||||||
|
Registration, login, and session management are handled by the Better-Auth
|
||||||
|
service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me)
|
||||||
|
which validate sessions via the shared sessions table.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_success(client):
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "new@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "New User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 201
|
|
||||||
data = resp.json()
|
|
||||||
assert "access_token" in data
|
|
||||||
assert "refresh_token" in data
|
|
||||||
assert data["token_type"] == "bearer"
|
|
||||||
assert data["expires_in"] == 900 # 15 min * 60
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_duplicate_email(client):
|
|
||||||
await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "dupe@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "User One",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "dupe@example.com",
|
|
||||||
"password": "securepass456",
|
|
||||||
"display_name": "User Two",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 409
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_short_password(client):
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "short@example.com",
|
|
||||||
"password": "short",
|
|
||||||
"display_name": "Short Pass",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_login_success(client):
|
|
||||||
await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "login@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Login User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "login@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "access_token" in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_login_wrong_password(client):
|
|
||||||
await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "wrong@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Wrong Pass",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "wrong@example.com",
|
|
||||||
"password": "badpassword1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_login_nonexistent_user(client):
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "ghost@example.com",
|
|
||||||
"password": "doesntmatter",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_token(client):
|
|
||||||
reg = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "refresh@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Refresh User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
refresh_token = reg.json()["refresh_token"]
|
|
||||||
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "access_token" in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_with_invalid_token(client):
|
|
||||||
resp = await client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={
|
|
||||||
"refresh_token": "invalid.token.here",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me(client, auth_headers):
|
async def test_get_me(client, auth_headers):
|
||||||
resp = await client.get("/auth/me", headers=auth_headers)
|
resp = await client.get("/auth/me", headers=auth_headers)
|
||||||
@@ -155,7 +22,32 @@ async def test_get_me(client, auth_headers):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me_unauthorized(client):
|
async def test_get_me_unauthorized(client):
|
||||||
resp = await client.get("/auth/me")
|
resp = await client.get("/auth/me")
|
||||||
assert resp.status_code in (401, 403) # No auth header
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_invalid_session(client):
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Cookie": "better-auth.session_token=invalid-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_with_bearer_token(client, db_engine):
|
||||||
|
"""Session tokens can also be passed as Bearer tokens for API clients."""
|
||||||
|
from tests.conftest import _create_test_user_and_session
|
||||||
|
|
||||||
|
_, session_token = await _create_test_user_and_session(
|
||||||
|
client, db_engine, email="bearer@example.com", display_name="Bearer User"
|
||||||
|
)
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {session_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["email"] == "bearer@example.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -163,9 +55,7 @@ async def test_update_me(client, auth_headers):
|
|||||||
resp = await client.patch(
|
resp = await client.patch(
|
||||||
"/auth/me",
|
"/auth/me",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
json={
|
json={"display_name": "Updated Name"},
|
||||||
"display_name": "Updated Name",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["display_name"] == "Updated Name"
|
assert resp.json()["display_name"] == "Updated Name"
|
||||||
@@ -176,34 +66,58 @@ async def test_delete_me(client, auth_headers):
|
|||||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
# Verify user is gone (token still valid but user deleted)
|
# Session is still valid but user is gone
|
||||||
resp = await client.get("/auth/me", headers=auth_headers)
|
resp = await client.get("/auth/me", headers=auth_headers)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_refresh_after_delete_fails(client):
|
async def test_expired_session_rejected(client, db_engine):
|
||||||
"""Refresh token for a deleted user must be rejected."""
|
"""Expired sessions must be rejected."""
|
||||||
reg = await client.post(
|
import secrets
|
||||||
"/auth/register",
|
import uuid
|
||||||
json={
|
from datetime import UTC, datetime, timedelta
|
||||||
"email": "ghost@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Ghost User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tokens = reg.json()
|
|
||||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
|
||||||
|
|
||||||
# Delete the user
|
from sqlalchemy import text
|
||||||
resp = await client.delete("/auth/me", headers=headers)
|
|
||||||
assert resp.status_code == 204
|
|
||||||
|
|
||||||
# Refresh token should now fail
|
user_id = str(uuid.uuid4())
|
||||||
resp = await client.post(
|
session_token = secrets.token_urlsafe(32)
|
||||||
"/auth/refresh",
|
now = datetime.now(UTC).isoformat()
|
||||||
json={
|
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||||
"refresh_token": tokens["refresh_token"],
|
|
||||||
},
|
async with db_engine.begin() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": user_id,
|
||||||
|
"email": "expired@example.com",
|
||||||
|
"hp": "unused",
|
||||||
|
"dn": "Expired User",
|
||||||
|
"ev": False,
|
||||||
|
"ca": now,
|
||||||
|
"ua": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"token": session_token,
|
||||||
|
"uid": user_id,
|
||||||
|
"ea": expired,
|
||||||
|
"ca": now,
|
||||||
|
"ua": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from decimal import Decimal
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from cartsnitch_api.auth.jwt import decode_token
|
|
||||||
from cartsnitch_api.models import (
|
from cartsnitch_api.models import (
|
||||||
Coupon,
|
Coupon,
|
||||||
NormalizedProduct,
|
NormalizedProduct,
|
||||||
@@ -126,10 +126,16 @@ async def seed_data(db_engine, auth_headers):
|
|||||||
session.add_all(prices)
|
session.add_all(prices)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
# -- Purchases (need the user_id from the registered test user) --
|
# -- Get the user_id from the session token in auth_headers --
|
||||||
token = auth_headers["Authorization"].split(" ")[1]
|
cookie_str = auth_headers.get("Cookie", "")
|
||||||
payload = decode_token(token)
|
session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else ""
|
||||||
user_id = UUID(payload["sub"])
|
|
||||||
|
result = await session.execute(
|
||||||
|
text("SELECT user_id FROM sessions WHERE token = :token"),
|
||||||
|
{"token": session_token},
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
user_id = UUID(row[0])
|
||||||
|
|
||||||
purchase1 = Purchase(
|
purchase1 = Purchase(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -1,132 +1,103 @@
|
|||||||
"""E2E: Auth and token validation flows."""
|
"""E2E: Auth and session validation flows.
|
||||||
|
|
||||||
import asyncio
|
Registration and login are handled by the Better-Auth service.
|
||||||
|
These tests validate session token handling at the API gateway level.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.conftest import _create_test_user_and_session
|
||||||
@pytest.mark.asyncio
|
|
||||||
class TestAuthRegistrationLogin:
|
|
||||||
"""Full registration → login → token refresh → profile flow."""
|
|
||||||
|
|
||||||
async def test_full_auth_lifecycle(self, client, db_engine):
|
|
||||||
"""Register → login → get profile → refresh → get profile again."""
|
|
||||||
# Register
|
|
||||||
reg = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "lifecycle@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Lifecycle User",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert reg.status_code == 201
|
|
||||||
tokens = reg.json()
|
|
||||||
assert "access_token" in tokens
|
|
||||||
assert "refresh_token" in tokens
|
|
||||||
assert tokens["token_type"] == "bearer"
|
|
||||||
assert tokens["expires_in"] > 0
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
|
||||||
|
|
||||||
# Get profile with access token
|
|
||||||
me = await client.get("/auth/me", headers=headers)
|
|
||||||
assert me.status_code == 200
|
|
||||||
assert me.json()["email"] == "lifecycle@example.com"
|
|
||||||
assert me.json()["display_name"] == "Lifecycle User"
|
|
||||||
|
|
||||||
# Sleep 1s so the new token has a different exp than the registration token
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
# Login with same credentials
|
|
||||||
login = await client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={"email": "lifecycle@example.com", "password": "securepass123"},
|
|
||||||
)
|
|
||||||
assert login.status_code == 200
|
|
||||||
login_tokens = login.json()
|
|
||||||
assert login_tokens["access_token"] != tokens["access_token"]
|
|
||||||
|
|
||||||
# Refresh token
|
|
||||||
refresh = await client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={"refresh_token": tokens["refresh_token"]},
|
|
||||||
)
|
|
||||||
assert refresh.status_code == 200
|
|
||||||
new_tokens = refresh.json()
|
|
||||||
assert new_tokens["access_token"] != tokens["access_token"]
|
|
||||||
|
|
||||||
# Use refreshed token to access profile
|
|
||||||
new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"}
|
|
||||||
me2 = await client.get("/auth/me", headers=new_headers)
|
|
||||||
assert me2.status_code == 200
|
|
||||||
assert me2.json()["email"] == "lifecycle@example.com"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestTokenValidation:
|
class TestSessionValidation:
|
||||||
"""Token edge cases and error responses."""
|
"""Session edge cases and error responses."""
|
||||||
|
|
||||||
async def test_expired_token_rejected(self, client, db_engine):
|
async def test_invalid_session_token_rejected(self, client, db_engine):
|
||||||
"""Manually craft an expired token and verify rejection."""
|
resp = await client.get(
|
||||||
import uuid
|
"/auth/me",
|
||||||
from datetime import UTC, datetime, timedelta
|
headers={"Cookie": "better-auth.session_token=not-a-real-token"},
|
||||||
|
)
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
from cartsnitch_api.config import settings
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"sub": str(uuid.uuid4()),
|
|
||||||
"exp": datetime.now(UTC) - timedelta(minutes=5),
|
|
||||||
"type": "access",
|
|
||||||
}
|
|
||||||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
|
||||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
|
|
||||||
async def test_invalid_token_rejected(self, client, db_engine):
|
async def test_missing_auth(self, client, db_engine):
|
||||||
resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"})
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
async def test_missing_auth_header(self, client, db_engine):
|
|
||||||
resp = await client.get("/auth/me")
|
resp = await client.get("/auth/me")
|
||||||
assert resp.status_code in (401, 403)
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
async def test_refresh_token_cannot_access_endpoints(self, client, db_engine):
|
async def test_bearer_token_also_works(self, client, db_engine):
|
||||||
"""A refresh token should not work as an access token."""
|
"""Session tokens passed as Bearer tokens should also be accepted."""
|
||||||
reg = await client.post(
|
_, session_token = await _create_test_user_and_session(
|
||||||
"/auth/register",
|
client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E"
|
||||||
json={
|
|
||||||
"email": "refresh-test@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Refresh Test",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
refresh_token = reg.json()["refresh_token"]
|
resp = await client.get(
|
||||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"})
|
"/auth/me",
|
||||||
assert resp.status_code == 401
|
headers={"Authorization": f"Bearer {session_token}"},
|
||||||
|
|
||||||
async def test_deleted_user_token_invalid(self, client, db_engine):
|
|
||||||
"""After deleting an account, tokens should no longer work."""
|
|
||||||
reg = await client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "delete-me@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "Delete Me",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
tokens = reg.json()
|
assert resp.status_code == 200
|
||||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
assert resp.json()["email"] == "bearer@e2e.com"
|
||||||
|
|
||||||
|
async def test_deleted_user_session_returns_not_found(self, client, db_engine):
|
||||||
|
"""After deleting a user, their session should result in 404 for profile."""
|
||||||
|
_, session_token = await _create_test_user_and_session(
|
||||||
|
client, db_engine, email="delete-me@e2e.com", display_name="Delete Me"
|
||||||
|
)
|
||||||
|
headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||||
|
|
||||||
# Delete account
|
|
||||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||||
assert delete_resp.status_code == 204
|
assert delete_resp.status_code == 204
|
||||||
|
|
||||||
# Profile should fail
|
|
||||||
me = await client.get("/auth/me", headers=headers)
|
me = await client.get("/auth/me", headers=headers)
|
||||||
assert me.status_code in (401, 404)
|
assert me.status_code == 404
|
||||||
|
|
||||||
|
async def test_expired_session_rejected(self, client, db_engine):
|
||||||
|
"""Expired sessions must be rejected."""
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
session_token = secrets.token_urlsafe(32)
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||||
|
|
||||||
|
async with db_engine.begin() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": user_id,
|
||||||
|
"email": "expired@e2e.com",
|
||||||
|
"hp": "unused",
|
||||||
|
"dn": "Expired User",
|
||||||
|
"ev": False,
|
||||||
|
"ca": now,
|
||||||
|
"ua": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"token": session_token,
|
||||||
|
"uid": user_id,
|
||||||
|
"ea": expired,
|
||||||
|
"ca": now,
|
||||||
|
"ua": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -154,60 +125,38 @@ class TestAuthProtectedEndpoints:
|
|||||||
class TestCrossUserDataIsolation:
|
class TestCrossUserDataIsolation:
|
||||||
"""Verify that users cannot access other users' data."""
|
"""Verify that users cannot access other users' data."""
|
||||||
|
|
||||||
async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data):
|
async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data):
|
||||||
"""Register a second user and verify they cannot see User A's purchases."""
|
"""A second user cannot see User A's purchases."""
|
||||||
# User A's purchase (from seed_data)
|
|
||||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||||
|
|
||||||
# Register User B
|
_, session_token = await _create_test_user_and_session(
|
||||||
reg = await client.post(
|
client, db_engine, email="userb@e2e.com", display_name="User B"
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "userb@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "User B",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
assert reg.status_code == 201
|
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||||
user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
|
||||||
|
|
||||||
# User B tries to access User A's specific purchase
|
|
||||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||||
assert resp.status_code in (403, 404), (
|
assert resp.status_code in (403, 404), (
|
||||||
"User B should not be able to access User A's purchase"
|
"User B should not be able to access User A's purchase"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_user_b_purchase_list_is_empty(self, client, seed_data):
|
async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data):
|
||||||
"""A new user should see no purchases (not User A's purchases)."""
|
"""A new user should see no purchases."""
|
||||||
reg = await client.post(
|
_, session_token = await _create_test_user_and_session(
|
||||||
"/auth/register",
|
client, db_engine, email="userc@e2e.com", display_name="User C"
|
||||||
json={
|
|
||||||
"email": "userc@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "User C",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
assert reg.status_code == 201
|
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||||
user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
|
||||||
|
|
||||||
resp = await client.get("/purchases", headers=user_c_headers)
|
resp = await client.get("/purchases", headers=user_c_headers)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||||
|
|
||||||
async def test_user_b_stores_isolated(self, client, seed_data):
|
async def test_user_b_stores_isolated(self, client, db_engine, seed_data):
|
||||||
"""User B's connected stores should be independent from User A."""
|
"""User B's connected stores should be independent from User A."""
|
||||||
reg = await client.post(
|
_, session_token = await _create_test_user_and_session(
|
||||||
"/auth/register",
|
client, db_engine, email="userd@e2e.com", display_name="User D"
|
||||||
json={
|
|
||||||
"email": "userd@example.com",
|
|
||||||
"password": "securepass123",
|
|
||||||
"display_name": "User D",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
assert reg.status_code == 201
|
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||||
user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
|
||||||
|
|
||||||
# User D should have no connected stores
|
|
||||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||||
|
|||||||
@@ -1,26 +1,25 @@
|
|||||||
"""Integration tests for purchase endpoints."""
|
"""Integration tests for purchase endpoints."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import date
|
from datetime import UTC, date, datetime, timedelta
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from cartsnitch_api.auth.jwt import create_access_token
|
|
||||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def purchase_data(db_engine):
|
async def purchase_data(db_engine):
|
||||||
"""Seed a user, store, purchase, and items."""
|
"""Seed a user, store, purchase, items, and a valid session."""
|
||||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
async with factory() as session:
|
async with factory() as session:
|
||||||
from cartsnitch_api.auth.passwords import hash_password
|
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
email="buyer@example.com",
|
email="buyer@example.com",
|
||||||
hashed_password=hash_password("testpass123"),
|
hashed_password="not-used-with-better-auth",
|
||||||
display_name="Buyer",
|
display_name="Buyer",
|
||||||
)
|
)
|
||||||
store = Store(name="Kroger", slug="kroger")
|
store = Store(name="Kroger", slug="kroger")
|
||||||
@@ -50,13 +49,33 @@ async def purchase_data(db_engine):
|
|||||||
session.add(item)
|
session.add(item)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
token = create_access_token(user.id)
|
# Create a session token directly in the sessions table
|
||||||
return {
|
session_token = secrets.token_urlsafe(32)
|
||||||
"user": user,
|
now = datetime.now(UTC).isoformat()
|
||||||
"store": store,
|
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||||
"purchase": purchase,
|
|
||||||
"headers": {"Authorization": f"Bearer {token}"},
|
async with db_engine.begin() as conn:
|
||||||
}
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||||
|
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"token": session_token,
|
||||||
|
"user_id": str(user.id),
|
||||||
|
"expires_at": expires,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user": user,
|
||||||
|
"store": store,
|
||||||
|
"purchase": purchase,
|
||||||
|
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user