feat: merge cartsnitch/api into api/ subdirectory

Consolidate API gateway service into monorepo.
Squashed from https://github.com/cartsnitch/api main (89bacb1).

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Coupon Carl
2026-03-28 02:24:02 +00:00
commit b7e6f637a7
91 changed files with 6296 additions and 0 deletions
View File
View File
+39
View File
@@ -0,0 +1,39 @@
"""FastAPI dependency injection for authentication."""
from uuid import UUID
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from cartsnitch_api.auth.jwt import decode_token
from cartsnitch_api.config import settings
bearer_scheme = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> UUID:
try:
payload = decode_token(credentials.credentials)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
) from None
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
) from None
return UUID(payload["sub"])
async def verify_service_key(x_service_key: str = Header()) -> None:
if x_service_key != settings.service_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid service key",
)
+31
View File
@@ -0,0 +1,31 @@
"""JWT token creation and validation."""
from datetime import UTC, datetime, timedelta
from typing import Any, cast
from uuid import UUID
from jose import JWTError, jwt
from cartsnitch_api.config import settings
def create_access_token(user_id: UUID) -> str:
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
payload = {"sub": str(user_id), "exp": expire, "type": "access"}
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
def create_refresh_token(user_id: UUID) -> str:
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
def decode_token(token: str) -> dict:
try:
return cast(
dict[str, Any],
jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]),
)
except JWTError as e:
raise ValueError(f"Invalid token: {e}") from e
+11
View File
@@ -0,0 +1,11 @@
"""Password hashing and verification with bcrypt."""
import bcrypt
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
+96
View File
@@ -0,0 +1,96 @@
"""Auth routes: register, login, refresh, me, update, delete."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import (
LoginRequest,
RefreshRequest,
RegisterRequest,
TokenResponse,
UpdateUserRequest,
UserResponse,
)
from cartsnitch_api.services.auth import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
svc = AuthService(db)
try:
return await svc.register(body.email, body.password, body.display_name)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
svc = AuthService(db)
try:
return await svc.login(body.email, body.password)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
) from None
@router.post("/refresh", response_model=TokenResponse)
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
svc = AuthService(db)
try:
return await svc.refresh(body.refresh_token)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
) from None
@router.get("/me", response_model=UserResponse)
async def get_me(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
try:
return await svc.get_user(user_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) from None
@router.patch("/me", response_model=UserResponse)
async def update_me(
body: UpdateUserRequest,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
try:
return await svc.update_user(user_id, email=body.email, display_name=body.display_name)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) from None
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
async def delete_me(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AuthService(db)
try:
await svc.delete_user(user_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) from None
+26
View File
@@ -0,0 +1,26 @@
"""Redis/DragonflyDB caching helpers."""
from cartsnitch_api.config import settings
class CacheClient:
"""Stub for Redis/DragonflyDB caching.
Will be used for expensive queries: price trends, product comparisons.
Cache invalidation via Redis pub/sub events from other services.
"""
def __init__(self) -> None:
self.url = settings.redis_url
async def get(self, key: str) -> str | None:
# TODO: implement with redis-py async
return None
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
# TODO: implement with redis-py async
pass
async def delete(self, key: str) -> None:
# TODO: implement with redis-py async
pass
+51
View File
@@ -0,0 +1,51 @@
import base64
from pydantic import model_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
model_config = {"env_prefix": "CARTSNITCH_"}
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
redis_url: str = "redis://localhost:6379/0"
jwt_secret_key: str = "change-me-in-production"
jwt_algorithm: str = "HS256"
jwt_access_token_expire_minutes: int = 15
jwt_refresh_token_expire_days: int = 7
service_key: str = "change-me-in-production"
# Valid Fernet key for local dev — MUST be overridden in production
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
receiptwitness_url: str = "http://receiptwitness:8001"
stickershock_url: str = "http://stickershock:8002"
clipartist_url: str = "http://clipartist:8003"
shrinkray_url: str = "http://shrinkray:8004"
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
rate_limit_enabled: bool = True
@model_validator(mode="after")
def validate_fernet_key(self):
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
try:
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
if len(decoded) != 32:
raise ValueError
except Exception:
raise ValueError(
"CARTSNITCH_FERNET_KEY must be a valid Fernet key "
"(32 bytes, url-safe base64 encoded). "
"Generate one with: python -c "
"'from cryptography.fernet import Fernet; "
"print(Fernet.generate_key().decode())'"
) from None
return self
settings = Settings()
+85
View File
@@ -0,0 +1,85 @@
"""Constants and enums shared across CartSnitch services."""
from enum import StrEnum
class StoreSlug(StrEnum):
"""Supported retailer slugs."""
MEIJER = "meijer"
KROGER = "kroger"
TARGET = "target"
class AccountStatus(StrEnum):
"""User store account link status."""
ACTIVE = "active"
EXPIRED = "expired"
ERROR = "error"
class DiscountType(StrEnum):
"""Coupon discount type."""
PERCENT = "percent"
FIXED = "fixed"
BOGO = "bogo"
BUY_X_GET_Y = "buy_x_get_y"
class PriceSource(StrEnum):
"""Source of a price observation."""
RECEIPT = "receipt"
CATALOG = "catalog"
WEEKLY_AD = "weekly_ad"
class EventType(StrEnum):
"""Redis pub/sub event types."""
RECEIPTS_INGESTED = "cartsnitch.receipts.ingested"
PRICES_UPDATED = "cartsnitch.prices.updated"
PRODUCTS_NORMALIZED = "cartsnitch.products.normalized"
COUPONS_UPDATED = "cartsnitch.coupons.updated"
ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase"
ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation"
class ProductCategory(StrEnum):
"""Top-level product categories."""
PRODUCE = "produce"
DAIRY = "dairy"
MEAT = "meat"
BAKERY = "bakery"
FROZEN = "frozen"
PANTRY = "pantry"
BEVERAGES = "beverages"
SNACKS = "snacks"
HOUSEHOLD = "household"
PERSONAL_CARE = "personal_care"
OTHER = "other"
class MatchConfidence(StrEnum):
"""Confidence level for product matching."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class SizeUnit(StrEnum):
"""Standardized product size units."""
OZ = "oz"
FL_OZ = "fl_oz"
LB = "lb"
G = "g"
KG = "kg"
ML = "ml"
L = "l"
CT = "ct"
PK = "pk"
+16
View File
@@ -0,0 +1,16 @@
"""Database session management for the API gateway."""
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from cartsnitch_api.config import settings
engine = create_async_engine(settings.database_url, echo=False)
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session."""
async with async_session_factory() as session:
yield session
+62
View File
@@ -0,0 +1,62 @@
"""FastAPI app factory for CartSnitch API Gateway."""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from cartsnitch_api.auth.routes import router as auth_router
from cartsnitch_api.middleware.cors import add_cors_middleware
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
from cartsnitch_api.routes.alerts import router as alerts_router
from cartsnitch_api.routes.coupons import router as coupons_router
from cartsnitch_api.routes.health import router as health_router
from cartsnitch_api.routes.prices import router as prices_router
from cartsnitch_api.routes.products import router as products_router
from cartsnitch_api.routes.public import router as public_router
from cartsnitch_api.routes.purchases import router as purchases_router
from cartsnitch_api.routes.scraping import router as scraping_router
from cartsnitch_api.routes.shopping import router as shopping_router
from cartsnitch_api.routes.stores import router as stores_router
@asynccontextmanager
async def lifespan(app: FastAPI):
# TODO: initialize DB session pool, Redis connection, service clients
yield
# TODO: cleanup connections
def create_app() -> FastAPI:
app = FastAPI(
title="CartSnitch API",
description="Grocery price tracking and shrinkflation detection API",
version="0.1.0",
lifespan=lifespan,
)
# Middleware (order matters — outermost first)
add_cors_middleware(app)
add_error_monitor_middleware(app)
add_rate_limit_middleware(app)
# Exception handlers
add_error_handlers(app)
# Routers
app.include_router(health_router)
app.include_router(auth_router)
app.include_router(stores_router)
app.include_router(purchases_router)
app.include_router(products_router)
app.include_router(prices_router)
app.include_router(coupons_router)
app.include_router(shopping_router)
app.include_router(alerts_router)
app.include_router(scraping_router)
app.include_router(public_router)
return app
app = create_app()
+16
View File
@@ -0,0 +1,16 @@
"""CORS middleware configuration."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from cartsnitch_api.config import settings
def add_cors_middleware(app: FastAPI) -> None:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@@ -0,0 +1,190 @@
"""Structured error responses and error monitoring.
Ensures all errors return a consistent JSON shape and never leak stack traces.
Provides hooks for error monitoring/alerting.
"""
import logging
import time
import traceback
from collections.abc import Awaitable, Callable
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger("cartsnitch_api.errors")
def _error_response(
status_code: int,
detail: str,
code: str | None = None,
errors: list[dict] | None = None,
) -> JSONResponse:
"""Build a consistent error response."""
body: dict = {"detail": detail}
if code:
body["code"] = code
if errors:
body["errors"] = errors
return JSONResponse(status_code=status_code, content=body)
def add_error_handlers(app: FastAPI) -> None:
"""Register global exception handlers for consistent error responses."""
@app.exception_handler(RequestValidationError)
async def validation_error_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Return 422 with structured field-level error details."""
field_errors = []
for err in exc.errors():
loc = err.get("loc", ())
field_errors.append(
{
"field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc),
"message": err.get("msg", "Invalid value"),
"type": err.get("type", "value_error"),
}
)
return _error_response(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Validation error",
code="VALIDATION_ERROR",
errors=field_errors,
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
"""Wrap HTTP exceptions (Starlette and FastAPI) in consistent format."""
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
return _error_response(
status_code=exc.status_code,
detail=detail,
code=_status_to_code(exc.status_code),
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Catch-all: log full traceback, return safe 500 to client."""
logger.error(
"Unhandled exception on %s %s: %s\n%s",
request.method,
request.url.path,
exc,
traceback.format_exc(),
)
_notify_error_monitor(request, exc)
return _error_response(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
code="INTERNAL_ERROR",
)
def _status_to_code(status_code: int) -> str:
"""Map HTTP status code to a machine-readable error code."""
mapping = {
400: "BAD_REQUEST",
401: "UNAUTHORIZED",
403: "FORBIDDEN",
404: "NOT_FOUND",
409: "CONFLICT",
422: "VALIDATION_ERROR",
429: "RATE_LIMITED",
502: "BAD_GATEWAY",
503: "SERVICE_UNAVAILABLE",
}
return mapping.get(status_code, f"HTTP_{status_code}")
# ---------- Error Monitoring ----------
class _ErrorMonitor:
"""Simple error counter for monitoring and alerting hooks.
Tracks error counts and rates. In production, this would forward
to an external monitoring service (Prometheus, Sentry, etc.).
"""
def __init__(self) -> None:
self.error_counts: dict[int, int] = {}
self.recent_5xx: list[dict] = []
self._max_recent = 100
def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None:
self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1
if status_code >= 500:
entry = {
"timestamp": time.time(),
"status": status_code,
"path": path,
"method": method,
"error": error,
}
self.recent_5xx.append(entry)
if len(self.recent_5xx) > self._max_recent:
self.recent_5xx = self.recent_5xx[-self._max_recent :]
logger.warning(
"5xx error recorded: %s %s -> %d (%s)",
method,
path,
status_code,
error or "unknown",
)
def get_stats(self) -> dict:
return {
"error_counts": dict(self.error_counts),
"recent_5xx_count": len(self.recent_5xx),
}
_monitor = _ErrorMonitor()
def get_error_monitor() -> _ErrorMonitor:
"""Access the global error monitor (for health/metrics endpoints)."""
return _monitor
def _notify_error_monitor(request: Request, exc: Exception) -> None:
"""Record unhandled exception in the error monitor."""
_monitor.record(
status_code=500,
path=request.url.path,
method=request.method,
error=str(exc)[:200],
)
class ErrorMonitorMiddleware(BaseHTTPMiddleware):
"""Middleware to track all 4xx/5xx responses for monitoring."""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable],
):
response = await call_next(request)
if response.status_code >= 400:
_monitor.record(
status_code=response.status_code,
path=request.url.path,
method=request.method,
)
return response
def add_error_monitor_middleware(app: FastAPI) -> None:
app.add_middleware(ErrorMonitorMiddleware)
+111
View File
@@ -0,0 +1,111 @@
"""Rate limiting middleware for public and authenticated endpoints.
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
"""
import time
from collections import defaultdict
from threading import Lock
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings
class _SlidingWindowCounter:
"""Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
self.max_requests = max_requests
self.window_seconds = window_seconds
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic()
cutoff = now - self.window_seconds
with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key])
if current_count >= self.max_requests:
retry_after = int(self._hits[key][0] - cutoff) + 1
return False, 0, retry_after
self._hits[key].append(now)
remaining = self.max_requests - current_count - 1
return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token)
_public_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds,
)
_auth_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
window_seconds=settings.rate_limit_window_seconds,
)
def _get_client_ip(request: Request) -> str:
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
"""Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter
# For authenticated endpoints, use Bearer token as key if present
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
# Use last 16 chars of token as key to avoid storing full tokens
return f"token:{token[-16:]}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request)
key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key)
if not allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"code": "RATE_LIMITED",
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limiter.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
def add_rate_limit_middleware(app: FastAPI) -> None:
app.add_middleware(RateLimitMiddleware)
+26
View File
@@ -0,0 +1,26 @@
"""SQLAlchemy ORM models — re-exports all models for convenience."""
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
from cartsnitch_api.models.coupon import Coupon
from cartsnitch_api.models.price import PriceHistory
from cartsnitch_api.models.product import NormalizedProduct
from cartsnitch_api.models.purchase import Purchase, PurchaseItem
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
from cartsnitch_api.models.store import Store, StoreLocation
from cartsnitch_api.models.user import User, UserStoreAccount
__all__ = [
"Base",
"TimestampMixin",
"UUIDPrimaryKeyMixin",
"Store",
"StoreLocation",
"User",
"UserStoreAccount",
"Purchase",
"PurchaseItem",
"NormalizedProduct",
"PriceHistory",
"Coupon",
"ShrinkflationEvent",
]
+30
View File
@@ -0,0 +1,30 @@
"""Base model and mixins for all CartSnitch ORM models."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all CartSnitch models."""
class TimestampMixin:
"""Mixin providing created_at / updated_at columns."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
)
class UUIDPrimaryKeyMixin:
"""Mixin providing a UUID primary key."""
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
)
+42
View File
@@ -0,0 +1,42 @@
"""Coupon model."""
import uuid
from datetime import date, datetime
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import DiscountType
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct
from cartsnitch_api.models.store import Store
class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A coupon or deal for a product at a store."""
__tablename__ = "coupons"
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("normalized_products.id")
)
title: Mapped[str] = mapped_column(String(300), nullable=False)
description: Mapped[str | None] = mapped_column(String(1000))
discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False)
discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
valid_from: Mapped[date | None] = mapped_column(Date)
valid_to: Mapped[date | None] = mapped_column(Date)
requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
coupon_code: Mapped[str | None] = mapped_column(String(100))
source_url: Mapped[str | None] = mapped_column(String(500))
scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
# Relationships
store: Mapped["Store"] = relationship(back_populates="coupons")
normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons")
+50
View File
@@ -0,0 +1,50 @@
"""PriceHistory model — tracks product prices over time."""
import uuid
from datetime import date
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import Date, ForeignKey, Index, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import PriceSource
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct
from cartsnitch_api.models.purchase import PurchaseItem
from cartsnitch_api.models.store import Store
class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A single price observation for a product at a store on a date."""
__tablename__ = "price_history"
__table_args__ = (
Index(
"ix_price_history_product_store_date",
"normalized_product_id",
"store_id",
"observed_date",
),
)
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("normalized_products.id"), nullable=False
)
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
observed_date: Mapped[date] = mapped_column(Date, nullable=False)
regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
source: Mapped[PriceSource] = mapped_column(String(20), nullable=False)
purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id"))
# Relationships
normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories")
store: Mapped["Store"] = relationship(back_populates="price_histories")
purchase_item: Mapped["PurchaseItem | None"] = relationship(
back_populates="price_history_entries"
)
+39
View File
@@ -0,0 +1,39 @@
"""NormalizedProduct model — the canonical product identity."""
from typing import TYPE_CHECKING
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import ProductCategory, SizeUnit
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.coupon import Coupon
from cartsnitch_api.models.price import PriceHistory
from cartsnitch_api.models.purchase import PurchaseItem
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Canonical product identity — matches products across retailers."""
__tablename__ = "normalized_products"
canonical_name: Mapped[str] = mapped_column(String(300), nullable=False)
category: Mapped[ProductCategory | None] = mapped_column(String(50))
subcategory: Mapped[str | None] = mapped_column(String(100))
brand: Mapped[str | None] = mapped_column(String(200))
size: Mapped[str | None] = mapped_column(String(50))
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
# Relationships
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
price_histories: Mapped[list["PriceHistory"]] = relationship(
back_populates="normalized_product"
)
coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product")
shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship(
back_populates="normalized_product"
)
+91
View File
@@ -0,0 +1,91 @@
"""Purchase and PurchaseItem models."""
import uuid
from datetime import date, datetime
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import (
JSON,
Date,
DateTime,
ForeignKey,
Index,
Numeric,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.price import PriceHistory
from cartsnitch_api.models.product import NormalizedProduct
from cartsnitch_api.models.store import Store, StoreLocation
from cartsnitch_api.models.user import User
class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A single shopping trip / receipt."""
__tablename__ = "purchases"
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
purchase_date: Mapped[date] = mapped_column(Date, nullable=False)
total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
source_url: Mapped[str | None] = mapped_column(String(500))
raw_data: Mapped[dict | None] = mapped_column(JSON)
ingested_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
user: Mapped["User"] = relationship(back_populates="purchases")
store: Mapped["Store"] = relationship(back_populates="purchases")
store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases")
items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase")
__table_args__ = (
Index("ix_purchases_user_store", "user_id", "store_id"),
UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
)
class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Individual line item on a receipt."""
__tablename__ = "purchase_items"
purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False)
product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False)
upc: Mapped[str | None] = mapped_column(String(20))
quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1)
unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
category_raw: Mapped[str | None] = mapped_column(String(100))
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("normalized_products.id")
)
# Relationships
purchase: Mapped["Purchase"] = relationship(back_populates="items")
normalized_product: Mapped["NormalizedProduct | None"] = relationship(
back_populates="purchase_items"
)
price_history_entries: Mapped[list["PriceHistory"]] = relationship(
back_populates="purchase_item"
)
@@ -0,0 +1,41 @@
"""ShrinkflationEvent model."""
import uuid
from datetime import date
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import Date, ForeignKey, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import SizeUnit
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct
class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Detected shrinkflation event — product size changed while price held or rose."""
__tablename__ = "shrinkflation_events"
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("normalized_products.id"), nullable=False
)
detected_date: Mapped[date] = mapped_column(Date, nullable=False)
old_size: Mapped[str] = mapped_column(String(50), nullable=False)
new_size: Mapped[str] = mapped_column(String(50), nullable=False)
old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
confidence: Mapped[Decimal] = mapped_column(
Numeric(3, 2), nullable=False, default=Decimal("1.00")
)
notes: Mapped[str | None] = mapped_column(String(1000))
# Relationships
normalized_product: Mapped["NormalizedProduct"] = relationship(
back_populates="shrinkflation_events"
)
+52
View File
@@ -0,0 +1,52 @@
"""Store and StoreLocation models."""
import uuid
from typing import TYPE_CHECKING
from sqlalchemy import Float, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import StoreSlug
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING:
from cartsnitch_api.models.coupon import Coupon
from cartsnitch_api.models.price import PriceHistory
from cartsnitch_api.models.purchase import Purchase
from cartsnitch_api.models.user import UserStoreAccount
class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Supported retailer."""
__tablename__ = "stores"
name: Mapped[str] = mapped_column(String(100), nullable=False)
slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True)
logo_url: Mapped[str | None] = mapped_column(String(500))
website_url: Mapped[str | None] = mapped_column(String(500))
# Relationships
locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store")
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store")
user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store")
price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store")
coupons: Mapped[list["Coupon"]] = relationship(back_populates="store")
class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Physical store location."""
__tablename__ = "store_locations"
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
address: Mapped[str] = mapped_column(String(300), nullable=False)
city: Mapped[str] = mapped_column(String(100), nullable=False)
state: Mapped[str] = mapped_column(String(2), nullable=False)
zip: Mapped[str] = mapped_column(String(10), nullable=False)
lat: Mapped[float | None] = mapped_column(Float)
lng: Mapped[float | None] = mapped_column(Float)
# Relationships
store: Mapped["Store"] = relationship(back_populates="locations")
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location")
+50
View File
@@ -0,0 +1,50 @@
"""User and UserStoreAccount models."""
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import AccountStatus
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
from cartsnitch_api.types import EncryptedJSON
if TYPE_CHECKING:
from cartsnitch_api.models.purchase import Purchase
from cartsnitch_api.models.store import Store
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Application user."""
__tablename__ = "users"
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
display_name: Mapped[str | None] = mapped_column(String(100))
# Relationships
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
purchases: Mapped[list["Purchase"]] = relationship(back_populates="user")
class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Link between a user and their retailer account credentials."""
__tablename__ = "user_store_accounts"
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
status: Mapped[AccountStatus] = mapped_column(
String(20), nullable=False, default=AccountStatus.ACTIVE
)
# Relationships
user: Mapped["User"] = relationship(back_populates="store_accounts")
store: Mapped["Store"] = relationship(back_populates="user_accounts")
+44
View File
@@ -0,0 +1,44 @@
"""Alert routes: list alerts, manage settings."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse
from cartsnitch_api.services.alerts import AlertService
router = APIRouter(prefix="/alerts", tags=["alerts"])
@router.get("", response_model=list[AlertResponse])
async def list_alerts(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AlertService(db)
return await svc.list_alerts(user_id)
@router.get("/settings", response_model=AlertSettingsResponse)
async def get_alert_settings(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = AlertService(db)
return await svc.get_settings(user_id)
@router.put("/settings")
async def update_alert_settings(
body: AlertSettingsRequest,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Alert settings persistence not yet implemented. "
"Use GET /alerts/settings for current defaults.",
)
+32
View File
@@ -0,0 +1,32 @@
"""Coupon routes: browse, relevant matches."""
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import CouponResponse
from cartsnitch_api.services.coupons import CouponService
router = APIRouter(prefix="/coupons", tags=["coupons"])
@router.get("", response_model=list[CouponResponse])
async def list_coupons(
store_id: UUID | None = Query(None),
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = CouponService(db)
return await svc.list_coupons(store_id)
@router.get("/relevant", response_model=list[CouponResponse])
async def relevant_coupons(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = CouponService(db)
return await svc.relevant_coupons(user_id)
+20
View File
@@ -0,0 +1,20 @@
"""Health check and error metrics endpoints."""
from fastapi import APIRouter, Depends
from cartsnitch_api.auth.dependencies import verify_service_key
from cartsnitch_api.middleware.error_handler import get_error_monitor
router = APIRouter(tags=["health"])
@router.get("/health")
async def health():
return {"status": "ok"}
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
async def error_stats():
"""Error monitoring stats — internal only (requires X-Service-Key)."""
monitor = get_error_monitor()
return monitor.get_stats()
+47
View File
@@ -0,0 +1,47 @@
"""Price routes: trends, increases, comparison."""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import (
PriceComparisonResponse,
PriceIncreaseResponse,
PriceTrendResponse,
)
from cartsnitch_api.services.prices import PriceService
router = APIRouter(prefix="/prices", tags=["prices"])
@router.get("/trends", response_model=list[PriceTrendResponse])
async def price_trends(
user_id: UUID = Depends(get_current_user),
category: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
svc = PriceService(db)
return await svc.get_trends(category)
@router.get("/increases", response_model=list[PriceIncreaseResponse])
async def price_increases(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = PriceService(db)
return await svc.get_increases()
@router.get("/comparison", response_model=list[PriceComparisonResponse])
async def price_comparison(
product_ids: Annotated[list[UUID], Query()],
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = PriceService(db)
return await svc.get_comparison(product_ids)
+56
View File
@@ -0,0 +1,56 @@
"""Product routes: search/list, detail, price history."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse
from cartsnitch_api.services.products import ProductService
router = APIRouter(prefix="/products", tags=["products"])
@router.get("", response_model=list[ProductResponse])
async def list_products(
user_id: UUID = Depends(get_current_user),
q: str | None = Query(None),
category: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
):
svc = ProductService(db)
return await svc.list_products(q, category, page, page_size)
@router.get("/{product_id}", response_model=ProductDetailResponse)
async def get_product(
product_id: UUID,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = ProductService(db)
try:
return await svc.get_product(product_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
) from None
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
async def get_product_prices(
product_id: UUID,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = ProductService(db)
try:
return await svc.get_price_history(product_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
) from None
+48
View File
@@ -0,0 +1,48 @@
"""Public endpoints: price transparency data (no auth required)."""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import (
PublicInflationResponse,
PublicStoreComparisonResponse,
PublicTrendResponse,
)
from cartsnitch_api.services.public import PublicService
router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
svc = PublicService(db)
try:
return await svc.get_trend(product_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
) from None
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)],
db: AsyncSession = Depends(get_db),
):
if not product_ids:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one product_id is required",
)
svc = PublicService(db)
return await svc.get_store_comparison(product_ids)
@router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation(db: AsyncSession = Depends(get_db)):
svc = PublicService(db)
return await svc.get_inflation()
+49
View File
@@ -0,0 +1,49 @@
"""Purchase routes: list, detail, stats."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse
from cartsnitch_api.services.purchases import PurchaseService
router = APIRouter(prefix="/purchases", tags=["purchases"])
@router.get("", response_model=list[PurchaseResponse])
async def list_purchases(
user_id: UUID = Depends(get_current_user),
store_id: UUID | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
):
svc = PurchaseService(db)
return await svc.list_purchases(user_id, store_id, page, page_size)
@router.get("/stats", response_model=PurchaseStatsResponse)
async def purchase_stats(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = PurchaseService(db)
return await svc.get_stats(user_id)
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
async def get_purchase(
purchase_id: UUID,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = PurchaseService(db)
try:
return await svc.get_purchase(purchase_id, user_id)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found"
) from None
+42
View File
@@ -0,0 +1,42 @@
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from httpx import HTTPStatusError, RequestError
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse
from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient
router = APIRouter(prefix="/scraping", tags=["scraping"])
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
client = ReceiptWitnessClient()
try:
result = await client.trigger_sync(str(user_id), store_slug)
return result
except HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
detail="Sync service error",
) from e
except RequestError:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unable to reach sync service",
) from None
@router.get("/status", response_model=list[SyncStatusResponse])
async def sync_status(user_id: UUID = Depends(get_current_user)):
client = ReceiptWitnessClient()
try:
return await client.get_sync_status(str(user_id))
except (HTTPStatusError, RequestError):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unable to reach sync service",
) from None
+48
View File
@@ -0,0 +1,48 @@
"""Shopping routes: optimize list, saved lists."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from httpx import HTTPStatusError, RequestError
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse
from cartsnitch_api.services.clipartist import ClipArtistClient
router = APIRouter(prefix="/shopping", tags=["shopping"])
@router.post("/optimize", response_model=OptimizeResponse)
async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
client = ClipArtistClient()
try:
result = await client.optimize(
user_id=str(user_id),
items=[item.model_dump() for item in body.items],
preferred_stores=(
[str(s) for s in body.preferred_stores] if body.preferred_stores else None
),
)
return result
except HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
detail="Shopping optimization service error",
) from e
except RequestError:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unable to reach shopping optimization service",
) from None
@router.get("/lists", response_model=list[ShoppingListResponse])
async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
client = ClipArtistClient()
try:
return await client.get_shopping_lists(str(user_id))
except (HTTPStatusError, RequestError):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unable to reach shopping service",
) from None
+61
View File
@@ -0,0 +1,61 @@
"""Store routes: list stores, manage user store connections."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse
from cartsnitch_api.services.stores import StoreService
router = APIRouter(tags=["stores"])
@router.get("/stores", response_model=list[StoreResponse])
async def list_stores(db: AsyncSession = Depends(get_db)):
svc = StoreService(db)
return await svc.list_stores()
@router.get("/me/stores", response_model=list[StoreAccountResponse])
async def list_user_stores(
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = StoreService(db)
return await svc.list_user_stores(user_id)
@router.post(
"/me/stores/{store_slug}/connect",
response_model=StoreAccountResponse,
status_code=status.HTTP_201_CREATED,
)
async def connect_store(
store_slug: str,
body: ConnectStoreRequest,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = StoreService(db)
try:
return await svc.connect_store(user_id, store_slug, body.credentials)
except LookupError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
async def disconnect_store(
store_slug: str,
user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
svc = StoreService(db)
try:
await svc.disconnect_store(user_id, store_slug)
except LookupError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
+291
View File
@@ -0,0 +1,291 @@
"""Pydantic v2 request/response schemas for all API endpoints."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field
# ---------- Auth ----------
class RegisterRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=128)
display_name: str = Field(min_length=1, max_length=100)
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshRequest(BaseModel):
refresh_token: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class UpdateUserRequest(BaseModel):
email: EmailStr | None = None
display_name: str | None = Field(None, min_length=1, max_length=100)
class UserResponse(BaseModel):
id: UUID
email: str
display_name: str
created_at: datetime
# ---------- Stores ----------
class StoreResponse(BaseModel):
id: UUID
name: str
slug: str
logo_url: str | None = None
supported: bool = True
class StoreAccountResponse(BaseModel):
store: StoreResponse
connected: bool
last_sync_at: datetime | None = None
sync_status: str | None = None
class ConnectStoreRequest(BaseModel):
credentials: dict | None = None
# ---------- Purchases ----------
class LineItemResponse(BaseModel):
id: UUID
product_id: UUID | None = None
name: str
quantity: float
unit_price: float
total_price: float
class PurchaseResponse(BaseModel):
id: UUID
store_id: UUID
store_name: str
purchased_at: datetime
total: float
item_count: int
class PurchaseDetailResponse(PurchaseResponse):
line_items: list[LineItemResponse]
class PurchaseStatsResponse(BaseModel):
total_spent: float
purchase_count: int
by_store: dict[str, float]
by_period: dict[str, float]
# ---------- Products ----------
class ProductResponse(BaseModel):
id: UUID
name: str
brand: str | None = None
category: str | None = None
upc: str | None = None
image_url: str | None = None
class ProductDetailResponse(ProductResponse):
prices_by_store: list["StorePriceResponse"]
class StorePriceResponse(BaseModel):
store_id: UUID
store_name: str
current_price: float
last_seen_at: datetime
# ---------- Prices ----------
class PriceTrendResponse(BaseModel):
product_id: UUID
product_name: str
data_points: list["PricePointResponse"]
class PricePointResponse(BaseModel):
date: datetime
price: float
store_id: UUID
store_name: str
class PriceIncreaseResponse(BaseModel):
product_id: UUID
product_name: str
store_name: str
old_price: float
new_price: float
increase_pct: float
detected_at: datetime
class PriceComparisonResponse(BaseModel):
product_id: UUID
product_name: str
prices: list[StorePriceResponse]
# ---------- Coupons ----------
class CouponResponse(BaseModel):
id: UUID
store_id: UUID
store_name: str
description: str
discount_value: float
discount_type: str
product_id: UUID | None = None
expires_at: datetime | None = None
# ---------- Shopping ----------
class ShoppingListItemRequest(BaseModel):
product_id: UUID | None = None
name: str
quantity: int = 1
class OptimizeRequest(BaseModel):
items: list[ShoppingListItemRequest]
preferred_stores: list[UUID] | None = None
class OptimizedStoreTrip(BaseModel):
store_id: UUID
store_name: str
items: list["OptimizedItemResponse"]
subtotal: float
coupons: list[CouponResponse]
savings: float
class OptimizedItemResponse(BaseModel):
name: str
price: float
product_id: UUID | None = None
class OptimizeResponse(BaseModel):
trips: list[OptimizedStoreTrip]
total_cost: float
total_savings: float
class ShoppingListResponse(BaseModel):
id: UUID
name: str
item_count: int
created_at: datetime
updated_at: datetime
# ---------- Alerts ----------
class AlertResponse(BaseModel):
id: UUID
alert_type: str
product_id: UUID
product_name: str
message: str
triggered_at: datetime
read: bool = False
class AlertSettingsRequest(BaseModel):
price_increase_threshold_pct: float | None = None
shrinkflation_enabled: bool | None = None
email_notifications: bool | None = None
class AlertSettingsResponse(BaseModel):
price_increase_threshold_pct: float
shrinkflation_enabled: bool
email_notifications: bool
# ---------- Scraping ----------
class SyncTriggerResponse(BaseModel):
job_id: UUID
status: str
message: str
class SyncStatusResponse(BaseModel):
store_slug: str
status: str
last_sync_at: datetime | None = None
items_synced: int | None = None
# ---------- Public ----------
class PublicTrendResponse(BaseModel):
product_id: UUID
product_name: str
data_points: list[PricePointResponse]
class PublicStoreComparisonResponse(BaseModel):
products: list[PriceComparisonResponse]
class PublicInflationResponse(BaseModel):
period: str
cartsnitch_index: float
cpi_baseline: float
categories: dict[str, float]
# ---------- Common ----------
class PaginatedResponse(BaseModel):
items: list
total: int
page: int
page_size: int
pages: int
class ErrorResponse(BaseModel):
detail: str
code: str | None = None
# Rebuild forward refs
ProductDetailResponse.model_rebuild()
PriceTrendResponse.model_rebuild()
OptimizedStoreTrip.model_rebuild()
+75
View File
@@ -0,0 +1,75 @@
"""Alert service — price and shrinkflation alerts for users.
Alerts are generated by StickerShock and ShrinkRay services and written to the DB.
This service reads them for the API gateway.
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
class AlertService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def list_alerts(self, user_id: UUID) -> list[dict]:
"""List shrinkflation events for products the user has purchased."""
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
# Get product IDs from user's purchases
items_result = await self.db.execute(
select(PurchaseItem.normalized_product_id)
.join(Purchase)
.where(
Purchase.user_id == user_id,
PurchaseItem.normalized_product_id.isnot(None),
)
.distinct()
)
product_ids = [row[0] for row in items_result.all()]
if not product_ids:
return []
result = await self.db.execute(
select(ShrinkflationEvent)
.where(ShrinkflationEvent.normalized_product_id.in_(product_ids))
.options(selectinload(ShrinkflationEvent.normalized_product))
.order_by(ShrinkflationEvent.detected_date.desc())
)
events = result.scalars().all()
return [
{
"id": e.id,
"alert_type": "shrinkflation",
"product_id": e.normalized_product_id,
"product_name": e.normalized_product.canonical_name,
"message": (
f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}"
),
"triggered_at": e.detected_date,
"read": False,
}
for e in events
]
async def get_settings(self, user_id: UUID) -> dict:
# Alert settings would be stored in a user_settings table.
# For now, return defaults since the table doesn't exist yet in common lib.
return {
"price_increase_threshold_pct": 5.0,
"shrinkflation_enabled": True,
"email_notifications": False,
}
async def update_settings(self, user_id: UUID, **fields) -> dict:
# Would update user_settings table. Return merged defaults for now.
current = await self.get_settings(user_id)
for k, v in fields.items():
if v is not None and k in current:
current[k] = v
return current
+125
View File
@@ -0,0 +1,125 @@
"""Auth service — user registration, login, token management."""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token
from cartsnitch_api.auth.passwords import hash_password, verify_password
from cartsnitch_api.config import settings
class AuthService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def register(self, email: str, password: str, display_name: str) -> dict:
from cartsnitch_api.models import User
existing = await self.db.execute(select(User).where(User.email == email))
if existing.scalar_one_or_none():
raise ValueError("Email already registered")
user = User(
email=email,
hashed_password=hash_password(password),
display_name=display_name,
)
self.db.add(user)
await self.db.commit()
await self.db.refresh(user)
return self._make_token_response(user.id)
async def login(self, email: str, password: str) -> dict:
from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user or not verify_password(password, user.hashed_password):
raise ValueError("Invalid email or password")
return self._make_token_response(user.id)
async def refresh(self, refresh_token: str) -> dict:
from cartsnitch_api.models import User
try:
payload = decode_token(refresh_token)
except ValueError:
raise ValueError("Invalid refresh token") from None
if payload.get("type") != "refresh":
raise ValueError("Invalid token type") from None
user_id = UUID(payload["sub"])
# Verify the user still exists before issuing new tokens
result = await self.db.execute(select(User).where(User.id == user_id))
if not result.scalar_one_or_none():
raise ValueError("User no longer exists")
return self._make_token_response(user_id)
async def get_user(self, user_id: UUID) -> dict:
from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
return {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"created_at": user.created_at,
}
async def update_user(self, user_id: UUID, **fields) -> dict:
from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
if "display_name" in fields and fields["display_name"] is not None:
user.display_name = fields["display_name"]
if "email" in fields and fields["email"] is not None:
existing = await self.db.execute(
select(User).where(User.email == fields["email"], User.id != user_id)
)
if existing.scalar_one_or_none():
raise ValueError("Email already in use")
user.email = fields["email"]
await self.db.commit()
await self.db.refresh(user)
return {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"created_at": user.created_at,
}
async def delete_user(self, user_id: UUID) -> None:
from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise LookupError("User not found")
await self.db.delete(user)
await self.db.commit()
def _make_token_response(self, user_id: UUID) -> dict:
return {
"access_token": create_access_token(user_id),
"refresh_token": create_refresh_token(user_id),
"token_type": "bearer",
"expires_in": settings.jwt_access_token_expire_minutes * 60,
}
+52
View File
@@ -0,0 +1,52 @@
"""HTTP client for ClipArtist internal API."""
from typing import Any, cast
import httpx
from cartsnitch_api.config import settings
class ClipArtistClient:
def __init__(self) -> None:
self.base_url = settings.clipartist_url
self.headers = {"X-Service-Key": settings.service_key}
async def optimize(
self,
user_id: str,
items: list[dict],
preferred_stores: list[str] | None = None,
) -> dict:
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{self.base_url}/optimize",
headers=self.headers,
json={
"user_id": user_id,
"items": items,
"preferred_stores": preferred_stores,
},
)
resp.raise_for_status()
return cast(dict[str, Any], resp.json())
async def get_shopping_lists(self, user_id: str) -> list[dict]:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/shopping-lists",
headers=self.headers,
params={"user_id": user_id},
)
resp.raise_for_status()
return cast(list[dict[str, Any]], resp.json())
async def get_relevant_coupons(self, user_id: str) -> list[dict]:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/coupons/relevant",
headers=self.headers,
params={"user_id": user_id},
)
resp.raise_for_status()
return cast(list[dict[str, Any]], resp.json())
+76
View File
@@ -0,0 +1,76 @@
"""Coupon service — browse coupons, find relevant ones."""
from datetime import date
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
class CouponService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def list_coupons(self, store_id: UUID | None = None) -> list[dict]:
from cartsnitch_api.models import Coupon
today = date.today()
query = (
select(Coupon)
.where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)))
.options(selectinload(Coupon.store))
.order_by(Coupon.valid_to.asc().nullslast())
)
if store_id:
query = query.where(Coupon.store_id == store_id)
result = await self.db.execute(query)
coupons = result.scalars().all()
return [self._to_dict(c) for c in coupons]
async def relevant_coupons(self, user_id: UUID) -> list[dict]:
"""Coupons for products the user has purchased."""
from cartsnitch_api.models import Coupon, PurchaseItem
today = date.today()
# Get product IDs from user's purchase history
from cartsnitch_api.models import Purchase
items_result = await self.db.execute(
select(PurchaseItem.normalized_product_id)
.join(Purchase)
.where(
Purchase.user_id == user_id,
PurchaseItem.normalized_product_id.isnot(None),
)
.distinct()
)
product_ids = [row[0] for row in items_result.all()]
if not product_ids:
return []
result = await self.db.execute(
select(Coupon)
.where(
Coupon.normalized_product_id.in_(product_ids),
(Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)),
)
.options(selectinload(Coupon.store))
)
coupons = result.scalars().all()
return [self._to_dict(c) for c in coupons]
def _to_dict(self, c) -> dict:
return {
"id": c.id,
"store_id": c.store_id,
"store_name": c.store.name,
"description": c.description or c.title,
"discount_value": float(c.discount_value) if c.discount_value else 0,
"discount_type": c.discount_type,
"product_id": c.normalized_product_id,
"expires_at": c.valid_to,
}
+183
View File
@@ -0,0 +1,183 @@
"""Price service — trends, increases, comparison."""
from uuid import UUID
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from cartsnitch_api.services.queries import latest_price_per_store
class PriceService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_trends(self, category: str | None = None) -> list[dict]:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
query = (
select(PriceHistory)
.join(NormalizedProduct)
.options(
selectinload(PriceHistory.store),
selectinload(PriceHistory.normalized_product),
)
.order_by(PriceHistory.observed_date)
)
if category:
query = query.where(NormalizedProduct.category == category)
result = await self.db.execute(query)
prices = result.scalars().all()
# Group by product
by_product: dict[UUID, dict] = {}
for ph in prices:
pid = ph.normalized_product_id
if pid not in by_product:
by_product[pid] = {
"product_id": pid,
"product_name": ph.normalized_product.canonical_name,
"data_points": [],
}
by_product[pid]["data_points"].append(
{
"date": ph.observed_date,
"price": float(ph.regular_price),
"store_id": ph.store_id,
"store_name": ph.store.name,
}
)
return list(by_product.values())
async def get_increases(self) -> list[dict]:
"""Find products with recent significant price increases.
Uses a window function (lag) to compare each price observation with the
previous one per product+store, avoiding the N+1 query pattern.
"""
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
# Use lag() window function to get previous price in a single query
prev_price = (
func.lag(PriceHistory.regular_price)
.over(
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
order_by=PriceHistory.observed_date,
)
.label("prev_price")
)
row_num = (
func.row_number()
.over(
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
order_by=PriceHistory.observed_date.desc(),
)
.label("rn")
)
inner = select(
PriceHistory.normalized_product_id,
PriceHistory.store_id,
PriceHistory.regular_price,
PriceHistory.observed_date,
prev_price,
row_num,
).subquery()
# Only keep the latest row (rn=1) where price increased
result = await self.db.execute(
select(
inner.c.normalized_product_id,
inner.c.store_id,
inner.c.regular_price,
inner.c.observed_date,
inner.c.prev_price,
NormalizedProduct.canonical_name,
Store.name.label("store_name"),
)
.join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id)
.join(Store, Store.id == inner.c.store_id)
.where(
inner.c.rn == 1,
inner.c.prev_price.isnot(None),
inner.c.regular_price > inner.c.prev_price,
)
)
increases = []
for row in result.all():
old = float(row.prev_price)
new = float(row.regular_price)
increases.append(
{
"product_id": row.normalized_product_id,
"product_name": row.canonical_name,
"store_name": row.store_name,
"old_price": old,
"new_price": new,
"increase_pct": round((new - old) / old * 100, 2),
"detected_at": row.observed_date,
}
)
increases.sort(key=lambda x: x["increase_pct"], reverse=True)
return increases
async def get_comparison(self, product_ids: list[UUID]) -> list[dict]:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids:
return []
# Fetch all requested products in one query
prod_result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
)
products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query
subq = latest_price_per_store(product_ids)
prices_result = await self.db.execute(
select(PriceHistory)
.join(
subq,
and_(
PriceHistory.store_id == subq.c.store_id,
PriceHistory.observed_date == subq.c.max_date,
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
),
)
.where(PriceHistory.normalized_product_id.in_(product_ids))
.options(selectinload(PriceHistory.store))
)
all_prices = prices_result.scalars().all()
# Group prices by product
prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids}
for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
comparisons = []
for pid in product_ids:
product = products_by_id.get(pid)
if not product:
continue
comparisons.append(
{
"product_id": pid,
"product_name": product.canonical_name,
"prices": [
{
"store_id": ph.store_id,
"store_name": ph.store.name,
"current_price": float(ph.regular_price),
"last_seen_at": ph.observed_date,
}
for ph in prices_by_product.get(pid, [])
],
}
)
return comparisons
+124
View File
@@ -0,0 +1,124 @@
"""Product service — catalog, detail, price history."""
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from cartsnitch_api.services.queries import latest_price_per_store
class ProductService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def list_products(
self,
q: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[dict]:
from cartsnitch_api.models import NormalizedProduct
query = select(NormalizedProduct)
if q:
# Escape SQL LIKE wildcards in user input
safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%"))
if category:
query = query.where(NormalizedProduct.category == category)
query = query.order_by(NormalizedProduct.canonical_name)
query = query.offset((page - 1) * page_size).limit(page_size)
result = await self.db.execute(query)
products = result.scalars().all()
return [
{
"id": p.id,
"name": p.canonical_name,
"brand": p.brand,
"category": p.category,
"upc": (p.upc_variants[0] if p.upc_variants else None),
"image_url": None,
}
for p in products
]
async def get_product(self, product_id: UUID) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
)
product = result.scalar_one_or_none()
if not product:
raise LookupError("Product not found")
# Get latest price per store
subq = latest_price_per_store([product_id])
prices_result = await self.db.execute(
select(PriceHistory)
.join(
subq,
and_(
PriceHistory.store_id == subq.c.store_id,
PriceHistory.observed_date == subq.c.max_date,
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
),
)
.where(PriceHistory.normalized_product_id == product_id)
.options(selectinload(PriceHistory.store))
)
prices = prices_result.scalars().all()
return {
"id": product.id,
"name": product.canonical_name,
"brand": product.brand,
"category": product.category,
"upc": (product.upc_variants[0] if product.upc_variants else None),
"image_url": None,
"prices_by_store": [
{
"store_id": ph.store_id,
"store_name": ph.store.name,
"current_price": float(ph.regular_price),
"last_seen_at": ph.observed_date,
}
for ph in prices
],
}
async def get_price_history(self, product_id: UUID) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
)
product = result.scalar_one_or_none()
if not product:
raise LookupError("Product not found")
prices_result = await self.db.execute(
select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id)
.options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date)
)
prices = prices_result.scalars().all()
return {
"product_id": product.id,
"product_name": product.canonical_name,
"data_points": [
{
"date": ph.observed_date,
"price": float(ph.regular_price),
"store_id": ph.store_id,
"store_name": ph.store.name,
}
for ph in prices
],
}
+129
View File
@@ -0,0 +1,129 @@
"""Public service — unauthenticated price transparency endpoints."""
from uuid import UUID
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from cartsnitch_api.services.queries import latest_price_per_store
class PublicService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_trend(self, product_id: UUID) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
)
product = result.scalar_one_or_none()
if not product:
raise LookupError("Product not found")
prices_result = await self.db.execute(
select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id)
.options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date)
)
prices = prices_result.scalars().all()
return {
"product_id": product.id,
"product_name": product.canonical_name,
"data_points": [
{
"date": ph.observed_date,
"price": float(ph.regular_price),
"store_id": ph.store_id,
"store_name": ph.store.name,
}
for ph in prices
],
}
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids:
return {"products": []}
# Fetch all products in one query
prod_result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
)
products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query
subq = latest_price_per_store(product_ids)
prices_result = await self.db.execute(
select(PriceHistory)
.join(
subq,
and_(
PriceHistory.store_id == subq.c.store_id,
PriceHistory.observed_date == subq.c.max_date,
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
),
)
.where(PriceHistory.normalized_product_id.in_(product_ids))
.options(selectinload(PriceHistory.store))
)
all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {}
for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = []
for pid in product_ids:
product = products_by_id.get(pid)
if not product:
continue
products.append(
{
"product_id": pid,
"product_name": product.canonical_name,
"prices": [
{
"store_id": ph.store_id,
"store_name": ph.store.name,
"current_price": float(ph.regular_price),
"last_seen_at": ph.observed_date,
}
for ph in prices_by_product.get(pid, [])
],
}
)
return {"products": products}
async def get_inflation(self) -> dict:
"""Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory
# Get average prices grouped by category for recent vs older data
result = await self.db.execute(
select(
NormalizedProduct.category,
func.avg(PriceHistory.regular_price),
)
.join(NormalizedProduct)
.group_by(NormalizedProduct.category)
)
categories = {}
for row in result.all():
cat, avg_price = row
if cat:
categories[cat] = float(avg_price) if avg_price else 0.0
return {
"period": "all-time",
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0,
"categories": categories,
}
+116
View File
@@ -0,0 +1,116 @@
"""Purchase service — list, detail, stats."""
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
class PurchaseService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def list_purchases(
self,
user_id: UUID,
store_id: UUID | None = None,
page: int = 1,
page_size: int = 20,
) -> list[dict]:
from cartsnitch_api.models import Purchase, PurchaseItem, Store
# Count items per purchase in a single subquery instead of N+1
item_counts = (
select(
PurchaseItem.purchase_id,
func.count().label("item_count"),
)
.group_by(PurchaseItem.purchase_id)
.subquery()
)
query = (
select(Purchase, item_counts.c.item_count, Store.name.label("store_name"))
.join(Store, Store.id == Purchase.store_id)
.outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id)
.where(Purchase.user_id == user_id)
)
if store_id:
query = query.where(Purchase.store_id == store_id)
query = query.order_by(Purchase.purchase_date.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
result = await self.db.execute(query)
return [
{
"id": p.id,
"store_id": p.store_id,
"store_name": store_name,
"purchased_at": p.purchase_date,
"total": float(p.total),
"item_count": item_count or 0,
}
for p, item_count, store_name in result.all()
]
async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
from cartsnitch_api.models import Purchase
result = await self.db.execute(
select(Purchase)
.where(Purchase.id == purchase_id, Purchase.user_id == user_id)
.options(selectinload(Purchase.store), selectinload(Purchase.items))
)
purchase = result.scalar_one_or_none()
if not purchase:
raise LookupError("Purchase not found")
return {
"id": purchase.id,
"store_id": purchase.store_id,
"store_name": purchase.store.name,
"purchased_at": purchase.purchase_date,
"total": float(purchase.total),
"item_count": len(purchase.items),
"line_items": [
{
"id": item.id,
"product_id": item.normalized_product_id,
"name": item.product_name_raw,
"quantity": float(item.quantity),
"unit_price": float(item.unit_price),
"total_price": float(item.extended_price),
}
for item in purchase.items
],
}
async def get_stats(self, user_id: UUID) -> dict:
from cartsnitch_api.models import Purchase
result = await self.db.execute(
select(Purchase)
.where(Purchase.user_id == user_id)
.options(selectinload(Purchase.store))
)
purchases = result.scalars().all()
total_spent = sum(float(p.total) for p in purchases)
by_store: dict[str, float] = {}
by_period: dict[str, float] = {}
for p in purchases:
store_name = p.store.name
by_store[store_name] = by_store.get(store_name, 0) + float(p.total)
period = p.purchase_date.strftime("%Y-%m")
by_period[period] = by_period.get(period, 0) + float(p.total)
return {
"total_spent": total_spent,
"purchase_count": len(purchases),
"by_store": by_store,
"by_period": by_period,
}
+23
View File
@@ -0,0 +1,23 @@
"""Shared query helpers for service layer."""
from uuid import UUID
from sqlalchemy import func, select
def latest_price_per_store(product_ids: list[UUID] | None = None):
"""Subquery returning the latest observed_date per product+store.
Optionally filtered to a list of product IDs. Returns a subquery with
columns: normalized_product_id, store_id, max_date.
"""
from cartsnitch_api.models import PriceHistory
query = select(
PriceHistory.normalized_product_id,
PriceHistory.store_id,
func.max(PriceHistory.observed_date).label("max_date"),
).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id)
if product_ids is not None:
query = query.where(PriceHistory.normalized_product_id.in_(product_ids))
return query.subquery()
@@ -0,0 +1,33 @@
"""HTTP client for ReceiptWitness internal API."""
from typing import Any, cast
import httpx
from cartsnitch_api.config import settings
class ReceiptWitnessClient:
def __init__(self) -> None:
self.base_url = settings.receiptwitness_url
self.headers = {"X-Service-Key": settings.service_key}
async def trigger_sync(self, user_id: str, store_slug: str) -> dict:
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{self.base_url}/sync/{store_slug}",
headers=self.headers,
json={"user_id": user_id},
)
resp.raise_for_status()
return cast(dict[str, Any], resp.json())
async def get_sync_status(self, user_id: str) -> list[dict]:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/sync/status",
headers=self.headers,
params={"user_id": user_id},
)
resp.raise_for_status()
return cast(list[dict[str, Any]], resp.json())
+23
View File
@@ -0,0 +1,23 @@
"""HTTP client for ShrinkRay internal API."""
from typing import Any, cast
import httpx
from cartsnitch_api.config import settings
class ShrinkRayClient:
def __init__(self) -> None:
self.base_url = settings.shrinkray_url
self.headers = {"X-Service-Key": settings.service_key}
async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/alerts",
headers=self.headers,
params={"user_id": user_id},
)
resp.raise_for_status()
return cast(list[dict[str, Any]], resp.json())
@@ -0,0 +1,32 @@
"""HTTP client for StickerShock internal API."""
from typing import Any, cast
import httpx
from cartsnitch_api.config import settings
class StickerShockClient:
def __init__(self) -> None:
self.base_url = settings.stickershock_url
self.headers = {"X-Service-Key": settings.service_key}
async def get_price_increases(self, params: dict | None = None) -> list[dict]:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/increases",
headers=self.headers,
params=params,
)
resp.raise_for_status()
return cast(list[dict[str, Any]], resp.json())
async def get_inflation_data(self) -> dict:
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/inflation",
headers=self.headers,
)
resp.raise_for_status()
return cast(dict[str, Any], resp.json())
+129
View File
@@ -0,0 +1,129 @@
"""Store service — list stores, manage user store account connections."""
import json
from uuid import UUID
from cryptography.fernet import Fernet
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from cartsnitch_api.config import settings
def _get_fernet() -> Fernet:
return Fernet(settings.fernet_key.encode())
class StoreService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def list_stores(self) -> list[dict]:
from cartsnitch_api.models import Store
result = await self.db.execute(select(Store).order_by(Store.name))
stores = result.scalars().all()
return [
{
"id": s.id,
"name": s.name,
"slug": s.slug,
"logo_url": s.logo_url,
"supported": True,
}
for s in stores
]
async def list_user_stores(self, user_id: UUID) -> list[dict]:
from cartsnitch_api.models import UserStoreAccount
result = await self.db.execute(
select(UserStoreAccount)
.where(UserStoreAccount.user_id == user_id)
.options(selectinload(UserStoreAccount.store))
)
accounts = result.scalars().all()
return [
{
"store": {
"id": a.store.id,
"name": a.store.name,
"slug": a.store.slug,
"logo_url": a.store.logo_url,
"supported": True,
},
"connected": a.status == "active",
"last_sync_at": a.last_sync_at,
"sync_status": a.status,
}
for a in accounts
]
async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
from cartsnitch_api.models import Store, UserStoreAccount
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
store = result.scalar_one_or_none()
if not store:
raise LookupError(f"Store '{store_slug}' not found")
existing = await self.db.execute(
select(UserStoreAccount).where(
UserStoreAccount.user_id == user_id,
UserStoreAccount.store_id == store.id,
)
)
if existing.scalar_one_or_none():
raise ValueError("Store account already connected")
encrypted_data = None
if credentials:
fernet = _get_fernet()
encrypted_data = {
"encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode()
}
account = UserStoreAccount(
user_id=user_id,
store_id=store.id,
session_data=encrypted_data,
status="active",
)
self.db.add(account)
await self.db.commit()
await self.db.refresh(account)
return {
"store": {
"id": store.id,
"name": store.name,
"slug": store.slug,
"logo_url": store.logo_url,
"supported": True,
},
"connected": True,
"last_sync_at": None,
"sync_status": "active",
}
async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
from cartsnitch_api.models import Store, UserStoreAccount
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
store = result.scalar_one_or_none()
if not store:
raise LookupError(f"Store '{store_slug}' not found")
result = await self.db.execute(
select(UserStoreAccount).where(
UserStoreAccount.user_id == user_id,
UserStoreAccount.store_id == store.id,
)
)
account = result.scalar_one_or_none()
if not account:
raise LookupError("Store account not connected")
await self.db.delete(account)
await self.db.commit()
+36
View File
@@ -0,0 +1,36 @@
"""Custom SQLAlchemy column types."""
import json
from cryptography.fernet import Fernet
from sqlalchemy import Text
from sqlalchemy.types import TypeDecorator
from cartsnitch_api.config import settings
def _get_fernet() -> Fernet:
return Fernet(settings.fernet_key.encode())
class EncryptedJSON(TypeDecorator):
"""SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet.
Stores data as a Fernet-encrypted text blob in the database.
On read, decrypts and deserialises back to a Python dict/list.
"""
impl = Text
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return None
plaintext = json.dumps(value).encode()
return _get_fernet().encrypt(plaintext).decode()
def process_result_value(self, value, dialect):
if value is None:
return None
decrypted = _get_fernet().decrypt(value.encode())
return json.loads(decrypted)