sync(api): copy latest standalone code and merge alembic migrations
Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -0,0 +1,49 @@
|
|||||||
|
"""Add email_inbound_token to users.
|
||||||
|
|
||||||
|
Revision ID: 005_add_email_inbound_token
|
||||||
|
Revises: 004_fix_user_id_text
|
||||||
|
Create Date: 2026-04-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "005_add_email_inbound_token"
|
||||||
|
down_revision = "004_fix_user_id_text"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add column nullable first so existing rows can be backfilled
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("email_inbound_token", sa.String(22), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backfill existing users with unique tokens
|
||||||
|
connection = op.get_bind()
|
||||||
|
result = connection.execute(sa.text("SELECT id FROM users WHERE email_inbound_token IS NULL"))
|
||||||
|
for (user_id,) in result:
|
||||||
|
token = secrets.token_urlsafe(16)
|
||||||
|
connection.execute(
|
||||||
|
sa.text("UPDATE users SET email_inbound_token = :token WHERE id = :id"),
|
||||||
|
{"token": token, "id": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now enforce non-null and unique
|
||||||
|
op.alter_column("users", "email_inbound_token", nullable=False)
|
||||||
|
op.create_index(
|
||||||
|
"ix_users_email_inbound_token",
|
||||||
|
"users",
|
||||||
|
["email_inbound_token"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_users_email_inbound_token", table_name="users")
|
||||||
|
op.drop_column("users", "email_inbound_token")
|
||||||
@@ -1,100 +1,34 @@
|
|||||||
"""FastAPI dependency injection for authentication.
|
"""FastAPI dependency injection for authentication."""
|
||||||
|
|
||||||
Validates Better-Auth session tokens from cookies or Bearer header.
|
from uuid import UUID
|
||||||
Sessions are verified by querying the shared sessions table directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
|
|
||||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
from cartsnitch_api.auth.jwt import decode_token
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
from cartsnitch_api.database import get_db
|
|
||||||
|
|
||||||
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
bearer_scheme = HTTPBearer()
|
||||||
# but we support Bearer tokens for service-to-service or mobile clients.
|
|
||||||
bearer_scheme = HTTPBearer(auto_error=False)
|
|
||||||
|
|
||||||
# Better-Auth session cookie names.
|
|
||||||
# Over HTTPS Better-Auth adds the __Secure- prefix automatically.
|
|
||||||
SESSION_COOKIE_NAMES = [
|
|
||||||
"__Secure-better-auth.session_token", # HTTPS (deployed)
|
|
||||||
"better-auth.session_token", # HTTP (local dev)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def _validate_session_token(token: str, db: AsyncSession) -> str:
|
|
||||||
"""Validate a Better-Auth session token against the sessions table.
|
|
||||||
|
|
||||||
Returns the user_id (as str) if the session is valid and not expired.
|
|
||||||
Better-Auth v1.5.6 stores raw tokens in the DB. The session cookie
|
|
||||||
is signed: ``rawToken.base64HMACSignature``. Strip the signature
|
|
||||||
before querying.
|
|
||||||
"""
|
|
||||||
# Signed cookie format: rawToken.hmacSignature — split and use only the token part
|
|
||||||
raw_token = token.split(".")[0] if "." in token else token
|
|
||||||
result = await db.execute(
|
|
||||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
|
||||||
{"token": raw_token},
|
|
||||||
)
|
|
||||||
row = result.first()
|
|
||||||
|
|
||||||
if not row:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid session token",
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id, expires_at = row
|
|
||||||
if expires_at.tzinfo is None:
|
|
||||||
# Treat naive datetimes as UTC
|
|
||||||
expires_at = expires_at.replace(tzinfo=UTC)
|
|
||||||
|
|
||||||
if expires_at < datetime.now(UTC):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Session expired",
|
|
||||||
)
|
|
||||||
|
|
||||||
return str(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
request: Request,
|
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
) -> UUID:
|
||||||
db: AsyncSession = Depends(get_db),
|
try:
|
||||||
) -> str:
|
payload = decode_token(credentials.credentials)
|
||||||
"""Extract and validate the session token from cookie or Authorization header.
|
except ValueError:
|
||||||
|
|
||||||
Checks in order:
|
|
||||||
1. Better-Auth session cookie (primary — web clients)
|
|
||||||
2. Bearer token in Authorization header (fallback — API clients)
|
|
||||||
"""
|
|
||||||
token: str | None = None
|
|
||||||
|
|
||||||
# 1. Check session cookie (try both names for HTTP/HTTPS compatibility)
|
|
||||||
cookie_token = None
|
|
||||||
for name in SESSION_COOKIE_NAMES:
|
|
||||||
cookie_token = request.cookies.get(name)
|
|
||||||
if cookie_token:
|
|
||||||
break
|
|
||||||
if cookie_token:
|
|
||||||
token = cookie_token
|
|
||||||
|
|
||||||
# 2. Fall back to Bearer header
|
|
||||||
if not token and credentials:
|
|
||||||
token = credentials.credentials
|
|
||||||
|
|
||||||
if not token:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Authentication required",
|
detail="Invalid or expired token",
|
||||||
)
|
) from None
|
||||||
|
|
||||||
return await _validate_session_token(token, db)
|
if payload.get("type") != "access":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token type",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
return UUID(payload["sub"])
|
||||||
|
|
||||||
|
|
||||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||||
|
|||||||
@@ -2,21 +2,22 @@
|
|||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user_id: str) -> str:
|
def create_access_token(user_id: UUID) -> str:
|
||||||
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||||
payload = {"sub": user_id, "exp": expire, "type": "access"}
|
payload = {"sub": str(user_id), "exp": expire, "type": "access"}
|
||||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token(user_id: str) -> str:
|
def create_refresh_token(user_id: UUID) -> str:
|
||||||
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
|
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||||
payload = {"sub": user_id, "exp": expire, "type": "refresh"}
|
payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
|
||||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
"""Auth routes: user profile management.
|
"""Auth routes: register, login, refresh, me, update, delete."""
|
||||||
|
|
||||||
Registration, login, refresh, and session management are handled by
|
from uuid import UUID
|
||||||
the Better-Auth service (auth/). This router provides user profile
|
|
||||||
endpoints that query our own user data from the shared database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cartsnitch_api.auth.dependencies import get_current_user
|
from cartsnitch_api.auth.dependencies import get_current_user
|
||||||
from cartsnitch_api.database import get_db
|
from cartsnitch_api.database import get_db
|
||||||
|
from cartsnitch_api.models import User
|
||||||
from cartsnitch_api.schemas import (
|
from cartsnitch_api.schemas import (
|
||||||
|
LoginRequest,
|
||||||
|
RefreshRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
TokenResponse,
|
||||||
UpdateUserRequest,
|
UpdateUserRequest,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
)
|
)
|
||||||
@@ -19,9 +23,40 @@ from cartsnitch_api.services.auth import AuthService
|
|||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
svc = AuthService(db)
|
||||||
|
try:
|
||||||
|
return await svc.register(body.email, body.password, body.display_name)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
svc = AuthService(db)
|
||||||
|
try:
|
||||||
|
return await svc.login(body.email, body.password)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
svc = AuthService(db)
|
||||||
|
try:
|
||||||
|
return await svc.refresh(body.refresh_token)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def get_me(
|
async def get_me(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = AuthService(db)
|
svc = AuthService(db)
|
||||||
@@ -36,7 +71,7 @@ async def get_me(
|
|||||||
@router.patch("/me", response_model=UserResponse)
|
@router.patch("/me", response_model=UserResponse)
|
||||||
async def update_me(
|
async def update_me(
|
||||||
body: UpdateUserRequest,
|
body: UpdateUserRequest,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = AuthService(db)
|
svc = AuthService(db)
|
||||||
@@ -52,7 +87,7 @@ async def update_me(
|
|||||||
|
|
||||||
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_me(
|
async def delete_me(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = AuthService(db)
|
svc = AuthService(db)
|
||||||
@@ -62,3 +97,28 @@ async def delete_me(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class EmailInAddressResponse(BaseModel):
|
||||||
|
email_address: str
|
||||||
|
instructions: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/email-in-address", response_model=EmailInAddressResponse)
|
||||||
|
async def get_email_in_address(
|
||||||
|
user_id: UUID = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(User.email_inbound_token).where(User.id == user_id))
|
||||||
|
token = result.scalar_one_or_none()
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found"
|
||||||
|
) from None
|
||||||
|
return EmailInAddressResponse(
|
||||||
|
email_address=f"receipts+{token}@receipts.cartsnitch.com",
|
||||||
|
instructions=(
|
||||||
|
"Forward your digital receipt emails to this address. "
|
||||||
|
"We currently support Meijer, Kroger, and Target receipt emails."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ class Settings(BaseSettings):
|
|||||||
# Valid Fernet key for local dev — MUST be overridden in production
|
# Valid Fernet key for local dev — MUST be overridden in production
|
||||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||||
|
|
||||||
auth_service_url: str = "http://auth:3001"
|
|
||||||
|
|
||||||
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
||||||
|
|
||||||
receiptwitness_url: str = "http://receiptwitness:8001"
|
receiptwitness_url: str = "http://receiptwitness:8001"
|
||||||
|
|||||||
+10
-14
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from cartsnitch_api.auth.routes import router as auth_router
|
from cartsnitch_api.auth.routes import router as auth_router
|
||||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||||
@@ -46,19 +46,15 @@ def create_app() -> FastAPI:
|
|||||||
# Routers
|
# Routers
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
|
app.include_router(stores_router)
|
||||||
# Data endpoints mounted under /api/v1
|
app.include_router(purchases_router)
|
||||||
v1_router = APIRouter(prefix="/api/v1")
|
app.include_router(products_router)
|
||||||
v1_router.include_router(stores_router)
|
app.include_router(prices_router)
|
||||||
v1_router.include_router(purchases_router)
|
app.include_router(coupons_router)
|
||||||
v1_router.include_router(products_router)
|
app.include_router(shopping_router)
|
||||||
v1_router.include_router(prices_router)
|
app.include_router(alerts_router)
|
||||||
v1_router.include_router(coupons_router)
|
app.include_router(scraping_router)
|
||||||
v1_router.include_router(shopping_router)
|
app.include_router(public_router)
|
||||||
v1_router.include_router(alerts_router)
|
|
||||||
v1_router.include_router(scraping_router)
|
|
||||||
v1_router.include_router(public_router)
|
|
||||||
app.include_router(v1_router)
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from cartsnitch_api.constants import DiscountType
|
from cartsnitch_api.constants import DiscountType
|
||||||
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin
|
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from cartsnitch_api.models.product import NormalizedProduct
|
from cartsnitch_api.models.product import NormalizedProduct
|
||||||
from cartsnitch_api.models.store import Store
|
from cartsnitch_api.models.store import Store
|
||||||
|
|
||||||
|
|
||||||
class Coupon(UUIDPrimaryKeyMixin, Base):
|
class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""A coupon or deal for a product at a store."""
|
"""A coupon or deal for a product at a store."""
|
||||||
|
|
||||||
__tablename__ = "coupons"
|
__tablename__ = "coupons"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from sqlalchemy import Date, ForeignKey, Index, Numeric, String
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from cartsnitch_api.constants import PriceSource
|
from cartsnitch_api.constants import PriceSource
|
||||||
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin
|
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from cartsnitch_api.models.product import NormalizedProduct
|
from cartsnitch_api.models.product import NormalizedProduct
|
||||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from cartsnitch_api.models.store import Store
|
from cartsnitch_api.models.store import Store
|
||||||
|
|
||||||
|
|
||||||
class PriceHistory(UUIDPrimaryKeyMixin, Base):
|
class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""A single price observation for a product at a store on a date."""
|
"""A single price observation for a product at a store on a date."""
|
||||||
|
|
||||||
__tablename__ = "price_history"
|
__tablename__ = "price_history"
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin
|
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from cartsnitch_api.models.price import PriceHistory
|
from cartsnitch_api.models.price import PriceHistory
|
||||||
@@ -27,12 +27,12 @@ if TYPE_CHECKING:
|
|||||||
from cartsnitch_api.models.user import User
|
from cartsnitch_api.models.user import User
|
||||||
|
|
||||||
|
|
||||||
class Purchase(UUIDPrimaryKeyMixin, Base):
|
class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""A single shopping trip / receipt."""
|
"""A single shopping trip / receipt."""
|
||||||
|
|
||||||
__tablename__ = "purchases"
|
__tablename__ = "purchases"
|
||||||
|
|
||||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"), nullable=False)
|
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||||
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
|
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
|
||||||
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
@@ -61,7 +61,7 @@ class Purchase(UUIDPrimaryKeyMixin, Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PurchaseItem(UUIDPrimaryKeyMixin, Base):
|
class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""Individual line item on a receipt."""
|
"""Individual line item on a receipt."""
|
||||||
|
|
||||||
__tablename__ = "purchase_items"
|
__tablename__ = "purchase_items"
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ from sqlalchemy import Date, ForeignKey, Numeric, String
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from cartsnitch_api.constants import SizeUnit
|
from cartsnitch_api.constants import SizeUnit
|
||||||
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin
|
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from cartsnitch_api.models.product import NormalizedProduct
|
from cartsnitch_api.models.product import NormalizedProduct
|
||||||
|
|
||||||
|
|
||||||
class ShrinkflationEvent(UUIDPrimaryKeyMixin, Base):
|
class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""Detected shrinkflation event — product size changed while price held or rose."""
|
"""Detected shrinkflation event — product size changed while price held or rose."""
|
||||||
|
|
||||||
__tablename__ = "shrinkflation_events"
|
__tablename__ = "shrinkflation_events"
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""User and UserStoreAccount models."""
|
"""User and UserStoreAccount models."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
|
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from cartsnitch_api.constants import AccountStatus
|
from cartsnitch_api.constants import AccountStatus
|
||||||
@@ -16,15 +17,20 @@ if TYPE_CHECKING:
|
|||||||
from cartsnitch_api.models.store import Store
|
from cartsnitch_api.models.store import Store
|
||||||
|
|
||||||
|
|
||||||
class User(TimestampMixin, Base):
|
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||||
"""Application user."""
|
"""Application user."""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Text, primary_key=True)
|
|
||||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
display_name: Mapped[str | None] = mapped_column(String(100))
|
display_name: Mapped[str | None] = mapped_column(String(100))
|
||||||
|
email_inbound_token: Mapped[str] = mapped_column(
|
||||||
|
String(22),
|
||||||
|
nullable=False,
|
||||||
|
unique=True,
|
||||||
|
default=lambda: secrets.token_urlsafe(16),
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
|
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
|
||||||
@@ -37,7 +43,7 @@ class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
|||||||
__tablename__ = "user_store_accounts"
|
__tablename__ = "user_store_accounts"
|
||||||
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
|
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
|
||||||
|
|
||||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"), nullable=False)
|
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||||
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
|
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
|
||||||
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Alert routes: list alerts, manage settings."""
|
"""Alert routes: list alerts, manage settings."""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -13,7 +15,7 @@ router = APIRouter(prefix="/alerts", tags=["alerts"])
|
|||||||
|
|
||||||
@router.get("", response_model=list[AlertResponse])
|
@router.get("", response_model=list[AlertResponse])
|
||||||
async def list_alerts(
|
async def list_alerts(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = AlertService(db)
|
svc = AlertService(db)
|
||||||
@@ -22,7 +24,7 @@ async def list_alerts(
|
|||||||
|
|
||||||
@router.get("/settings", response_model=AlertSettingsResponse)
|
@router.get("/settings", response_model=AlertSettingsResponse)
|
||||||
async def get_alert_settings(
|
async def get_alert_settings(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = AlertService(db)
|
svc = AlertService(db)
|
||||||
@@ -32,7 +34,7 @@ async def get_alert_settings(
|
|||||||
@router.put("/settings")
|
@router.put("/settings")
|
||||||
async def update_alert_settings(
|
async def update_alert_settings(
|
||||||
body: AlertSettingsRequest,
|
body: AlertSettingsRequest,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ router = APIRouter(prefix="/coupons", tags=["coupons"])
|
|||||||
@router.get("", response_model=list[CouponResponse])
|
@router.get("", response_model=list[CouponResponse])
|
||||||
async def list_coupons(
|
async def list_coupons(
|
||||||
store_id: UUID | None = Query(None),
|
store_id: UUID | None = Query(None),
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = CouponService(db)
|
svc = CouponService(db)
|
||||||
@@ -25,7 +25,7 @@ async def list_coupons(
|
|||||||
|
|
||||||
@router.get("/relevant", response_model=list[CouponResponse])
|
@router.get("/relevant", response_model=list[CouponResponse])
|
||||||
async def relevant_coupons(
|
async def relevant_coupons(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = CouponService(db)
|
svc = CouponService(db)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ router = APIRouter(prefix="/prices", tags=["prices"])
|
|||||||
|
|
||||||
@router.get("/trends", response_model=list[PriceTrendResponse])
|
@router.get("/trends", response_model=list[PriceTrendResponse])
|
||||||
async def price_trends(
|
async def price_trends(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
category: str | None = Query(None),
|
category: str | None = Query(None),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
@@ -30,7 +30,7 @@ async def price_trends(
|
|||||||
|
|
||||||
@router.get("/increases", response_model=list[PriceIncreaseResponse])
|
@router.get("/increases", response_model=list[PriceIncreaseResponse])
|
||||||
async def price_increases(
|
async def price_increases(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = PriceService(db)
|
svc = PriceService(db)
|
||||||
@@ -40,7 +40,7 @@ async def price_increases(
|
|||||||
@router.get("/comparison", response_model=list[PriceComparisonResponse])
|
@router.get("/comparison", response_model=list[PriceComparisonResponse])
|
||||||
async def price_comparison(
|
async def price_comparison(
|
||||||
product_ids: Annotated[list[UUID], Query()],
|
product_ids: Annotated[list[UUID], Query()],
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = PriceService(db)
|
svc = PriceService(db)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/products", tags=["products"])
|
|||||||
|
|
||||||
@router.get("", response_model=list[ProductResponse])
|
@router.get("", response_model=list[ProductResponse])
|
||||||
async def list_products(
|
async def list_products(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
q: str | None = Query(None),
|
q: str | None = Query(None),
|
||||||
category: str | None = Query(None),
|
category: str | None = Query(None),
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
@@ -29,7 +29,7 @@ async def list_products(
|
|||||||
@router.get("/{product_id}", response_model=ProductDetailResponse)
|
@router.get("/{product_id}", response_model=ProductDetailResponse)
|
||||||
async def get_product(
|
async def get_product(
|
||||||
product_id: UUID,
|
product_id: UUID,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = ProductService(db)
|
svc = ProductService(db)
|
||||||
@@ -44,7 +44,7 @@ async def get_product(
|
|||||||
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
|
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
|
||||||
async def get_product_prices(
|
async def get_product_prices(
|
||||||
product_id: UUID,
|
product_id: UUID,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = ProductService(db)
|
svc = ProductService(db)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/purchases", tags=["purchases"])
|
|||||||
|
|
||||||
@router.get("", response_model=list[PurchaseResponse])
|
@router.get("", response_model=list[PurchaseResponse])
|
||||||
async def list_purchases(
|
async def list_purchases(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
store_id: UUID | None = Query(None),
|
store_id: UUID | None = Query(None),
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(20, ge=1, le=100),
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
@@ -27,7 +27,7 @@ async def list_purchases(
|
|||||||
|
|
||||||
@router.get("/stats", response_model=PurchaseStatsResponse)
|
@router.get("/stats", response_model=PurchaseStatsResponse)
|
||||||
async def purchase_stats(
|
async def purchase_stats(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = PurchaseService(db)
|
svc = PurchaseService(db)
|
||||||
@@ -37,7 +37,7 @@ async def purchase_stats(
|
|||||||
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
|
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
|
||||||
async def get_purchase(
|
async def get_purchase(
|
||||||
purchase_id: UUID,
|
purchase_id: UUID,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = PurchaseService(db)
|
svc = PurchaseService(db)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
|
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from httpx import HTTPStatusError, RequestError
|
from httpx import HTTPStatusError, RequestError
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ router = APIRouter(prefix="/scraping", tags=["scraping"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
|
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
|
||||||
async def trigger_sync(store_slug: str, user_id: str = Depends(get_current_user)):
|
async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
|
||||||
client = ReceiptWitnessClient()
|
client = ReceiptWitnessClient()
|
||||||
try:
|
try:
|
||||||
result = await client.trigger_sync(str(user_id), store_slug)
|
result = await client.trigger_sync(str(user_id), store_slug)
|
||||||
@@ -29,7 +31,7 @@ async def trigger_sync(store_slug: str, user_id: str = Depends(get_current_user)
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=list[SyncStatusResponse])
|
@router.get("/status", response_model=list[SyncStatusResponse])
|
||||||
async def sync_status(user_id: str = Depends(get_current_user)):
|
async def sync_status(user_id: UUID = Depends(get_current_user)):
|
||||||
client = ReceiptWitnessClient()
|
client = ReceiptWitnessClient()
|
||||||
try:
|
try:
|
||||||
return await client.get_sync_status(str(user_id))
|
return await client.get_sync_status(str(user_id))
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Shopping routes: optimize list, saved lists."""
|
"""Shopping routes: optimize list, saved lists."""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from httpx import HTTPStatusError, RequestError
|
from httpx import HTTPStatusError, RequestError
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ router = APIRouter(prefix="/shopping", tags=["shopping"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/optimize", response_model=OptimizeResponse)
|
@router.post("/optimize", response_model=OptimizeResponse)
|
||||||
async def optimize_shopping(body: OptimizeRequest, user_id: str = Depends(get_current_user)):
|
async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
|
||||||
client = ClipArtistClient()
|
client = ClipArtistClient()
|
||||||
try:
|
try:
|
||||||
result = await client.optimize(
|
result = await client.optimize(
|
||||||
@@ -35,7 +37,7 @@ async def optimize_shopping(body: OptimizeRequest, user_id: str = Depends(get_cu
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/lists", response_model=list[ShoppingListResponse])
|
@router.get("/lists", response_model=list[ShoppingListResponse])
|
||||||
async def list_shopping_lists(user_id: str = Depends(get_current_user)):
|
async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
|
||||||
client = ClipArtistClient()
|
client = ClipArtistClient()
|
||||||
try:
|
try:
|
||||||
return await client.get_shopping_lists(str(user_id))
|
return await client.get_shopping_lists(str(user_id))
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Store routes: list stores, manage user store connections."""
|
"""Store routes: list stores, manage user store connections."""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -19,7 +21,7 @@ async def list_stores(db: AsyncSession = Depends(get_db)):
|
|||||||
|
|
||||||
@router.get("/me/stores", response_model=list[StoreAccountResponse])
|
@router.get("/me/stores", response_model=list[StoreAccountResponse])
|
||||||
async def list_user_stores(
|
async def list_user_stores(
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = StoreService(db)
|
svc = StoreService(db)
|
||||||
@@ -34,7 +36,7 @@ async def list_user_stores(
|
|||||||
async def connect_store(
|
async def connect_store(
|
||||||
store_slug: str,
|
store_slug: str,
|
||||||
body: ConnectStoreRequest,
|
body: ConnectStoreRequest,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = StoreService(db)
|
svc = StoreService(db)
|
||||||
@@ -49,7 +51,7 @@ async def connect_store(
|
|||||||
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def disconnect_store(
|
async def disconnect_store(
|
||||||
store_slug: str,
|
store_slug: str,
|
||||||
user_id: str = Depends(get_current_user),
|
user_id: UUID = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
svc = StoreService(db)
|
svc = StoreService(db)
|
||||||
|
|||||||
@@ -1,13 +1,33 @@
|
|||||||
"""Pydantic v2 request/response schemas for all API endpoints."""
|
"""Pydantic v2 request/response schemas for all API endpoints."""
|
||||||
|
|
||||||
from datetime import date, datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
# ---------- Auth ----------
|
# ---------- Auth ----------
|
||||||
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
|
||||||
# These schemas are for the profile management endpoints only.
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(min_length=8, max_length=128)
|
||||||
|
display_name: str = Field(min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
expires_in: int
|
||||||
|
|
||||||
|
|
||||||
class UpdateUserRequest(BaseModel):
|
class UpdateUserRequest(BaseModel):
|
||||||
@@ -16,7 +36,7 @@ class UpdateUserRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: str
|
id: UUID
|
||||||
email: str
|
email: str
|
||||||
display_name: str
|
display_name: str
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
@@ -60,7 +80,7 @@ class PurchaseResponse(BaseModel):
|
|||||||
id: UUID
|
id: UUID
|
||||||
store_id: UUID
|
store_id: UUID
|
||||||
store_name: str
|
store_name: str
|
||||||
purchased_at: date
|
purchased_at: datetime
|
||||||
total: float
|
total: float
|
||||||
item_count: int
|
item_count: int
|
||||||
|
|
||||||
@@ -142,7 +162,7 @@ class CouponResponse(BaseModel):
|
|||||||
discount_value: float
|
discount_value: float
|
||||||
discount_type: str
|
discount_type: str
|
||||||
product_id: UUID | None = None
|
product_id: UUID | None = None
|
||||||
expires_at: date | None = None
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
# ---------- Shopping ----------
|
# ---------- Shopping ----------
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ Alerts are generated by StickerShock and ShrinkRay services and written to the D
|
|||||||
This service reads them for the API gateway.
|
This service reads them for the API gateway.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
@@ -13,7 +15,7 @@ class AlertService:
|
|||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def list_alerts(self, user_id: str) -> list[dict]:
|
async def list_alerts(self, user_id: UUID) -> list[dict]:
|
||||||
"""List shrinkflation events for products the user has purchased."""
|
"""List shrinkflation events for products the user has purchased."""
|
||||||
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
|
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
|
||||||
|
|
||||||
@@ -55,7 +57,7 @@ class AlertService:
|
|||||||
for e in events
|
for e in events
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_settings(self, user_id: str) -> dict:
|
async def get_settings(self, user_id: UUID) -> dict:
|
||||||
# Alert settings would be stored in a user_settings table.
|
# Alert settings would be stored in a user_settings table.
|
||||||
# For now, return defaults since the table doesn't exist yet in common lib.
|
# For now, return defaults since the table doesn't exist yet in common lib.
|
||||||
return {
|
return {
|
||||||
@@ -64,7 +66,7 @@ class AlertService:
|
|||||||
"email_notifications": False,
|
"email_notifications": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def update_settings(self, user_id: str, **fields) -> dict:
|
async def update_settings(self, user_id: UUID, **fields) -> dict:
|
||||||
# Would update user_settings table. Return merged defaults for now.
|
# Would update user_settings table. Return merged defaults for now.
|
||||||
current = await self.get_settings(user_id)
|
current = await self.get_settings(user_id)
|
||||||
for k, v in fields.items():
|
for k, v in fields.items():
|
||||||
|
|||||||
@@ -1,19 +1,68 @@
|
|||||||
"""Auth service — user profile management.
|
"""Auth service — user registration, login, token management."""
|
||||||
|
|
||||||
Registration, login, token management, and session handling are now
|
from uuid import UUID
|
||||||
handled by the Better-Auth service (auth/). This service provides
|
|
||||||
user lookup and profile update operations for the API gateway.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
|
||||||
|
from cartsnitch_api.auth.passwords import hash_password, verify_password
|
||||||
|
from cartsnitch_api.config import settings
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def get_user(self, user_id: str) -> dict:
|
async def register(self, email: str, password: str, display_name: str) -> dict:
|
||||||
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
|
existing = await self.db.execute(select(User).where(User.email == email))
|
||||||
|
if existing.scalar_one_or_none():
|
||||||
|
raise ValueError("Email already registered")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
hashed_password=hash_password(password),
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
self.db.add(user)
|
||||||
|
await self.db.commit()
|
||||||
|
await self.db.refresh(user)
|
||||||
|
|
||||||
|
return self._make_token_response(user.id)
|
||||||
|
|
||||||
|
async def login(self, email: str, password: str) -> dict:
|
||||||
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
|
result = await self.db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user or not verify_password(password, user.hashed_password):
|
||||||
|
raise ValueError("Invalid email or password")
|
||||||
|
|
||||||
|
return self._make_token_response(user.id)
|
||||||
|
|
||||||
|
async def refresh(self, refresh_token: str) -> dict:
|
||||||
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = decode_token(refresh_token)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid refresh token") from None
|
||||||
|
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
raise ValueError("Invalid token type") from None
|
||||||
|
|
||||||
|
user_id = UUID(payload["sub"])
|
||||||
|
|
||||||
|
# Verify the user still exists before issuing new tokens
|
||||||
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
raise ValueError("User no longer exists")
|
||||||
|
|
||||||
|
return self._make_token_response(user_id)
|
||||||
|
|
||||||
|
async def get_user(self, user_id: UUID) -> dict:
|
||||||
from cartsnitch_api.models import User
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||||
@@ -28,7 +77,7 @@ class AuthService:
|
|||||||
"created_at": user.created_at,
|
"created_at": user.created_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def update_user(self, user_id: str, **fields) -> dict:
|
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||||
from cartsnitch_api.models import User
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||||
@@ -56,7 +105,7 @@ class AuthService:
|
|||||||
"created_at": user.created_at,
|
"created_at": user.created_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def delete_user(self, user_id: str) -> None:
|
async def delete_user(self, user_id: UUID) -> None:
|
||||||
from cartsnitch_api.models import User
|
from cartsnitch_api.models import User
|
||||||
|
|
||||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||||
@@ -66,3 +115,11 @@ class AuthService:
|
|||||||
|
|
||||||
await self.db.delete(user)
|
await self.db.delete(user)
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
|
|
||||||
|
def _make_token_response(self, user_id: UUID) -> dict:
|
||||||
|
return {
|
||||||
|
"access_token": create_access_token(user_id),
|
||||||
|
"refresh_token": create_refresh_token(user_id),
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": settings.jwt_access_token_expire_minutes * 60,
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class CouponService:
|
|||||||
coupons = result.scalars().all()
|
coupons = result.scalars().all()
|
||||||
return [self._to_dict(c) for c in coupons]
|
return [self._to_dict(c) for c in coupons]
|
||||||
|
|
||||||
async def relevant_coupons(self, user_id: str) -> list[dict]:
|
async def relevant_coupons(self, user_id: UUID) -> list[dict]:
|
||||||
"""Coupons for products the user has purchased."""
|
"""Coupons for products the user has purchased."""
|
||||||
from cartsnitch_api.models import Coupon, PurchaseItem
|
from cartsnitch_api.models import Coupon, PurchaseItem
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class PurchaseService:
|
|||||||
|
|
||||||
async def list_purchases(
|
async def list_purchases(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: UUID,
|
||||||
store_id: UUID | None = None,
|
store_id: UUID | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
@@ -56,7 +56,7 @@ class PurchaseService:
|
|||||||
for p, item_count, store_name in result.all()
|
for p, item_count, store_name in result.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_purchase(self, purchase_id: UUID, user_id: str) -> dict:
|
async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
|
||||||
from cartsnitch_api.models import Purchase
|
from cartsnitch_api.models import Purchase
|
||||||
|
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
@@ -88,7 +88,7 @@ class PurchaseService:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_stats(self, user_id: str) -> dict:
|
async def get_stats(self, user_id: UUID) -> dict:
|
||||||
from cartsnitch_api.models import Purchase
|
from cartsnitch_api.models import Purchase
|
||||||
|
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Store service — list stores, manage user store account connections."""
|
"""Store service — list stores, manage user store account connections."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -34,7 +35,7 @@ class StoreService:
|
|||||||
for s in stores
|
for s in stores
|
||||||
]
|
]
|
||||||
|
|
||||||
async def list_user_stores(self, user_id: str) -> list[dict]:
|
async def list_user_stores(self, user_id: UUID) -> list[dict]:
|
||||||
from cartsnitch_api.models import UserStoreAccount
|
from cartsnitch_api.models import UserStoreAccount
|
||||||
|
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
@@ -59,7 +60,7 @@ class StoreService:
|
|||||||
for a in accounts
|
for a in accounts
|
||||||
]
|
]
|
||||||
|
|
||||||
async def connect_store(self, user_id: str, store_slug: str, credentials: dict | None) -> dict:
|
async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
|
||||||
from cartsnitch_api.models import Store, UserStoreAccount
|
from cartsnitch_api.models import Store, UserStoreAccount
|
||||||
|
|
||||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||||
@@ -106,7 +107,7 @@ class StoreService:
|
|||||||
"sync_status": "active",
|
"sync_status": "active",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def disconnect_store(self, user_id: str, store_slug: str) -> None:
|
async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
|
||||||
from cartsnitch_api.models import Store, UserStoreAccount
|
from cartsnitch_api.models import Store, UserStoreAccount
|
||||||
|
|
||||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||||
|
|||||||
+15
-101
@@ -1,16 +1,8 @@
|
|||||||
"""Shared test fixtures with in-memory SQLite database.
|
"""Shared test fixtures with in-memory SQLite database."""
|
||||||
|
|
||||||
Session-based auth: tests create users and sessions directly in the DB,
|
|
||||||
matching the Better-Auth session validation flow.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
import uuid
|
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy import create_engine, event, text
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@@ -59,46 +51,6 @@ async def db_engine():
|
|||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
|
||||||
await conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sessions (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
token TEXT NOT NULL UNIQUE,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
expires_at TIMESTAMP NOT NULL,
|
|
||||||
ip_address TEXT,
|
|
||||||
user_agent TEXT,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
||||||
)
|
|
||||||
"""))
|
|
||||||
await conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
access_token TEXT,
|
|
||||||
refresh_token TEXT,
|
|
||||||
access_token_expires_at TIMESTAMP,
|
|
||||||
refresh_token_expires_at TIMESTAMP,
|
|
||||||
scope TEXT,
|
|
||||||
id_token TEXT,
|
|
||||||
password TEXT,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
||||||
)
|
|
||||||
"""))
|
|
||||||
await conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS verifications (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
identifier TEXT NOT NULL,
|
|
||||||
value TEXT NOT NULL,
|
|
||||||
expires_at TIMESTAMP NOT NULL,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
||||||
)
|
|
||||||
"""))
|
|
||||||
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
@@ -133,55 +85,17 @@ async def client(db_engine):
|
|||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
|
||||||
"""Create a test user and a valid session directly in the DB.
|
|
||||||
|
|
||||||
Returns (user_dict, session_token).
|
|
||||||
"""
|
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
email = user_overrides.get("email", "test@example.com")
|
|
||||||
display_name = user_overrides.get("display_name", "Test User")
|
|
||||||
session_token = secrets.token_urlsafe(32)
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
|
||||||
|
|
||||||
async with db_engine.begin() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": user_id,
|
|
||||||
"email": email,
|
|
||||||
"hashed_password": "not-used-with-better-auth",
|
|
||||||
"display_name": display_name,
|
|
||||||
"email_verified": False,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": session_id,
|
|
||||||
"token": session_token,
|
|
||||||
"user_id": user_id,
|
|
||||||
"expires_at": expires,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def auth_headers(client, db_engine):
|
async def auth_headers(client):
|
||||||
"""Create a test user with a valid session and return auth headers."""
|
"""Register a test user and return auth headers."""
|
||||||
_, session_token = await _create_test_user_and_session(client, db_engine)
|
resp = await client.post(
|
||||||
return {"Cookie": f"better-auth.session_token={session_token}"}
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "test@example.com",
|
||||||
|
"password": "testpass123",
|
||||||
|
"display_name": "Test User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|||||||
@@ -1,13 +1,146 @@
|
|||||||
"""Integration tests for auth profile endpoints.
|
"""Integration tests for auth endpoints."""
|
||||||
|
|
||||||
Registration, login, and session management are handled by the Better-Auth
|
|
||||||
service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me)
|
|
||||||
which validate sessions via the shared sessions table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_success(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "new@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "New User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
assert data["expires_in"] == 900 # 15 min * 60
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_email(client):
|
||||||
|
await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "dupe@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "User One",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "dupe@example.com",
|
||||||
|
"password": "securepass456",
|
||||||
|
"display_name": "User Two",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_short_password(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "short@example.com",
|
||||||
|
"password": "short",
|
||||||
|
"display_name": "Short Pass",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_success(client):
|
||||||
|
await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "login@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Login User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "login@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "access_token" in resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_wrong_password(client):
|
||||||
|
await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "wrong@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Wrong Pass",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "wrong@example.com",
|
||||||
|
"password": "badpassword1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_nonexistent_user(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "ghost@example.com",
|
||||||
|
"password": "doesntmatter",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_token(client):
|
||||||
|
reg = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "refresh@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Refresh User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
refresh_token = reg.json()["refresh_token"]
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "access_token" in resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_with_invalid_token(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "invalid.token.here",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me(client, auth_headers):
|
async def test_get_me(client, auth_headers):
|
||||||
resp = await client.get("/auth/me", headers=auth_headers)
|
resp = await client.get("/auth/me", headers=auth_headers)
|
||||||
@@ -22,32 +155,7 @@ async def test_get_me(client, auth_headers):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me_unauthorized(client):
|
async def test_get_me_unauthorized(client):
|
||||||
resp = await client.get("/auth/me")
|
resp = await client.get("/auth/me")
|
||||||
assert resp.status_code in (401, 403)
|
assert resp.status_code in (401, 403) # No auth header
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_me_invalid_session(client):
|
|
||||||
resp = await client.get(
|
|
||||||
"/auth/me",
|
|
||||||
headers={"Cookie": "better-auth.session_token=invalid-token"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_me_with_bearer_token(client, db_engine):
|
|
||||||
"""Session tokens can also be passed as Bearer tokens for API clients."""
|
|
||||||
from tests.conftest import _create_test_user_and_session
|
|
||||||
|
|
||||||
_, session_token = await _create_test_user_and_session(
|
|
||||||
client, db_engine, email="bearer@example.com", display_name="Bearer User"
|
|
||||||
)
|
|
||||||
resp = await client.get(
|
|
||||||
"/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {session_token}"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["email"] == "bearer@example.com"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -55,7 +163,9 @@ async def test_update_me(client, auth_headers):
|
|||||||
resp = await client.patch(
|
resp = await client.patch(
|
||||||
"/auth/me",
|
"/auth/me",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
json={"display_name": "Updated Name"},
|
json={
|
||||||
|
"display_name": "Updated Name",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["display_name"] == "Updated Name"
|
assert resp.json()["display_name"] == "Updated Name"
|
||||||
@@ -66,58 +176,34 @@ async def test_delete_me(client, auth_headers):
|
|||||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
# Session is still valid but user is gone
|
# Verify user is gone (token still valid but user deleted)
|
||||||
resp = await client.get("/auth/me", headers=auth_headers)
|
resp = await client.get("/auth/me", headers=auth_headers)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session_rejected(client, db_engine):
|
async def test_refresh_after_delete_fails(client):
|
||||||
"""Expired sessions must be rejected."""
|
"""Refresh token for a deleted user must be rejected."""
|
||||||
import secrets
|
reg = await client.post(
|
||||||
import uuid
|
"/auth/register",
|
||||||
from datetime import UTC, datetime, timedelta
|
json={
|
||||||
|
"email": "ghost@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Ghost User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens = reg.json()
|
||||||
|
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||||
|
|
||||||
from sqlalchemy import text
|
# Delete the user
|
||||||
|
resp = await client.delete("/auth/me", headers=headers)
|
||||||
|
assert resp.status_code == 204
|
||||||
|
|
||||||
user_id = str(uuid.uuid4())
|
# Refresh token should now fail
|
||||||
session_token = secrets.token_urlsafe(32)
|
resp = await client.post(
|
||||||
now = datetime.now(UTC).isoformat()
|
"/auth/refresh",
|
||||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
json={
|
||||||
|
"refresh_token": tokens["refresh_token"],
|
||||||
async with db_engine.begin() as conn:
|
},
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": user_id,
|
|
||||||
"email": "expired@example.com",
|
|
||||||
"hp": "unused",
|
|
||||||
"dn": "Expired User",
|
|
||||||
"ev": False,
|
|
||||||
"ca": now,
|
|
||||||
"ua": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"token": session_token,
|
|
||||||
"uid": user_id,
|
|
||||||
"ea": expired,
|
|
||||||
"ca": now,
|
|
||||||
"ua": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = await client.get(
|
|
||||||
"/auth/me",
|
|
||||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
|
||||||
)
|
)
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from decimal import Decimal
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from cartsnitch_api.auth.jwt import decode_token
|
||||||
from cartsnitch_api.models import (
|
from cartsnitch_api.models import (
|
||||||
Coupon,
|
Coupon,
|
||||||
NormalizedProduct,
|
NormalizedProduct,
|
||||||
@@ -26,8 +26,8 @@ from cartsnitch_api.models import (
|
|||||||
# Shared test constants
|
# Shared test constants
|
||||||
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
|
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
|
||||||
BAD_UUID = "not-a-uuid"
|
BAD_UUID = "not-a-uuid"
|
||||||
# Fixed anchor date for deterministic tests
|
# Anchor date relative to today so coupon validity windows stay in the future
|
||||||
ANCHOR_DATE = date(2026, 3, 15)
|
ANCHOR_DATE = date.today()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -126,16 +126,10 @@ async def seed_data(db_engine, auth_headers):
|
|||||||
session.add_all(prices)
|
session.add_all(prices)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
# -- Get the user_id from the session token in auth_headers --
|
# -- Purchases (need the user_id from the registered test user) --
|
||||||
cookie_str = auth_headers.get("Cookie", "")
|
token = auth_headers["Authorization"].split(" ")[1]
|
||||||
session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else ""
|
payload = decode_token(token)
|
||||||
|
user_id = UUID(payload["sub"])
|
||||||
result = await session.execute(
|
|
||||||
text("SELECT user_id FROM sessions WHERE token = :token"),
|
|
||||||
{"token": session_token},
|
|
||||||
)
|
|
||||||
row = result.first()
|
|
||||||
user_id = UUID(row[0])
|
|
||||||
|
|
||||||
purchase1 = Purchase(
|
purchase1 = Purchase(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -1,104 +1,133 @@
|
|||||||
"""E2E: Auth and session validation flows.
|
"""E2E: Auth and token validation flows."""
|
||||||
|
|
||||||
Registration and login are handled by the Better-Auth service.
|
import asyncio
|
||||||
These tests validate session token handling at the API gateway level.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.conftest import _create_test_user_and_session
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAuthRegistrationLogin:
|
||||||
|
"""Full registration → login → token refresh → profile flow."""
|
||||||
|
|
||||||
|
async def test_full_auth_lifecycle(self, client, db_engine):
|
||||||
|
"""Register → login → get profile → refresh → get profile again."""
|
||||||
|
# Register
|
||||||
|
reg = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "lifecycle@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Lifecycle User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert reg.status_code == 201
|
||||||
|
tokens = reg.json()
|
||||||
|
assert "access_token" in tokens
|
||||||
|
assert "refresh_token" in tokens
|
||||||
|
assert tokens["token_type"] == "bearer"
|
||||||
|
assert tokens["expires_in"] > 0
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||||
|
|
||||||
|
# Get profile with access token
|
||||||
|
me = await client.get("/auth/me", headers=headers)
|
||||||
|
assert me.status_code == 200
|
||||||
|
assert me.json()["email"] == "lifecycle@example.com"
|
||||||
|
assert me.json()["display_name"] == "Lifecycle User"
|
||||||
|
|
||||||
|
# Sleep 1s so the new token has a different exp than the registration token
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Login with same credentials
|
||||||
|
login = await client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={"email": "lifecycle@example.com", "password": "securepass123"},
|
||||||
|
)
|
||||||
|
assert login.status_code == 200
|
||||||
|
login_tokens = login.json()
|
||||||
|
assert login_tokens["access_token"] != tokens["access_token"]
|
||||||
|
|
||||||
|
# Refresh token
|
||||||
|
refresh = await client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={"refresh_token": tokens["refresh_token"]},
|
||||||
|
)
|
||||||
|
assert refresh.status_code == 200
|
||||||
|
new_tokens = refresh.json()
|
||||||
|
assert new_tokens["access_token"] != tokens["access_token"]
|
||||||
|
|
||||||
|
# Use refreshed token to access profile
|
||||||
|
new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"}
|
||||||
|
me2 = await client.get("/auth/me", headers=new_headers)
|
||||||
|
assert me2.status_code == 200
|
||||||
|
assert me2.json()["email"] == "lifecycle@example.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestSessionValidation:
|
class TestTokenValidation:
|
||||||
"""Session edge cases and error responses."""
|
"""Token edge cases and error responses."""
|
||||||
|
|
||||||
async def test_invalid_session_token_rejected(self, client, db_engine):
|
async def test_expired_token_rejected(self, client, db_engine):
|
||||||
resp = await client.get(
|
"""Manually craft an expired token and verify rejection."""
|
||||||
"/auth/me",
|
|
||||||
headers={"Cookie": "better-auth.session_token=not-a-real-token"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
async def test_missing_auth(self, client, db_engine):
|
|
||||||
resp = await client.get("/auth/me")
|
|
||||||
assert resp.status_code in (401, 403)
|
|
||||||
|
|
||||||
async def test_bearer_token_also_works(self, client, db_engine):
|
|
||||||
"""Session tokens passed as Bearer tokens should also be accepted."""
|
|
||||||
_, session_token = await _create_test_user_and_session(
|
|
||||||
client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E"
|
|
||||||
)
|
|
||||||
resp = await client.get(
|
|
||||||
"/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {session_token}"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["email"] == "bearer@e2e.com"
|
|
||||||
|
|
||||||
async def test_deleted_user_session_returns_not_found(self, client, db_engine):
|
|
||||||
"""After deleting a user, their session should result in 404 for profile."""
|
|
||||||
_, session_token = await _create_test_user_and_session(
|
|
||||||
client, db_engine, email="delete-me@e2e.com", display_name="Delete Me"
|
|
||||||
)
|
|
||||||
headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
|
||||||
|
|
||||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
|
||||||
assert delete_resp.status_code == 204
|
|
||||||
|
|
||||||
me = await client.get("/auth/me", headers=headers)
|
|
||||||
assert me.status_code == 404
|
|
||||||
|
|
||||||
async def test_expired_session_rejected(self, client, db_engine):
|
|
||||||
"""Expired sessions must be rejected."""
|
|
||||||
import secrets
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
from sqlalchemy import text
|
from jose import jwt
|
||||||
|
|
||||||
user_id = str(uuid.uuid4())
|
from cartsnitch_api.config import settings
|
||||||
session_token = secrets.token_urlsafe(32)
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
|
||||||
|
|
||||||
async with db_engine.begin() as conn:
|
payload = {
|
||||||
await conn.execute(
|
"sub": str(uuid.uuid4()),
|
||||||
text(
|
"exp": datetime.now(UTC) - timedelta(minutes=5),
|
||||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
"type": "access",
|
||||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
}
|
||||||
),
|
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||||
{
|
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
"id": user_id,
|
|
||||||
"email": "expired@e2e.com",
|
|
||||||
"hp": "unused",
|
|
||||||
"dn": "Expired User",
|
|
||||||
"ev": False,
|
|
||||||
"ca": now,
|
|
||||||
"ua": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"token": session_token,
|
|
||||||
"uid": user_id,
|
|
||||||
"ea": expired,
|
|
||||||
"ca": now,
|
|
||||||
"ua": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = await client.get(
|
|
||||||
"/auth/me",
|
|
||||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
async def test_invalid_token_rejected(self, client, db_engine):
|
||||||
|
resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
async def test_missing_auth_header(self, client, db_engine):
|
||||||
|
resp = await client.get("/auth/me")
|
||||||
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
|
async def test_refresh_token_cannot_access_endpoints(self, client, db_engine):
|
||||||
|
"""A refresh token should not work as an access token."""
|
||||||
|
reg = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "refresh-test@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Refresh Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
refresh_token = reg.json()["refresh_token"]
|
||||||
|
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
async def test_deleted_user_token_invalid(self, client, db_engine):
|
||||||
|
"""After deleting an account, tokens should no longer work."""
|
||||||
|
reg = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "delete-me@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "Delete Me",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens = reg.json()
|
||||||
|
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||||
|
|
||||||
|
# Delete account
|
||||||
|
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||||
|
assert delete_resp.status_code == 204
|
||||||
|
|
||||||
|
# Profile should fail
|
||||||
|
me = await client.get("/auth/me", headers=headers)
|
||||||
|
assert me.status_code in (401, 404)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestAuthProtectedEndpoints:
|
class TestAuthProtectedEndpoints:
|
||||||
@@ -125,38 +154,60 @@ class TestAuthProtectedEndpoints:
|
|||||||
class TestCrossUserDataIsolation:
|
class TestCrossUserDataIsolation:
|
||||||
"""Verify that users cannot access other users' data."""
|
"""Verify that users cannot access other users' data."""
|
||||||
|
|
||||||
async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data):
|
async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data):
|
||||||
"""A second user cannot see User A's purchases."""
|
"""Register a second user and verify they cannot see User A's purchases."""
|
||||||
|
# User A's purchase (from seed_data)
|
||||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||||
|
|
||||||
_, session_token = await _create_test_user_and_session(
|
# Register User B
|
||||||
client, db_engine, email="userb@e2e.com", display_name="User B"
|
reg = await client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "userb@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "User B",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
assert reg.status_code == 201
|
||||||
|
user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||||
|
|
||||||
|
# User B tries to access User A's specific purchase
|
||||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||||
assert resp.status_code in (403, 404), (
|
assert resp.status_code in (403, 404), (
|
||||||
"User B should not be able to access User A's purchase"
|
"User B should not be able to access User A's purchase"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data):
|
async def test_user_b_purchase_list_is_empty(self, client, seed_data):
|
||||||
"""A new user should see no purchases."""
|
"""A new user should see no purchases (not User A's purchases)."""
|
||||||
_, session_token = await _create_test_user_and_session(
|
reg = await client.post(
|
||||||
client, db_engine, email="userc@e2e.com", display_name="User C"
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "userc@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "User C",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
assert reg.status_code == 201
|
||||||
|
user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||||
|
|
||||||
resp = await client.get("/purchases", headers=user_c_headers)
|
resp = await client.get("/purchases", headers=user_c_headers)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||||
|
|
||||||
async def test_user_b_stores_isolated(self, client, db_engine, seed_data):
|
async def test_user_b_stores_isolated(self, client, seed_data):
|
||||||
"""User B's connected stores should be independent from User A."""
|
"""User B's connected stores should be independent from User A."""
|
||||||
_, session_token = await _create_test_user_and_session(
|
reg = await client.post(
|
||||||
client, db_engine, email="userd@e2e.com", display_name="User D"
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "userd@example.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
"display_name": "User D",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
assert reg.status_code == 201
|
||||||
|
user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||||
|
|
||||||
|
# User D should have no connected stores
|
||||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
"""Tests for GET /auth/me/email-in-address endpoint."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_email_in_address_authenticated(client: AsyncClient, auth_headers: dict):
|
||||||
|
"""Authenticated user gets their email-in address."""
|
||||||
|
response = await client.get(
|
||||||
|
"/auth/me/email-in-address",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "email_address" in data
|
||||||
|
assert data["email_address"].startswith("receipts+")
|
||||||
|
assert data["email_address"].endswith("@receipts.cartsnitch.com")
|
||||||
|
assert len(data["email_address"]) > len("receipts+@receipts.cartsnitch.com")
|
||||||
|
assert "instructions" in data
|
||||||
|
assert "Meijer" in data["instructions"]
|
||||||
|
assert "Kroger" in data["instructions"]
|
||||||
|
assert "Target" in data["instructions"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_email_in_address_unauthenticated(client: AsyncClient):
|
||||||
|
"""Unauthenticated request returns 401."""
|
||||||
|
response = await client.get("/auth/me/email-in-address")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_email_in_address_invalid_token(client: AsyncClient):
|
||||||
|
"""Invalid JWT token returns 401."""
|
||||||
|
response = await client.get(
|
||||||
|
"/auth/me/email-in-address",
|
||||||
|
headers={"Authorization": "Bearer invalid-token-xyz"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_email_address_format(client: AsyncClient, auth_headers: dict):
|
||||||
|
"""Email address format is receipts+{22-char-urlsafe-token}@receipts.cartsnitch.com."""
|
||||||
|
response = await client.get(
|
||||||
|
"/auth/me/email-in-address",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
email = data["email_address"]
|
||||||
|
# Format: receipts+<22-char-urlsafe-token>@receipts.cartsnitch.com
|
||||||
|
assert email.startswith("receipts+")
|
||||||
|
assert email.endswith("@receipts.cartsnitch.com")
|
||||||
|
# token_urlsafe(16) produces 22 chars
|
||||||
|
middle = email[len("receipts+") : -len("@receipts.cartsnitch.com")]
|
||||||
|
assert len(middle) == 22
|
||||||
|
assert "@" not in middle
|
||||||
@@ -89,4 +89,4 @@ async def test_route_count():
|
|||||||
if method in ("get", "post", "put", "delete", "patch"):
|
if method in ("get", "post", "put", "delete", "patch"):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 33, f"Expected 33 routes, found {count}"
|
assert count == 34, f"Expected 34 routes, found {count}"
|
||||||
|
|||||||
@@ -1,25 +1,26 @@
|
|||||||
"""Integration tests for purchase endpoints."""
|
"""Integration tests for purchase endpoints."""
|
||||||
|
|
||||||
import secrets
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, date, datetime, timedelta
|
from datetime import date
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from cartsnitch_api.auth.jwt import create_access_token
|
||||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def purchase_data(db_engine):
|
async def purchase_data(db_engine):
|
||||||
"""Seed a user, store, purchase, items, and a valid session."""
|
"""Seed a user, store, purchase, and items."""
|
||||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
async with factory() as session:
|
async with factory() as session:
|
||||||
|
from cartsnitch_api.auth.passwords import hash_password
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
email="buyer@example.com",
|
email="buyer@example.com",
|
||||||
hashed_password="not-used-with-better-auth",
|
hashed_password=hash_password("testpass123"),
|
||||||
display_name="Buyer",
|
display_name="Buyer",
|
||||||
)
|
)
|
||||||
store = Store(name="Kroger", slug="kroger")
|
store = Store(name="Kroger", slug="kroger")
|
||||||
@@ -49,33 +50,13 @@ async def purchase_data(db_engine):
|
|||||||
session.add(item)
|
session.add(item)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Create a session token directly in the sessions table
|
token = create_access_token(user.id)
|
||||||
session_token = secrets.token_urlsafe(32)
|
return {
|
||||||
now = datetime.now(UTC).isoformat()
|
"user": user,
|
||||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
"store": store,
|
||||||
|
"purchase": purchase,
|
||||||
async with db_engine.begin() as conn:
|
"headers": {"Authorization": f"Bearer {token}"},
|
||||||
await conn.execute(
|
}
|
||||||
text(
|
|
||||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
|
||||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"token": session_token,
|
|
||||||
"user_id": str(user.id),
|
|
||||||
"expires_at": expires,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"user": user,
|
|
||||||
"store": store,
|
|
||||||
"purchase": purchase,
|
|
||||||
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user