feat: merge cartsnitch/api into api/ subdirectory
Consolidate API gateway service into monorepo. Squashed from https://github.com/cartsnitch/api main (89bacb1). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
"""FastAPI dependency injection for authentication."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from cartsnitch_api.auth.jwt import decode_token
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
) -> UUID:
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
) from None
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
) from None
|
||||
|
||||
return UUID(payload["sub"])
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
if x_service_key != settings.service_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid service key",
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""JWT token creation and validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def create_access_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "access"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def create_refresh_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
try:
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]),
|
||||
)
|
||||
except JWTError as e:
|
||||
raise ValueError(f"Invalid token: {e}") from e
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Password hashing and verification with bcrypt."""
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Auth routes: register, login, refresh, me, update, delete."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
TokenResponse,
|
||||
UpdateUserRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from cartsnitch_api.services.auth import AuthService
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.register(body.email, body.password, body.display_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.login(body.email, body.password)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
|
||||
) from None
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.refresh(body.refresh_token)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.get_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_me(
|
||||
body: UpdateUserRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.update_user(user_id, email=body.email, display_name=body.display_name)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
await svc.delete_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Redis/DragonflyDB caching helpers."""
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class CacheClient:
|
||||
"""Stub for Redis/DragonflyDB caching.
|
||||
|
||||
Will be used for expensive queries: price trends, product comparisons.
|
||||
Cache invalidation via Redis pub/sub events from other services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.url = settings.redis_url
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
# TODO: implement with redis-py async
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
@@ -0,0 +1,51 @@
|
||||
import base64
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_prefix": "CARTSNITCH_"}
|
||||
|
||||
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 15
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
service_key: str = "change-me-in-production"
|
||||
# Valid Fernet key for local dev — MUST be overridden in production
|
||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
|
||||
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
||||
|
||||
receiptwitness_url: str = "http://receiptwitness:8001"
|
||||
stickershock_url: str = "http://stickershock:8002"
|
||||
clipartist_url: str = "http://clipartist:8003"
|
||||
shrinkray_url: str = "http://shrinkray:8004"
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_fernet_key(self):
|
||||
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||
if len(decoded) != 32:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_FERNET_KEY must be a valid Fernet key "
|
||||
"(32 bytes, url-safe base64 encoded). "
|
||||
"Generate one with: python -c "
|
||||
"'from cryptography.fernet import Fernet; "
|
||||
"print(Fernet.generate_key().decode())'"
|
||||
) from None
|
||||
return self
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Constants and enums shared across CartSnitch services."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StoreSlug(StrEnum):
|
||||
"""Supported retailer slugs."""
|
||||
|
||||
MEIJER = "meijer"
|
||||
KROGER = "kroger"
|
||||
TARGET = "target"
|
||||
|
||||
|
||||
class AccountStatus(StrEnum):
|
||||
"""User store account link status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DiscountType(StrEnum):
|
||||
"""Coupon discount type."""
|
||||
|
||||
PERCENT = "percent"
|
||||
FIXED = "fixed"
|
||||
BOGO = "bogo"
|
||||
BUY_X_GET_Y = "buy_x_get_y"
|
||||
|
||||
|
||||
class PriceSource(StrEnum):
|
||||
"""Source of a price observation."""
|
||||
|
||||
RECEIPT = "receipt"
|
||||
CATALOG = "catalog"
|
||||
WEEKLY_AD = "weekly_ad"
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Redis pub/sub event types."""
|
||||
|
||||
RECEIPTS_INGESTED = "cartsnitch.receipts.ingested"
|
||||
PRICES_UPDATED = "cartsnitch.prices.updated"
|
||||
PRODUCTS_NORMALIZED = "cartsnitch.products.normalized"
|
||||
COUPONS_UPDATED = "cartsnitch.coupons.updated"
|
||||
ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase"
|
||||
ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation"
|
||||
|
||||
|
||||
class ProductCategory(StrEnum):
|
||||
"""Top-level product categories."""
|
||||
|
||||
PRODUCE = "produce"
|
||||
DAIRY = "dairy"
|
||||
MEAT = "meat"
|
||||
BAKERY = "bakery"
|
||||
FROZEN = "frozen"
|
||||
PANTRY = "pantry"
|
||||
BEVERAGES = "beverages"
|
||||
SNACKS = "snacks"
|
||||
HOUSEHOLD = "household"
|
||||
PERSONAL_CARE = "personal_care"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class MatchConfidence(StrEnum):
|
||||
"""Confidence level for product matching."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class SizeUnit(StrEnum):
|
||||
"""Standardized product size units."""
|
||||
|
||||
OZ = "oz"
|
||||
FL_OZ = "fl_oz"
|
||||
LB = "lb"
|
||||
G = "g"
|
||||
KG = "kg"
|
||||
ML = "ml"
|
||||
L = "l"
|
||||
CT = "ct"
|
||||
PK = "pk"
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Database session management for the API gateway."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,62 @@
|
||||
"""FastAPI app factory for CartSnitch API Gateway."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from cartsnitch_api.auth.routes import router as auth_router
|
||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
|
||||
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
|
||||
from cartsnitch_api.routes.alerts import router as alerts_router
|
||||
from cartsnitch_api.routes.coupons import router as coupons_router
|
||||
from cartsnitch_api.routes.health import router as health_router
|
||||
from cartsnitch_api.routes.prices import router as prices_router
|
||||
from cartsnitch_api.routes.products import router as products_router
|
||||
from cartsnitch_api.routes.public import router as public_router
|
||||
from cartsnitch_api.routes.purchases import router as purchases_router
|
||||
from cartsnitch_api.routes.scraping import router as scraping_router
|
||||
from cartsnitch_api.routes.shopping import router as shopping_router
|
||||
from cartsnitch_api.routes.stores import router as stores_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# TODO: initialize DB session pool, Redis connection, service clients
|
||||
yield
|
||||
# TODO: cleanup connections
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="CartSnitch API",
|
||||
description="Grocery price tracking and shrinkflation detection API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Middleware (order matters — outermost first)
|
||||
add_cors_middleware(app)
|
||||
add_error_monitor_middleware(app)
|
||||
add_rate_limit_middleware(app)
|
||||
|
||||
# Exception handlers
|
||||
add_error_handlers(app)
|
||||
|
||||
# Routers
|
||||
app.include_router(health_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(stores_router)
|
||||
app.include_router(purchases_router)
|
||||
app.include_router(products_router)
|
||||
app.include_router(prices_router)
|
||||
app.include_router(coupons_router)
|
||||
app.include_router(shopping_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(scraping_router)
|
||||
app.include_router(public_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -0,0 +1,16 @@
|
||||
"""CORS middleware configuration."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def add_cors_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Structured error responses and error monitoring.
|
||||
|
||||
Ensures all errors return a consistent JSON shape and never leak stack traces.
|
||||
Provides hooks for error monitoring/alerting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger("cartsnitch_api.errors")
|
||||
|
||||
|
||||
def _error_response(
|
||||
status_code: int,
|
||||
detail: str,
|
||||
code: str | None = None,
|
||||
errors: list[dict] | None = None,
|
||||
) -> JSONResponse:
|
||||
"""Build a consistent error response."""
|
||||
body: dict = {"detail": detail}
|
||||
if code:
|
||||
body["code"] = code
|
||||
if errors:
|
||||
body["errors"] = errors
|
||||
return JSONResponse(status_code=status_code, content=body)
|
||||
|
||||
|
||||
def add_error_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers for consistent error responses."""
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
"""Return 422 with structured field-level error details."""
|
||||
field_errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", ())
|
||||
field_errors.append(
|
||||
{
|
||||
"field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc),
|
||||
"message": err.get("msg", "Invalid value"),
|
||||
"type": err.get("type", "value_error"),
|
||||
}
|
||||
)
|
||||
return _error_response(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Validation error",
|
||||
code="VALIDATION_ERROR",
|
||||
errors=field_errors,
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
"""Wrap HTTP exceptions (Starlette and FastAPI) in consistent format."""
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
return _error_response(
|
||||
status_code=exc.status_code,
|
||||
detail=detail,
|
||||
code=_status_to_code(exc.status_code),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all: log full traceback, return safe 500 to client."""
|
||||
logger.error(
|
||||
"Unhandled exception on %s %s: %s\n%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
_notify_error_monitor(request, exc)
|
||||
|
||||
return _error_response(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
code="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def _status_to_code(status_code: int) -> str:
|
||||
"""Map HTTP status code to a machine-readable error code."""
|
||||
mapping = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "RATE_LIMITED",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
}
|
||||
return mapping.get(status_code, f"HTTP_{status_code}")
|
||||
|
||||
|
||||
# ---------- Error Monitoring ----------
|
||||
|
||||
|
||||
class _ErrorMonitor:
|
||||
"""Simple error counter for monitoring and alerting hooks.
|
||||
|
||||
Tracks error counts and rates. In production, this would forward
|
||||
to an external monitoring service (Prometheus, Sentry, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.error_counts: dict[int, int] = {}
|
||||
self.recent_5xx: list[dict] = []
|
||||
self._max_recent = 100
|
||||
|
||||
def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None:
|
||||
self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1
|
||||
|
||||
if status_code >= 500:
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"status": status_code,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"error": error,
|
||||
}
|
||||
self.recent_5xx.append(entry)
|
||||
if len(self.recent_5xx) > self._max_recent:
|
||||
self.recent_5xx = self.recent_5xx[-self._max_recent :]
|
||||
|
||||
logger.warning(
|
||||
"5xx error recorded: %s %s -> %d (%s)",
|
||||
method,
|
||||
path,
|
||||
status_code,
|
||||
error or "unknown",
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return {
|
||||
"error_counts": dict(self.error_counts),
|
||||
"recent_5xx_count": len(self.recent_5xx),
|
||||
}
|
||||
|
||||
|
||||
_monitor = _ErrorMonitor()
|
||||
|
||||
|
||||
def get_error_monitor() -> _ErrorMonitor:
|
||||
"""Access the global error monitor (for health/metrics endpoints)."""
|
||||
return _monitor
|
||||
|
||||
|
||||
def _notify_error_monitor(request: Request, exc: Exception) -> None:
|
||||
"""Record unhandled exception in the error monitor."""
|
||||
_monitor.record(
|
||||
status_code=500,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc)[:200],
|
||||
)
|
||||
|
||||
|
||||
class ErrorMonitorMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to track all 4xx/5xx responses for monitoring."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable],
|
||||
):
|
||||
response = await call_next(request)
|
||||
|
||||
if response.status_code >= 400:
|
||||
_monitor.record(
|
||||
status_code=response.status_code,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def add_error_monitor_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(ErrorMonitorMiddleware)
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Rate limiting middleware for public and authenticated endpoints.
|
||||
|
||||
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
if current_count >= self.max_requests:
|
||||
retry_after = int(self._hits[key][0] - cutoff) + 1
|
||||
return False, 0, retry_after
|
||||
|
||||
self._hits[key].append(now)
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
# Use last 16 chars of token as key to avoid storing full tokens
|
||||
return f"token:{token[-16:]}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"code": "RATE_LIMITED",
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limiter.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
|
||||
def add_rate_limit_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""SQLAlchemy ORM models — re-exports all models for convenience."""
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import Purchase, PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
"UUIDPrimaryKeyMixin",
|
||||
"Store",
|
||||
"StoreLocation",
|
||||
"User",
|
||||
"UserStoreAccount",
|
||||
"Purchase",
|
||||
"PurchaseItem",
|
||||
"NormalizedProduct",
|
||||
"PriceHistory",
|
||||
"Coupon",
|
||||
"ShrinkflationEvent",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Base model and mixins for all CartSnitch ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all CartSnitch models."""
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin providing created_at / updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin providing a UUID primary key."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Coupon model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import DiscountType
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A coupon or deal for a product at a store."""
|
||||
|
||||
__tablename__ = "coupons"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(String(1000))
|
||||
discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False)
|
||||
discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
valid_from: Mapped[date | None] = mapped_column(Date)
|
||||
valid_to: Mapped[date | None] = mapped_column(Date)
|
||||
requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
coupon_code: Mapped[str | None] = mapped_column(String(100))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="coupons")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""PriceHistory model — tracks product prices over time."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Index, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import PriceSource
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single price observation for a product at a store on a date."""
|
||||
|
||||
__tablename__ = "price_history"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_price_history_product_store_date",
|
||||
"normalized_product_id",
|
||||
"store_id",
|
||||
"observed_date",
|
||||
),
|
||||
)
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
observed_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source: Mapped[PriceSource] = mapped_column(String(20), nullable=False)
|
||||
purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id"))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories")
|
||||
store: Mapped["Store"] = relationship(back_populates="price_histories")
|
||||
purchase_item: Mapped["PurchaseItem | None"] = relationship(
|
||||
back_populates="price_history_entries"
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""NormalizedProduct model — the canonical product identity."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import ProductCategory, SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
|
||||
|
||||
class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Canonical product identity — matches products across retailers."""
|
||||
|
||||
__tablename__ = "normalized_products"
|
||||
|
||||
canonical_name: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
category: Mapped[ProductCategory | None] = mapped_column(String(50))
|
||||
subcategory: Mapped[str | None] = mapped_column(String(100))
|
||||
brand: Mapped[str | None] = mapped_column(String(200))
|
||||
size: Mapped[str | None] = mapped_column(String(50))
|
||||
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
|
||||
|
||||
# Relationships
|
||||
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product")
|
||||
shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Purchase and PurchaseItem models."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Date,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Numeric,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User
|
||||
|
||||
|
||||
class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single shopping trip / receipt."""
|
||||
|
||||
__tablename__ = "purchases"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
|
||||
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
purchase_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
raw_data: Mapped[dict | None] = mapped_column(JSON)
|
||||
ingested_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="purchases")
|
||||
store: Mapped["Store"] = relationship(back_populates="purchases")
|
||||
store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases")
|
||||
items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_purchases_user_store", "user_id", "store_id"),
|
||||
UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
|
||||
)
|
||||
|
||||
|
||||
class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Individual line item on a receipt."""
|
||||
|
||||
__tablename__ = "purchase_items"
|
||||
|
||||
purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False)
|
||||
product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
upc: Mapped[str | None] = mapped_column(String(20))
|
||||
quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1)
|
||||
unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
category_raw: Mapped[str | None] = mapped_column(String(100))
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
|
||||
# Relationships
|
||||
purchase: Mapped["Purchase"] = relationship(back_populates="items")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(
|
||||
back_populates="purchase_items"
|
||||
)
|
||||
price_history_entries: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="purchase_item"
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""ShrinkflationEvent model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
|
||||
|
||||
class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Detected shrinkflation event — product size changed while price held or rose."""
|
||||
|
||||
__tablename__ = "shrinkflation_events"
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
detected_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
old_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
new_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
confidence: Mapped[Decimal] = mapped_column(
|
||||
Numeric(3, 2), nullable=False, default=Decimal("1.00")
|
||||
)
|
||||
notes: Mapped[str | None] = mapped_column(String(1000))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(
|
||||
back_populates="shrinkflation_events"
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Store and StoreLocation models."""
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Float, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import StoreSlug
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.user import UserStoreAccount
|
||||
|
||||
|
||||
class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Supported retailer."""
|
||||
|
||||
__tablename__ = "stores"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True)
|
||||
logo_url: Mapped[str | None] = mapped_column(String(500))
|
||||
website_url: Mapped[str | None] = mapped_column(String(500))
|
||||
|
||||
# Relationships
|
||||
locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store")
|
||||
user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store")
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="store")
|
||||
|
||||
|
||||
class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Physical store location."""
|
||||
|
||||
__tablename__ = "store_locations"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
address: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
city: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
state: Mapped[str] = mapped_column(String(2), nullable=False)
|
||||
zip: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
lat: Mapped[float | None] = mapped_column(Float)
|
||||
lng: Mapped[float | None] = mapped_column(Float)
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="locations")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""User and UserStoreAccount models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import AccountStatus
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.types import EncryptedJSON
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Application user."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(String(100))
|
||||
|
||||
# Relationships
|
||||
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Link between a user and their retailer account credentials."""
|
||||
|
||||
__tablename__ = "user_store_accounts"
|
||||
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
|
||||
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
status: Mapped[AccountStatus] = mapped_column(
|
||||
String(20), nullable=False, default=AccountStatus.ACTIVE
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="store_accounts")
|
||||
store: Mapped["Store"] = relationship(back_populates="user_accounts")
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Alert routes: list alerts, manage settings."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse
|
||||
from cartsnitch_api.services.alerts import AlertService
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[AlertResponse])
|
||||
async def list_alerts(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.list_alerts(user_id)
|
||||
|
||||
|
||||
@router.get("/settings", response_model=AlertSettingsResponse)
|
||||
async def get_alert_settings(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.get_settings(user_id)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_alert_settings(
|
||||
body: AlertSettingsRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Alert settings persistence not yet implemented. "
|
||||
"Use GET /alerts/settings for current defaults.",
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Coupon routes: browse, relevant matches."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import CouponResponse
|
||||
from cartsnitch_api.services.coupons import CouponService
|
||||
|
||||
router = APIRouter(prefix="/coupons", tags=["coupons"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[CouponResponse])
|
||||
async def list_coupons(
|
||||
store_id: UUID | None = Query(None),
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.list_coupons(store_id)
|
||||
|
||||
|
||||
@router.get("/relevant", response_model=list[CouponResponse])
|
||||
async def relevant_coupons(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.relevant_coupons(user_id)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Health check and error metrics endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from cartsnitch_api.auth.dependencies import verify_service_key
|
||||
from cartsnitch_api.middleware.error_handler import get_error_monitor
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
|
||||
async def error_stats():
|
||||
"""Error monitoring stats — internal only (requires X-Service-Key)."""
|
||||
monitor = get_error_monitor()
|
||||
return monitor.get_stats()
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Price routes: trends, increases, comparison."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PriceComparisonResponse,
|
||||
PriceIncreaseResponse,
|
||||
PriceTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.prices import PriceService
|
||||
|
||||
router = APIRouter(prefix="/prices", tags=["prices"])
|
||||
|
||||
|
||||
@router.get("/trends", response_model=list[PriceTrendResponse])
|
||||
async def price_trends(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
category: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_trends(category)
|
||||
|
||||
|
||||
@router.get("/increases", response_model=list[PriceIncreaseResponse])
|
||||
async def price_increases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_increases()
|
||||
|
||||
|
||||
@router.get("/comparison", response_model=list[PriceComparisonResponse])
|
||||
async def price_comparison(
|
||||
product_ids: Annotated[list[UUID], Query()],
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_comparison(product_ids)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Product routes: search/list, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse
|
||||
from cartsnitch_api.services.products import ProductService
|
||||
|
||||
router = APIRouter(prefix="/products", tags=["products"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ProductResponse])
|
||||
async def list_products(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
q: str | None = Query(None),
|
||||
category: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
return await svc.list_products(q, category, page, page_size)
|
||||
|
||||
|
||||
@router.get("/{product_id}", response_model=ProductDetailResponse)
|
||||
async def get_product(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_product(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
|
||||
async def get_product_prices(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_price_history(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Public endpoints: price transparency data (no auth required)."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PublicInflationResponse,
|
||||
PublicStoreComparisonResponse,
|
||||
PublicTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.public import PublicService
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
try:
|
||||
return await svc.get_trend(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||
async def public_store_comparison(
|
||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not product_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one product_id is required",
|
||||
)
|
||||
svc = PublicService(db)
|
||||
return await svc.get_store_comparison(product_ids)
|
||||
|
||||
|
||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
return await svc.get_inflation()
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Purchase routes: list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse
|
||||
from cartsnitch_api.services.purchases import PurchaseService
|
||||
|
||||
router = APIRouter(prefix="/purchases", tags=["purchases"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[PurchaseResponse])
|
||||
async def list_purchases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
store_id: UUID | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.list_purchases(user_id, store_id, page, page_size)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=PurchaseStatsResponse)
|
||||
async def purchase_stats(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.get_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
|
||||
async def get_purchase(
|
||||
purchase_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
try:
|
||||
return await svc.get_purchase(purchase_id, user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found"
|
||||
) from None
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse
|
||||
from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient
|
||||
|
||||
router = APIRouter(prefix="/scraping", tags=["scraping"])
|
||||
|
||||
|
||||
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
|
||||
async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
result = await client.trigger_sync(str(user_id), store_slug)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Sync service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/status", response_model=list[SyncStatusResponse])
|
||||
async def sync_status(user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
return await client.get_sync_status(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Shopping routes: optimize list, saved lists."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse
|
||||
from cartsnitch_api.services.clipartist import ClipArtistClient
|
||||
|
||||
router = APIRouter(prefix="/shopping", tags=["shopping"])
|
||||
|
||||
|
||||
@router.post("/optimize", response_model=OptimizeResponse)
|
||||
async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
result = await client.optimize(
|
||||
user_id=str(user_id),
|
||||
items=[item.model_dump() for item in body.items],
|
||||
preferred_stores=(
|
||||
[str(s) for s in body.preferred_stores] if body.preferred_stores else None
|
||||
),
|
||||
)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Shopping optimization service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping optimization service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/lists", response_model=list[ShoppingListResponse])
|
||||
async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
return await client.get_shopping_lists(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping service",
|
||||
) from None
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Store routes: list stores, manage user store connections."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse
|
||||
from cartsnitch_api.services.stores import StoreService
|
||||
|
||||
router = APIRouter(tags=["stores"])
|
||||
|
||||
|
||||
@router.get("/stores", response_model=list[StoreResponse])
|
||||
async def list_stores(db: AsyncSession = Depends(get_db)):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_stores()
|
||||
|
||||
|
||||
@router.get("/me/stores", response_model=list[StoreAccountResponse])
|
||||
async def list_user_stores(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_user_stores(user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/me/stores/{store_slug}/connect",
|
||||
response_model=StoreAccountResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def connect_store(
|
||||
store_slug: str,
|
||||
body: ConnectStoreRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
return await svc.connect_store(user_id, store_slug, body.credentials)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def disconnect_store(
|
||||
store_slug: str,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
await svc.disconnect_store(user_id, store_slug)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Pydantic v2 request/response schemas for all API endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
# ---------- Auth ----------
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
display_name: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
email: EmailStr | None = None
|
||||
display_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
display_name: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------- Stores ----------
|
||||
|
||||
|
||||
class StoreResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
logo_url: str | None = None
|
||||
supported: bool = True
|
||||
|
||||
|
||||
class StoreAccountResponse(BaseModel):
|
||||
store: StoreResponse
|
||||
connected: bool
|
||||
last_sync_at: datetime | None = None
|
||||
sync_status: str | None = None
|
||||
|
||||
|
||||
class ConnectStoreRequest(BaseModel):
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
# ---------- Purchases ----------
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
id: UUID
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: float
|
||||
unit_price: float
|
||||
total_price: float
|
||||
|
||||
|
||||
class PurchaseResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
purchased_at: datetime
|
||||
total: float
|
||||
item_count: int
|
||||
|
||||
|
||||
class PurchaseDetailResponse(PurchaseResponse):
|
||||
line_items: list[LineItemResponse]
|
||||
|
||||
|
||||
class PurchaseStatsResponse(BaseModel):
|
||||
total_spent: float
|
||||
purchase_count: int
|
||||
by_store: dict[str, float]
|
||||
by_period: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Products ----------
|
||||
|
||||
|
||||
class ProductResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brand: str | None = None
|
||||
category: str | None = None
|
||||
upc: str | None = None
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class ProductDetailResponse(ProductResponse):
|
||||
prices_by_store: list["StorePriceResponse"]
|
||||
|
||||
|
||||
class StorePriceResponse(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
current_price: float
|
||||
last_seen_at: datetime
|
||||
|
||||
|
||||
# ---------- Prices ----------
|
||||
|
||||
|
||||
class PriceTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list["PricePointResponse"]
|
||||
|
||||
|
||||
class PricePointResponse(BaseModel):
|
||||
date: datetime
|
||||
price: float
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
|
||||
|
||||
class PriceIncreaseResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
store_name: str
|
||||
old_price: float
|
||||
new_price: float
|
||||
increase_pct: float
|
||||
detected_at: datetime
|
||||
|
||||
|
||||
class PriceComparisonResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
prices: list[StorePriceResponse]
|
||||
|
||||
|
||||
# ---------- Coupons ----------
|
||||
|
||||
|
||||
class CouponResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
description: str
|
||||
discount_value: float
|
||||
discount_type: str
|
||||
product_id: UUID | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
# ---------- Shopping ----------
|
||||
|
||||
|
||||
class ShoppingListItemRequest(BaseModel):
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: int = 1
|
||||
|
||||
|
||||
class OptimizeRequest(BaseModel):
|
||||
items: list[ShoppingListItemRequest]
|
||||
preferred_stores: list[UUID] | None = None
|
||||
|
||||
|
||||
class OptimizedStoreTrip(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
items: list["OptimizedItemResponse"]
|
||||
subtotal: float
|
||||
coupons: list[CouponResponse]
|
||||
savings: float
|
||||
|
||||
|
||||
class OptimizedItemResponse(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
product_id: UUID | None = None
|
||||
|
||||
|
||||
class OptimizeResponse(BaseModel):
|
||||
trips: list[OptimizedStoreTrip]
|
||||
total_cost: float
|
||||
total_savings: float
|
||||
|
||||
|
||||
class ShoppingListResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
item_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------- Alerts ----------
|
||||
|
||||
|
||||
class AlertResponse(BaseModel):
|
||||
id: UUID
|
||||
alert_type: str
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
message: str
|
||||
triggered_at: datetime
|
||||
read: bool = False
|
||||
|
||||
|
||||
class AlertSettingsRequest(BaseModel):
|
||||
price_increase_threshold_pct: float | None = None
|
||||
shrinkflation_enabled: bool | None = None
|
||||
email_notifications: bool | None = None
|
||||
|
||||
|
||||
class AlertSettingsResponse(BaseModel):
|
||||
price_increase_threshold_pct: float
|
||||
shrinkflation_enabled: bool
|
||||
email_notifications: bool
|
||||
|
||||
|
||||
# ---------- Scraping ----------
|
||||
|
||||
|
||||
class SyncTriggerResponse(BaseModel):
|
||||
job_id: UUID
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class SyncStatusResponse(BaseModel):
|
||||
store_slug: str
|
||||
status: str
|
||||
last_sync_at: datetime | None = None
|
||||
items_synced: int | None = None
|
||||
|
||||
|
||||
# ---------- Public ----------
|
||||
|
||||
|
||||
class PublicTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list[PricePointResponse]
|
||||
|
||||
|
||||
class PublicStoreComparisonResponse(BaseModel):
|
||||
products: list[PriceComparisonResponse]
|
||||
|
||||
|
||||
class PublicInflationResponse(BaseModel):
|
||||
period: str
|
||||
cartsnitch_index: float
|
||||
cpi_baseline: float
|
||||
categories: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Common ----------
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: str
|
||||
code: str | None = None
|
||||
|
||||
|
||||
# Rebuild forward refs
|
||||
ProductDetailResponse.model_rebuild()
|
||||
PriceTrendResponse.model_rebuild()
|
||||
OptimizedStoreTrip.model_rebuild()
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Alert service — price and shrinkflation alerts for users.
|
||||
|
||||
Alerts are generated by StickerShock and ShrinkRay services and written to the DB.
|
||||
This service reads them for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class AlertService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_alerts(self, user_id: UUID) -> list[dict]:
|
||||
"""List shrinkflation events for products the user has purchased."""
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
|
||||
|
||||
# Get product IDs from user's purchases
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(ShrinkflationEvent)
|
||||
.where(ShrinkflationEvent.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(ShrinkflationEvent.normalized_product))
|
||||
.order_by(ShrinkflationEvent.detected_date.desc())
|
||||
)
|
||||
events = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"alert_type": "shrinkflation",
|
||||
"product_id": e.normalized_product_id,
|
||||
"product_name": e.normalized_product.canonical_name,
|
||||
"message": (
|
||||
f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}"
|
||||
),
|
||||
"triggered_at": e.detected_date,
|
||||
"read": False,
|
||||
}
|
||||
for e in events
|
||||
]
|
||||
|
||||
async def get_settings(self, user_id: UUID) -> dict:
|
||||
# Alert settings would be stored in a user_settings table.
|
||||
# For now, return defaults since the table doesn't exist yet in common lib.
|
||||
return {
|
||||
"price_increase_threshold_pct": 5.0,
|
||||
"shrinkflation_enabled": True,
|
||||
"email_notifications": False,
|
||||
}
|
||||
|
||||
async def update_settings(self, user_id: UUID, **fields) -> dict:
|
||||
# Would update user_settings table. Return merged defaults for now.
|
||||
current = await self.get_settings(user_id)
|
||||
for k, v in fields.items():
|
||||
if v is not None and k in current:
|
||||
current[k] = v
|
||||
return current
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Auth service — user registration, login, token management."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
|
||||
from cartsnitch_api.auth.passwords import hash_password, verify_password
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def register(self, email: str, password: str, display_name: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
existing = await self.db.execute(select(User).where(User.email == email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=hash_password(password),
|
||||
display_name=display_name,
|
||||
)
|
||||
self.db.add(user)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def login(self, email: str, password: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
return self._make_token_response(user.id)
|
||||
|
||||
async def refresh(self, refresh_token: str) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid refresh token") from None
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise ValueError("Invalid token type") from None
|
||||
|
||||
user_id = UUID(payload["sub"])
|
||||
|
||||
# Verify the user still exists before issuing new tokens
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise ValueError("User no longer exists")
|
||||
|
||||
return self._make_token_response(user_id)
|
||||
|
||||
async def get_user(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
if "display_name" in fields and fields["display_name"] is not None:
|
||||
user.display_name = fields["display_name"]
|
||||
if "email" in fields and fields["email"] is not None:
|
||||
existing = await self.db.execute(
|
||||
select(User).where(User.email == fields["email"], User.id != user_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already in use")
|
||||
user.email = fields["email"]
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def delete_user(self, user_id: UUID) -> None:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
await self.db.delete(user)
|
||||
await self.db.commit()
|
||||
|
||||
def _make_token_response(self, user_id: UUID) -> dict:
|
||||
return {
|
||||
"access_token": create_access_token(user_id),
|
||||
"refresh_token": create_refresh_token(user_id),
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.jwt_access_token_expire_minutes * 60,
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
"""HTTP client for ClipArtist internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ClipArtistClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.clipartist_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def optimize(
|
||||
self,
|
||||
user_id: str,
|
||||
items: list[dict],
|
||||
preferred_stores: list[str] | None = None,
|
||||
) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/optimize",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"items": items,
|
||||
"preferred_stores": preferred_stores,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_shopping_lists(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/shopping-lists",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_relevant_coupons(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/coupons/relevant",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Coupon service — browse coupons, find relevant ones."""
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class CouponService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_coupons(self, store_id: UUID | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import Coupon
|
||||
|
||||
today = date.today()
|
||||
query = (
|
||||
select(Coupon)
|
||||
.where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)))
|
||||
.options(selectinload(Coupon.store))
|
||||
.order_by(Coupon.valid_to.asc().nullslast())
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Coupon.store_id == store_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
async def relevant_coupons(self, user_id: UUID) -> list[dict]:
|
||||
"""Coupons for products the user has purchased."""
|
||||
from cartsnitch_api.models import Coupon, PurchaseItem
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Get product IDs from user's purchase history
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Coupon)
|
||||
.where(
|
||||
Coupon.normalized_product_id.in_(product_ids),
|
||||
(Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)),
|
||||
)
|
||||
.options(selectinload(Coupon.store))
|
||||
)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
def _to_dict(self, c) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"store_id": c.store_id,
|
||||
"store_name": c.store.name,
|
||||
"description": c.description or c.title,
|
||||
"discount_value": float(c.discount_value) if c.discount_value else 0,
|
||||
"discount_type": c.discount_type,
|
||||
"product_id": c.normalized_product_id,
|
||||
"expires_at": c.valid_to,
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Price service — trends, increases, comparison."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PriceService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trends(self, category: str | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
query = (
|
||||
select(PriceHistory)
|
||||
.join(NormalizedProduct)
|
||||
.options(
|
||||
selectinload(PriceHistory.store),
|
||||
selectinload(PriceHistory.normalized_product),
|
||||
)
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
prices = result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
by_product: dict[UUID, dict] = {}
|
||||
for ph in prices:
|
||||
pid = ph.normalized_product_id
|
||||
if pid not in by_product:
|
||||
by_product[pid] = {
|
||||
"product_id": pid,
|
||||
"product_name": ph.normalized_product.canonical_name,
|
||||
"data_points": [],
|
||||
}
|
||||
by_product[pid]["data_points"].append(
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
)
|
||||
return list(by_product.values())
|
||||
|
||||
async def get_increases(self) -> list[dict]:
|
||||
"""Find products with recent significant price increases.
|
||||
|
||||
Uses a window function (lag) to compare each price observation with the
|
||||
previous one per product+store, avoiding the N+1 query pattern.
|
||||
"""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
# Use lag() window function to get previous price in a single query
|
||||
prev_price = (
|
||||
func.lag(PriceHistory.regular_price)
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date,
|
||||
)
|
||||
.label("prev_price")
|
||||
)
|
||||
|
||||
row_num = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date.desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
|
||||
inner = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
PriceHistory.regular_price,
|
||||
PriceHistory.observed_date,
|
||||
prev_price,
|
||||
row_num,
|
||||
).subquery()
|
||||
|
||||
# Only keep the latest row (rn=1) where price increased
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
inner.c.normalized_product_id,
|
||||
inner.c.store_id,
|
||||
inner.c.regular_price,
|
||||
inner.c.observed_date,
|
||||
inner.c.prev_price,
|
||||
NormalizedProduct.canonical_name,
|
||||
Store.name.label("store_name"),
|
||||
)
|
||||
.join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id)
|
||||
.join(Store, Store.id == inner.c.store_id)
|
||||
.where(
|
||||
inner.c.rn == 1,
|
||||
inner.c.prev_price.isnot(None),
|
||||
inner.c.regular_price > inner.c.prev_price,
|
||||
)
|
||||
)
|
||||
|
||||
increases = []
|
||||
for row in result.all():
|
||||
old = float(row.prev_price)
|
||||
new = float(row.regular_price)
|
||||
increases.append(
|
||||
{
|
||||
"product_id": row.normalized_product_id,
|
||||
"product_name": row.canonical_name,
|
||||
"store_name": row.store_name,
|
||||
"old_price": old,
|
||||
"new_price": new,
|
||||
"increase_pct": round((new - old) / old * 100, 2),
|
||||
"detected_at": row.observed_date,
|
||||
}
|
||||
)
|
||||
|
||||
increases.sort(key=lambda x: x["increase_pct"], reverse=True)
|
||||
return increases
|
||||
|
||||
async def get_comparison(self, product_ids: list[UUID]) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
# Fetch all requested products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group prices by product
|
||||
prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
comparisons = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
comparisons.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
return comparisons
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Product service — catalog, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class ProductService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_products(
|
||||
self,
|
||||
q: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct
|
||||
|
||||
query = select(NormalizedProduct)
|
||||
if q:
|
||||
# Escape SQL LIKE wildcards in user input
|
||||
safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%"))
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
query = query.order_by(NormalizedProduct.canonical_name)
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
products = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.canonical_name,
|
||||
"brand": p.brand,
|
||||
"category": p.category,
|
||||
"upc": (p.upc_variants[0] if p.upc_variants else None),
|
||||
"image_url": None,
|
||||
}
|
||||
for p in products
|
||||
]
|
||||
|
||||
async def get_product(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
# Get latest price per store
|
||||
subq = latest_price_per_store([product_id])
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"id": product.id,
|
||||
"name": product.canonical_name,
|
||||
"brand": product.brand,
|
||||
"category": product.category,
|
||||
"upc": (product.upc_variants[0] if product.upc_variants else None),
|
||||
"image_url": None,
|
||||
"prices_by_store": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_price_history(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Public service — unauthenticated price transparency endpoints."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PublicService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trend(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return {"products": []}
|
||||
|
||||
# Fetch all products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
prices_by_product: dict[UUID, list] = {}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
products = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
products.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return {"products": products}
|
||||
|
||||
async def get_inflation(self) -> dict:
|
||||
"""Aggregate price change stats. Compares average prices across periods."""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
# Get average prices grouped by category for recent vs older data
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
)
|
||||
.join(NormalizedProduct)
|
||||
.group_by(NormalizedProduct.category)
|
||||
)
|
||||
categories = {}
|
||||
for row in result.all():
|
||||
cat, avg_price = row
|
||||
if cat:
|
||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||
|
||||
return {
|
||||
"period": "all-time",
|
||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||
"cpi_baseline": 100.0,
|
||||
"categories": categories,
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Purchase service — list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class PurchaseService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_purchases(
|
||||
self,
|
||||
user_id: UUID,
|
||||
store_id: UUID | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store
|
||||
|
||||
# Count items per purchase in a single subquery instead of N+1
|
||||
item_counts = (
|
||||
select(
|
||||
PurchaseItem.purchase_id,
|
||||
func.count().label("item_count"),
|
||||
)
|
||||
.group_by(PurchaseItem.purchase_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(Purchase, item_counts.c.item_count, Store.name.label("store_name"))
|
||||
.join(Store, Store.id == Purchase.store_id)
|
||||
.outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id)
|
||||
.where(Purchase.user_id == user_id)
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Purchase.store_id == store_id)
|
||||
|
||||
query = query.order_by(Purchase.purchase_date.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"store_id": p.store_id,
|
||||
"store_name": store_name,
|
||||
"purchased_at": p.purchase_date,
|
||||
"total": float(p.total),
|
||||
"item_count": item_count or 0,
|
||||
}
|
||||
for p, item_count, store_name in result.all()
|
||||
]
|
||||
|
||||
async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.id == purchase_id, Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store), selectinload(Purchase.items))
|
||||
)
|
||||
purchase = result.scalar_one_or_none()
|
||||
if not purchase:
|
||||
raise LookupError("Purchase not found")
|
||||
|
||||
return {
|
||||
"id": purchase.id,
|
||||
"store_id": purchase.store_id,
|
||||
"store_name": purchase.store.name,
|
||||
"purchased_at": purchase.purchase_date,
|
||||
"total": float(purchase.total),
|
||||
"item_count": len(purchase.items),
|
||||
"line_items": [
|
||||
{
|
||||
"id": item.id,
|
||||
"product_id": item.normalized_product_id,
|
||||
"name": item.product_name_raw,
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"total_price": float(item.extended_price),
|
||||
}
|
||||
for item in purchase.items
|
||||
],
|
||||
}
|
||||
|
||||
async def get_stats(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store))
|
||||
)
|
||||
purchases = result.scalars().all()
|
||||
|
||||
total_spent = sum(float(p.total) for p in purchases)
|
||||
by_store: dict[str, float] = {}
|
||||
by_period: dict[str, float] = {}
|
||||
|
||||
for p in purchases:
|
||||
store_name = p.store.name
|
||||
by_store[store_name] = by_store.get(store_name, 0) + float(p.total)
|
||||
period = p.purchase_date.strftime("%Y-%m")
|
||||
by_period[period] = by_period.get(period, 0) + float(p.total)
|
||||
|
||||
return {
|
||||
"total_spent": total_spent,
|
||||
"purchase_count": len(purchases),
|
||||
"by_store": by_store,
|
||||
"by_period": by_period,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Shared query helpers for service layer."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
|
||||
def latest_price_per_store(product_ids: list[UUID] | None = None):
|
||||
"""Subquery returning the latest observed_date per product+store.
|
||||
|
||||
Optionally filtered to a list of product IDs. Returns a subquery with
|
||||
columns: normalized_product_id, store_id, max_date.
|
||||
"""
|
||||
from cartsnitch_api.models import PriceHistory
|
||||
|
||||
query = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
func.max(PriceHistory.observed_date).label("max_date"),
|
||||
).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id)
|
||||
if product_ids is not None:
|
||||
query = query.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
return query.subquery()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""HTTP client for ReceiptWitness internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ReceiptWitnessClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.receiptwitness_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def trigger_sync(self, user_id: str, store_slug: str) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/sync/{store_slug}",
|
||||
headers=self.headers,
|
||||
json={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_sync_status(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/sync/status",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,23 @@
|
||||
"""HTTP client for ShrinkRay internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ShrinkRayClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.shrinkray_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/alerts",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,32 @@
|
||||
"""HTTP client for StickerShock internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class StickerShockClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.stickershock_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_price_increases(self, params: dict | None = None) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/increases",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_inflation_data(self) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/inflation",
|
||||
headers=self.headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Store service — list stores, manage user store account connections."""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class StoreService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_stores(self) -> list[dict]:
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
result = await self.db.execute(select(Store).order_by(Store.name))
|
||||
stores = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"slug": s.slug,
|
||||
"logo_url": s.logo_url,
|
||||
"supported": True,
|
||||
}
|
||||
for s in stores
|
||||
]
|
||||
|
||||
async def list_user_stores(self, user_id: UUID) -> list[dict]:
|
||||
from cartsnitch_api.models import UserStoreAccount
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount)
|
||||
.where(UserStoreAccount.user_id == user_id)
|
||||
.options(selectinload(UserStoreAccount.store))
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"store": {
|
||||
"id": a.store.id,
|
||||
"name": a.store.name,
|
||||
"slug": a.store.slug,
|
||||
"logo_url": a.store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": a.status == "active",
|
||||
"last_sync_at": a.last_sync_at,
|
||||
"sync_status": a.status,
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Store account already connected")
|
||||
|
||||
encrypted_data = None
|
||||
if credentials:
|
||||
fernet = _get_fernet()
|
||||
encrypted_data = {
|
||||
"encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode()
|
||||
}
|
||||
|
||||
account = UserStoreAccount(
|
||||
user_id=user_id,
|
||||
store_id=store.id,
|
||||
session_data=encrypted_data,
|
||||
status="active",
|
||||
)
|
||||
self.db.add(account)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(account)
|
||||
|
||||
return {
|
||||
"store": {
|
||||
"id": store.id,
|
||||
"name": store.name,
|
||||
"slug": store.slug,
|
||||
"logo_url": store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": True,
|
||||
"last_sync_at": None,
|
||||
"sync_status": "active",
|
||||
}
|
||||
|
||||
async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise LookupError("Store account not connected")
|
||||
|
||||
await self.db.delete(account)
|
||||
await self.db.commit()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Custom SQLAlchemy column types."""
|
||||
|
||||
import json
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class EncryptedJSON(TypeDecorator):
|
||||
"""SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet.
|
||||
|
||||
Stores data as a Fernet-encrypted text blob in the database.
|
||||
On read, decrypts and deserialises back to a Python dict/list.
|
||||
"""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
plaintext = json.dumps(value).encode()
|
||||
return _get_fernet().encrypt(plaintext).decode()
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
decrypted = _get_fernet().decrypt(value.encode())
|
||||
return json.loads(decrypted)
|
||||
Reference in New Issue
Block a user