diff --git a/api/.dockerignore b/api/.dockerignore new file mode 100644 index 0000000..d7f640e --- /dev/null +++ b/api/.dockerignore @@ -0,0 +1,14 @@ +.git +.github +.pytest_cache +.ruff_cache +__pycache__ +*.py[cod] +*.egg-info +dist +.venv +.env +tests +openapi.json +CLAUDE.md +README.md diff --git a/api/.github/workflows/ci.yml b/api/.github/workflows/ci.yml new file mode 100644 index 0000000..5c61bb7 --- /dev/null +++ b/api/.github/workflows/ci.yml @@ -0,0 +1,164 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + packages: write + +env: + REGISTRY: ghcr.io + IMAGE_NAME: cartsnitch/api + +jobs: + lint: + runs-on: runners-cartsnitch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install ruff + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + + typecheck: + runs-on: runners-cartsnitch + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git" + - run: pip install -e ".[dev]" mypy + - name: Type check + run: mypy src/cartsnitch_api + + test: + runs-on: runners-cartsnitch + services: + postgres: + image: postgres:15-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + env: + POSTGRES_USER: cartsnitch + POSTGRES_PASSWORD: cartsnitch_test + POSTGRES_DB: cartsnitch_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + redis: + image: redis:7-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + CARTSNITCH_REDIS_URL: redis://localhost:6379/0 + CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git" + - run: pip install -e ".[dev]" + - name: Run tests + run: pytest --tb=short -q + + build-and-push: + runs-on: runners-cartsnitch + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Generate CalVer tag + id: calver + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + DATE_TAG=$(date -u +%Y.%m.%d) + EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1) + if [ -z "$EXISTING" ]; then + VERSION="$DATE_TAG" + elif [ "$EXISTING" = "v${DATE_TAG}" ]; then + VERSION="${DATE_TAG}.2" + else + BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//") + VERSION="${DATE_TAG}.$((BUILD_NUM + 1))" + fi + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + echo "CalVer tag: $VERSION" + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Log in to GHCR + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=sha,prefix=sha- + type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + target: prod + + - name: Create git tag + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + git tag "v${{ steps.calver.outputs.version }}" + git push origin "v${{ steps.calver.outputs.version }}" \ No newline at end of file diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..b0492c2 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,9 @@ +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +.venv/ +.env +.pytest_cache/ +.ruff_cache/ +openapi.json diff --git a/api/CLAUDE.md b/api/CLAUDE.md new file mode 100644 index 0000000..fcba89c --- /dev/null +++ b/api/CLAUDE.md @@ -0,0 +1,175 @@ +# CartSnitch API Gateway + +## Project Context + +CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/api`) is the public-facing API gateway that serves the frontend and proxies requests to internal services. + +**GitHub org:** github.com/cartsnitch +**Domain:** cartsnitch.com + +### CartSnitch Services + +| Repo | Service | Purpose | +|------|---------|---------| +| `cartsnitch/common` | — | Shared models, schemas, utilities | +| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers | +| `cartsnitch/api` | API Gateway | Frontend-facing REST API (this repo) | +| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) | +| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | +| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | +| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | +| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations | + +### Architecture Decisions + +- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline. +- **Shared DB:** One PostgreSQL cluster. This service reads from all tables for serving frontend queries. Models come from `cartsnitch-common`. +- **Inter-service comms:** REST to internal services, Redis pub/sub for event subscriptions. +- **Target scale:** 500–1,000 users initially. + +## What This Service Does + +The API Gateway is the single entry point for the frontend PWA and any external consumers. It: + +1. **Handles user authentication** — registration, login, JWT token management +2. **Serves purchase/product/price data** — reads from the shared DB +3. **Proxies scraping operations** — forwards scrape triggers to ReceiptWitness +4. **Serves coupon/deal data** — reads from shared DB (written by ClipArtist) +5. **Serves alerts** — price increase alerts (StickerShock), shrinkflation alerts (ShrinkRay) +6. **Provides public data endpoints** — aggregate price trends for the transparency/shaming features + +## Tech Stack + +- Python 3.12+ +- FastAPI (async) +- SQLAlchemy 2.0 (via `cartsnitch-common`, read-heavy) +- Pydantic v2 (request/response validation) +- python-jose or PyJWT (JWT auth) +- passlib + bcrypt (password hashing) +- httpx (async HTTP client for proxying to internal services) +- Redis (subscribe to events for websocket push, caching) +- uvicorn (ASGI server) + +## Repo Structure + +``` +api/ +├── CLAUDE.md +├── README.md +├── pyproject.toml +├── Dockerfile +├── docker-compose.yml +├── src/ +│ └── cartsnitch_api/ +│ ├── __init__.py +│ ├── config.py # Service-specific settings +│ ├── main.py # FastAPI app factory, lifespan, middleware +│ ├── auth/ +│ │ ├── __init__.py +│ │ ├── jwt.py # JWT creation/validation +│ │ ├── passwords.py # Hashing, verification +│ │ ├── dependencies.py # FastAPI dependency injection (get_current_user) +│ │ └── routes.py # /auth/register, /auth/login, /auth/refresh +│ ├── routes/ +│ │ ├── __init__.py +│ │ ├── purchases.py # Purchase history endpoints +│ │ ├── products.py # Normalized product catalog +│ │ ├── prices.py # Price history and trends +│ │ ├── coupons.py # Active coupons and deals +│ │ ├── alerts.py # Price increase / shrinkflation alerts +│ │ ├── stores.py # Store info, user store account management +│ │ ├── scraping.py # Proxy to ReceiptWitness (trigger scrape, status) +│ │ ├── shopping.py # Optimized shopping list (proxy to ClipArtist) +│ │ ├── public.py # Public price transparency endpoints (no auth) +│ │ └── health.py +│ ├── services/ +│ │ ├── __init__.py +│ │ ├── receiptwitness.py # HTTP client for ReceiptWitness internal API +│ │ ├── stickershock.py # HTTP client for StickerShock internal API +│ │ ├── clipartist.py # HTTP client for ClipArtist internal API +│ │ └── shrinkray.py # HTTP client for ShrinkRay internal API +│ ├── middleware/ +│ │ ├── __init__.py +│ │ ├── cors.py +│ │ └── rate_limit.py +│ └── cache.py # Redis caching helpers +└── tests/ + ├── conftest.py + ├── test_auth/ + ├── test_routes/ + └── test_services/ +``` + +## API Endpoint Design + +### Auth +- `POST /auth/register` — create account +- `POST /auth/login` — get JWT access + refresh tokens +- `POST /auth/refresh` — refresh access token +- `GET /auth/me` — current user profile + +### Store Accounts +- `GET /stores` — list supported stores +- `GET /me/stores` — list user's connected store accounts + sync status +- `POST /me/stores/{store_slug}/connect` — initiate store connection flow +- `DELETE /me/stores/{store_slug}` — disconnect store account + +### Purchases +- `GET /purchases` — list user's purchases (paginated, filterable by store/date) +- `GET /purchases/{id}` — purchase detail with line items +- `GET /purchases/stats` — spending summary (by store, by category, by period) + +### Products +- `GET /products` — normalized product catalog (search, filter) +- `GET /products/{id}` — product detail with cross-store price comparison +- `GET /products/{id}/prices` — price history for a product across stores + +### Prices +- `GET /prices/trends` — aggregate price trends (public-capable) +- `GET /prices/increases` — recent significant price increases +- `GET /prices/comparison` — compare specific items across stores + +### Coupons +- `GET /coupons` — active coupons/deals (filterable by store) +- `GET /coupons/relevant` — coupons relevant to user's purchase history + +### Shopping +- `POST /shopping/optimize` — input: shopping list → output: store-split + coupons +- `GET /shopping/lists` — user's saved shopping lists + +### Alerts +- `GET /alerts` — user's price increase and shrinkflation alerts +- `PUT /alerts/settings` — configure alert thresholds + +### Public (No Auth) +- `GET /public/trends/{product_id}` — public price trend for a product +- `GET /public/store-comparison` — public store-vs-store price comparison +- `GET /public/inflation` — price changes vs CPI baseline + +### Scraping (Proxy to ReceiptWitness) +- `POST /scraping/{store_slug}/sync` — trigger a sync for the current user +- `GET /scraping/status` — sync status across all stores + +## Authentication + +- JWT-based auth with short-lived access tokens (15 min) and longer refresh tokens (7 days). +- Passwords hashed with bcrypt via passlib. +- All user-specific endpoints require a valid JWT in the `Authorization: Bearer` header. +- Public endpoints under `/public/` do not require auth. +- Internal service-to-service calls (ReceiptWitness, etc.) use a shared API key in the `X-Service-Key` header — not user JWTs. + +## Development Workflow + +- **Never push directly to main.** Always create feature branches and open PRs. +- Branch naming: `feature/` or `fix/` +- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` +- OpenAPI docs auto-generated at `/docs` (Swagger) and `/redoc`. +- Write tests for all routes. Use httpx.AsyncClient with FastAPI's TestClient pattern. + +## Important Notes + +- This service is read-heavy on the shared DB. Use async SQLAlchemy sessions. +- Consider Redis caching for expensive queries (price trends, product comparisons). Cache invalidation via Redis pub/sub events from other services. +- Rate limiting on public endpoints is important — these could get hammered if the price transparency features get attention. +- CORS must allow the frontend origin (cartsnitch.com and localhost for dev). +- The store connection flow is the trickiest UX challenge: the user needs to authenticate with each retailer, and we need to capture the resulting session. This likely involves a controlled Playwright browser session that the user can see/interact with, or an OAuth-like redirect flow if the retailer supports it (Kroger does for its public API, but not for purchase history access). diff --git a/api/Dockerfile b/api/Dockerfile new file mode 100644 index 0000000..bb5d3bd --- /dev/null +++ b/api/Dockerfile @@ -0,0 +1,26 @@ +FROM python:3.12-slim AS build + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libpq-dev \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY pyproject.toml ./ +COPY src/ ./src/ +RUN pip install --no-cache-dir --prefix=/install . + +FROM python:3.12-slim AS prod + +WORKDIR /app +RUN adduser --system --group --uid 1000 app +COPY --from=build /install /usr/local +COPY src/ ./src/ + +USER 1000 +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=3s \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" + +CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/api/alembic.ini b/api/alembic.ini new file mode 100644 index 0000000..42fafc3 --- /dev/null +++ b/api/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql://OVERRIDE_VIA_ENV_VAR + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/alembic/env.py b/api/alembic/env.py new file mode 100644 index 0000000..3e563e1 --- /dev/null +++ b/api/alembic/env.py @@ -0,0 +1,55 @@ +"""Alembic environment configuration for CartSnitch.""" + +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context +from cartsnitch_api.models import Base # noqa: F401 — imports all models for autogenerate + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC") +if not db_url: + raise RuntimeError( + "CARTSNITCH_DATABASE_URL_SYNC must be set. " + "Example: postgresql://user:pass@localhost:5432/cartsnitch" + ) +config.set_main_option("sqlalchemy.url", db_url) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/alembic/script.py.mako b/api/alembic/script.py.mako new file mode 100644 index 0000000..fe3b097 --- /dev/null +++ b/api/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/alembic/versions/001_encrypt_session_data.py b/api/alembic/versions/001_encrypt_session_data.py new file mode 100644 index 0000000..4932231 --- /dev/null +++ b/api/alembic/versions/001_encrypt_session_data.py @@ -0,0 +1,89 @@ +"""Encrypt existing plaintext session_data with Fernet. + +Revision ID: 001_encrypt_session_data +Revises: +Create Date: 2026-03-19 +""" + +import json +import os + +import sqlalchemy as sa +from cryptography.fernet import Fernet +from sqlalchemy import text + +from alembic import op + +revision = "001_encrypt_session_data" +down_revision = None +branch_labels = None +depends_on = None + + +def _get_fernet() -> Fernet: + key = os.environ.get("CARTSNITCH_FERNET_KEY") + if not key: + raise RuntimeError("CARTSNITCH_FERNET_KEY must be set to run this migration") + return Fernet(key.encode()) + + +def _is_fernet_token(value: str) -> bool: + """Check if a string looks like a Fernet token (base64 starting with gAAAAA).""" + return value.startswith("gAAAAA") + + +def upgrade() -> None: + # Change column type from JSON to TEXT to hold Fernet ciphertext + op.alter_column( + "user_store_accounts", + "session_data", + type_=sa.Text(), + existing_type=sa.JSON(), + existing_nullable=True, + postgresql_using="session_data::text", + ) + + conn = op.get_bind() + rows = conn.execute( + text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL") + ).fetchall() + + f = _get_fernet() + for row_id, session_data in rows: + raw = str(session_data) + if _is_fernet_token(raw): + continue + plaintext = raw if isinstance(session_data, str) else json.dumps(session_data) + encrypted = f.encrypt(plaintext.encode()).decode() + conn.execute( + text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"), + {"data": encrypted, "id": row_id}, + ) + + +def downgrade() -> None: + conn = op.get_bind() + rows = conn.execute( + text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL") + ).fetchall() + + f = _get_fernet() + for row_id, session_data in rows: + raw = str(session_data) + if not _is_fernet_token(raw): + continue + decrypted = f.decrypt(raw.encode()).decode() + conn.execute( + text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"), + {"data": decrypted, "id": row_id}, + ) + + # Revert column type from TEXT back to JSON + op.alter_column( + "user_store_accounts", + "session_data", + type_=sa.JSON(), + existing_type=sa.Text(), + existing_nullable=True, + postgresql_using="session_data::json", + ) diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 0000000..8509182 --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cartsnitch-api" +version = "0.1.0" +description = "CartSnitch API Gateway — public-facing REST API" +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "pydantic[email]>=2.9.0", + "pydantic-settings>=2.5.0", + "sqlalchemy[asyncio]>=2.0.35", + "asyncpg>=0.30.0", + "alembic>=1.13,<2.0", + "psycopg2>=2.9,<3.0", + "python-jose[cryptography]>=3.3.0", + "passlib[bcrypt]>=1.7.4", + "httpx>=0.27.0", + "redis[hiredis]>=5.2.0", + "cryptography>=43.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.0", + "pytest-asyncio>=0.24.0", + "aiosqlite>=0.20.0", + "httpx>=0.27.0", + "ruff>=0.7.0", + "psycopg2-binary>=2.9,<3.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/cartsnitch_api"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "UP", "B"] + +[tool.ruff.lint.per-file-ignores] +"src/cartsnitch_api/**/routes*.py" = ["B008"] +"src/cartsnitch_api/**/dependencies.py" = ["B008"] + +[tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +warn_return_any = true +warn_unused_configs = true diff --git a/api/renovate.json b/api/renovate.json new file mode 100644 index 0000000..833ba3b --- /dev/null +++ b/api/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["local>cartsnitch/.github:renovate-config"] +} diff --git a/api/src/cartsnitch_api/__init__.py b/api/src/cartsnitch_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/auth/__init__.py b/api/src/cartsnitch_api/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/auth/dependencies.py b/api/src/cartsnitch_api/auth/dependencies.py new file mode 100644 index 0000000..61735ee --- /dev/null +++ b/api/src/cartsnitch_api/auth/dependencies.py @@ -0,0 +1,39 @@ +"""FastAPI dependency injection for authentication.""" + +from uuid import UUID + +from fastapi import Depends, Header, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from cartsnitch_api.auth.jwt import decode_token +from cartsnitch_api.config import settings + +bearer_scheme = HTTPBearer() + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), +) -> UUID: + try: + payload = decode_token(credentials.credentials) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + ) from None + + if payload.get("type") != "access": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + ) from None + + return UUID(payload["sub"]) + + +async def verify_service_key(x_service_key: str = Header()) -> None: + if x_service_key != settings.service_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid service key", + ) diff --git a/api/src/cartsnitch_api/auth/jwt.py b/api/src/cartsnitch_api/auth/jwt.py new file mode 100644 index 0000000..100c77b --- /dev/null +++ b/api/src/cartsnitch_api/auth/jwt.py @@ -0,0 +1,31 @@ +"""JWT token creation and validation.""" + +from datetime import UTC, datetime, timedelta +from typing import Any, cast +from uuid import UUID + +from jose import JWTError, jwt + +from cartsnitch_api.config import settings + + +def create_access_token(user_id: UUID) -> str: + expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes) + payload = {"sub": str(user_id), "exp": expire, "type": "access"} + return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) + + +def create_refresh_token(user_id: UUID) -> str: + expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days) + payload = {"sub": str(user_id), "exp": expire, "type": "refresh"} + return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) + + +def decode_token(token: str) -> dict: + try: + return cast( + dict[str, Any], + jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]), + ) + except JWTError as e: + raise ValueError(f"Invalid token: {e}") from e diff --git a/api/src/cartsnitch_api/auth/passwords.py b/api/src/cartsnitch_api/auth/passwords.py new file mode 100644 index 0000000..180f994 --- /dev/null +++ b/api/src/cartsnitch_api/auth/passwords.py @@ -0,0 +1,11 @@ +"""Password hashing and verification with bcrypt.""" + +import bcrypt + + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) diff --git a/api/src/cartsnitch_api/auth/routes.py b/api/src/cartsnitch_api/auth/routes.py new file mode 100644 index 0000000..ab34c3e --- /dev/null +++ b/api/src/cartsnitch_api/auth/routes.py @@ -0,0 +1,96 @@ +"""Auth routes: register, login, refresh, me, update, delete.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + LoginRequest, + RefreshRequest, + RegisterRequest, + TokenResponse, + UpdateUserRequest, + UserResponse, +) +from cartsnitch_api.services.auth import AuthService + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.register(body.email, body.password, body.display_name) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.post("/login", response_model=TokenResponse) +async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.login(body.email, body.password) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password" + ) from None + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.refresh(body.refresh_token) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" + ) from None + + +@router.get("/me", response_model=UserResponse) +async def get_me( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + return await svc.get_user(user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None + + +@router.patch("/me", response_model=UserResponse) +async def update_me( + body: UpdateUserRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + return await svc.update_user(user_id, email=body.email, display_name=body.display_name) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT) +async def delete_me( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + await svc.delete_user(user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None diff --git a/api/src/cartsnitch_api/cache.py b/api/src/cartsnitch_api/cache.py new file mode 100644 index 0000000..a7fdc81 --- /dev/null +++ b/api/src/cartsnitch_api/cache.py @@ -0,0 +1,26 @@ +"""Redis/DragonflyDB caching helpers.""" + +from cartsnitch_api.config import settings + + +class CacheClient: + """Stub for Redis/DragonflyDB caching. + + Will be used for expensive queries: price trends, product comparisons. + Cache invalidation via Redis pub/sub events from other services. + """ + + def __init__(self) -> None: + self.url = settings.redis_url + + async def get(self, key: str) -> str | None: + # TODO: implement with redis-py async + return None + + async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None: + # TODO: implement with redis-py async + pass + + async def delete(self, key: str) -> None: + # TODO: implement with redis-py async + pass diff --git a/api/src/cartsnitch_api/config.py b/api/src/cartsnitch_api/config.py new file mode 100644 index 0000000..52474b2 --- /dev/null +++ b/api/src/cartsnitch_api/config.py @@ -0,0 +1,51 @@ +import base64 + +from pydantic import model_validator +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + model_config = {"env_prefix": "CARTSNITCH_"} + + database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + redis_url: str = "redis://localhost:6379/0" + + jwt_secret_key: str = "change-me-in-production" + jwt_algorithm: str = "HS256" + jwt_access_token_expire_minutes: int = 15 + jwt_refresh_token_expire_days: int = 7 + + service_key: str = "change-me-in-production" + # Valid Fernet key for local dev — MUST be overridden in production + fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + + cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"] + + receiptwitness_url: str = "http://receiptwitness:8001" + stickershock_url: str = "http://stickershock:8002" + clipartist_url: str = "http://clipartist:8003" + shrinkray_url: str = "http://shrinkray:8004" + + rate_limit_requests: int = 60 + rate_limit_window_seconds: int = 60 + rate_limit_enabled: bool = True + + @model_validator(mode="after") + def validate_fernet_key(self): + """Validate fernet_key is a valid 32-byte url-safe base64 key at startup.""" + try: + decoded = base64.urlsafe_b64decode(self.fernet_key.encode()) + if len(decoded) != 32: + raise ValueError + except Exception: + raise ValueError( + "CARTSNITCH_FERNET_KEY must be a valid Fernet key " + "(32 bytes, url-safe base64 encoded). " + "Generate one with: python -c " + "'from cryptography.fernet import Fernet; " + "print(Fernet.generate_key().decode())'" + ) from None + return self + + +settings = Settings() diff --git a/api/src/cartsnitch_api/constants.py b/api/src/cartsnitch_api/constants.py new file mode 100644 index 0000000..b7a716c --- /dev/null +++ b/api/src/cartsnitch_api/constants.py @@ -0,0 +1,85 @@ +"""Constants and enums shared across CartSnitch services.""" + +from enum import StrEnum + + +class StoreSlug(StrEnum): + """Supported retailer slugs.""" + + MEIJER = "meijer" + KROGER = "kroger" + TARGET = "target" + + +class AccountStatus(StrEnum): + """User store account link status.""" + + ACTIVE = "active" + EXPIRED = "expired" + ERROR = "error" + + +class DiscountType(StrEnum): + """Coupon discount type.""" + + PERCENT = "percent" + FIXED = "fixed" + BOGO = "bogo" + BUY_X_GET_Y = "buy_x_get_y" + + +class PriceSource(StrEnum): + """Source of a price observation.""" + + RECEIPT = "receipt" + CATALOG = "catalog" + WEEKLY_AD = "weekly_ad" + + +class EventType(StrEnum): + """Redis pub/sub event types.""" + + RECEIPTS_INGESTED = "cartsnitch.receipts.ingested" + PRICES_UPDATED = "cartsnitch.prices.updated" + PRODUCTS_NORMALIZED = "cartsnitch.products.normalized" + COUPONS_UPDATED = "cartsnitch.coupons.updated" + ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase" + ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation" + + +class ProductCategory(StrEnum): + """Top-level product categories.""" + + PRODUCE = "produce" + DAIRY = "dairy" + MEAT = "meat" + BAKERY = "bakery" + FROZEN = "frozen" + PANTRY = "pantry" + BEVERAGES = "beverages" + SNACKS = "snacks" + HOUSEHOLD = "household" + PERSONAL_CARE = "personal_care" + OTHER = "other" + + +class MatchConfidence(StrEnum): + """Confidence level for product matching.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class SizeUnit(StrEnum): + """Standardized product size units.""" + + OZ = "oz" + FL_OZ = "fl_oz" + LB = "lb" + G = "g" + KG = "kg" + ML = "ml" + L = "l" + CT = "ct" + PK = "pk" diff --git a/api/src/cartsnitch_api/database.py b/api/src/cartsnitch_api/database.py new file mode 100644 index 0000000..324c5bf --- /dev/null +++ b/api/src/cartsnitch_api/database.py @@ -0,0 +1,16 @@ +"""Database session management for the API gateway.""" + +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from cartsnitch_api.config import settings + +engine = create_async_engine(settings.database_url, echo=False) +async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency that yields an async DB session.""" + async with async_session_factory() as session: + yield session diff --git a/api/src/cartsnitch_api/main.py b/api/src/cartsnitch_api/main.py new file mode 100644 index 0000000..1cd54ef --- /dev/null +++ b/api/src/cartsnitch_api/main.py @@ -0,0 +1,62 @@ +"""FastAPI app factory for CartSnitch API Gateway.""" + +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from cartsnitch_api.auth.routes import router as auth_router +from cartsnitch_api.middleware.cors import add_cors_middleware +from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware +from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware +from cartsnitch_api.routes.alerts import router as alerts_router +from cartsnitch_api.routes.coupons import router as coupons_router +from cartsnitch_api.routes.health import router as health_router +from cartsnitch_api.routes.prices import router as prices_router +from cartsnitch_api.routes.products import router as products_router +from cartsnitch_api.routes.public import router as public_router +from cartsnitch_api.routes.purchases import router as purchases_router +from cartsnitch_api.routes.scraping import router as scraping_router +from cartsnitch_api.routes.shopping import router as shopping_router +from cartsnitch_api.routes.stores import router as stores_router + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # TODO: initialize DB session pool, Redis connection, service clients + yield + # TODO: cleanup connections + + +def create_app() -> FastAPI: + app = FastAPI( + title="CartSnitch API", + description="Grocery price tracking and shrinkflation detection API", + version="0.1.0", + lifespan=lifespan, + ) + + # Middleware (order matters — outermost first) + add_cors_middleware(app) + add_error_monitor_middleware(app) + add_rate_limit_middleware(app) + + # Exception handlers + add_error_handlers(app) + + # Routers + app.include_router(health_router) + app.include_router(auth_router) + app.include_router(stores_router) + app.include_router(purchases_router) + app.include_router(products_router) + app.include_router(prices_router) + app.include_router(coupons_router) + app.include_router(shopping_router) + app.include_router(alerts_router) + app.include_router(scraping_router) + app.include_router(public_router) + + return app + + +app = create_app() diff --git a/api/src/cartsnitch_api/middleware/__init__.py b/api/src/cartsnitch_api/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/middleware/cors.py b/api/src/cartsnitch_api/middleware/cors.py new file mode 100644 index 0000000..0e6a4ae --- /dev/null +++ b/api/src/cartsnitch_api/middleware/cors.py @@ -0,0 +1,16 @@ +"""CORS middleware configuration.""" + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from cartsnitch_api.config import settings + + +def add_cors_middleware(app: FastAPI) -> None: + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) diff --git a/api/src/cartsnitch_api/middleware/error_handler.py b/api/src/cartsnitch_api/middleware/error_handler.py new file mode 100644 index 0000000..a32a008 --- /dev/null +++ b/api/src/cartsnitch_api/middleware/error_handler.py @@ -0,0 +1,190 @@ +"""Structured error responses and error monitoring. + +Ensures all errors return a consistent JSON shape and never leak stack traces. +Provides hooks for error monitoring/alerting. +""" + +import logging +import time +import traceback +from collections.abc import Awaitable, Callable + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger("cartsnitch_api.errors") + + +def _error_response( + status_code: int, + detail: str, + code: str | None = None, + errors: list[dict] | None = None, +) -> JSONResponse: + """Build a consistent error response.""" + body: dict = {"detail": detail} + if code: + body["code"] = code + if errors: + body["errors"] = errors + return JSONResponse(status_code=status_code, content=body) + + +def add_error_handlers(app: FastAPI) -> None: + """Register global exception handlers for consistent error responses.""" + + @app.exception_handler(RequestValidationError) + async def validation_error_handler( + request: Request, exc: RequestValidationError + ) -> JSONResponse: + """Return 422 with structured field-level error details.""" + field_errors = [] + for err in exc.errors(): + loc = err.get("loc", ()) + field_errors.append( + { + "field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc), + "message": err.get("msg", "Invalid value"), + "type": err.get("type", "value_error"), + } + ) + return _error_response( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Validation error", + code="VALIDATION_ERROR", + errors=field_errors, + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse: + """Wrap HTTP exceptions (Starlette and FastAPI) in consistent format.""" + detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail) + return _error_response( + status_code=exc.status_code, + detail=detail, + code=_status_to_code(exc.status_code), + ) + + @app.exception_handler(Exception) + async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Catch-all: log full traceback, return safe 500 to client.""" + logger.error( + "Unhandled exception on %s %s: %s\n%s", + request.method, + request.url.path, + exc, + traceback.format_exc(), + ) + _notify_error_monitor(request, exc) + + return _error_response( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + code="INTERNAL_ERROR", + ) + + +def _status_to_code(status_code: int) -> str: + """Map HTTP status code to a machine-readable error code.""" + mapping = { + 400: "BAD_REQUEST", + 401: "UNAUTHORIZED", + 403: "FORBIDDEN", + 404: "NOT_FOUND", + 409: "CONFLICT", + 422: "VALIDATION_ERROR", + 429: "RATE_LIMITED", + 502: "BAD_GATEWAY", + 503: "SERVICE_UNAVAILABLE", + } + return mapping.get(status_code, f"HTTP_{status_code}") + + +# ---------- Error Monitoring ---------- + + +class _ErrorMonitor: + """Simple error counter for monitoring and alerting hooks. + + Tracks error counts and rates. In production, this would forward + to an external monitoring service (Prometheus, Sentry, etc.). + """ + + def __init__(self) -> None: + self.error_counts: dict[int, int] = {} + self.recent_5xx: list[dict] = [] + self._max_recent = 100 + + def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None: + self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1 + + if status_code >= 500: + entry = { + "timestamp": time.time(), + "status": status_code, + "path": path, + "method": method, + "error": error, + } + self.recent_5xx.append(entry) + if len(self.recent_5xx) > self._max_recent: + self.recent_5xx = self.recent_5xx[-self._max_recent :] + + logger.warning( + "5xx error recorded: %s %s -> %d (%s)", + method, + path, + status_code, + error or "unknown", + ) + + def get_stats(self) -> dict: + return { + "error_counts": dict(self.error_counts), + "recent_5xx_count": len(self.recent_5xx), + } + + +_monitor = _ErrorMonitor() + + +def get_error_monitor() -> _ErrorMonitor: + """Access the global error monitor (for health/metrics endpoints).""" + return _monitor + + +def _notify_error_monitor(request: Request, exc: Exception) -> None: + """Record unhandled exception in the error monitor.""" + _monitor.record( + status_code=500, + path=request.url.path, + method=request.method, + error=str(exc)[:200], + ) + + +class ErrorMonitorMiddleware(BaseHTTPMiddleware): + """Middleware to track all 4xx/5xx responses for monitoring.""" + + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable], + ): + response = await call_next(request) + + if response.status_code >= 400: + _monitor.record( + status_code=response.status_code, + path=request.url.path, + method=request.method, + ) + + return response + + +def add_error_monitor_middleware(app: FastAPI) -> None: + app.add_middleware(ErrorMonitorMiddleware) diff --git a/api/src/cartsnitch_api/middleware/rate_limit.py b/api/src/cartsnitch_api/middleware/rate_limit.py new file mode 100644 index 0000000..424ed19 --- /dev/null +++ b/api/src/cartsnitch_api/middleware/rate_limit.py @@ -0,0 +1,111 @@ +"""Rate limiting middleware for public and authenticated endpoints. + +Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. +Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. +""" + +import time +from collections import defaultdict +from threading import Lock + +from fastapi import FastAPI, Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from cartsnitch_api.config import settings + + +class _SlidingWindowCounter: + """Thread-safe in-memory sliding window rate limiter.""" + + def __init__(self, max_requests: int, window_seconds: int) -> None: + self.max_requests = max_requests + self.window_seconds = window_seconds + self._hits: dict[str, list[float]] = defaultdict(list) + self._lock = Lock() + + def is_allowed(self, key: str) -> tuple[bool, int, int]: + """Check if request is allowed. Returns (allowed, remaining, retry_after).""" + now = time.monotonic() + cutoff = now - self.window_seconds + + with self._lock: + # Prune expired entries + self._hits[key] = [t for t in self._hits[key] if t > cutoff] + + current_count = len(self._hits[key]) + if current_count >= self.max_requests: + retry_after = int(self._hits[key][0] - cutoff) + 1 + return False, 0, retry_after + + self._hits[key].append(now) + remaining = self.max_requests - current_count - 1 + return True, remaining, 0 + + +# Module-level counters — one for public (per-IP), one for auth (per-token) +_public_limiter = _SlidingWindowCounter( + max_requests=settings.rate_limit_requests, + window_seconds=settings.rate_limit_window_seconds, +) +_auth_limiter = _SlidingWindowCounter( + max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users + window_seconds=settings.rate_limit_window_seconds, +) + + +def _get_client_ip(request: Request) -> str: + """Extract client IP, respecting X-Forwarded-For behind a reverse proxy.""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + +def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: + """Determine rate limit key and which limiter to use.""" + if request.url.path.startswith("/public"): + return f"ip:{_get_client_ip(request)}", _public_limiter + + # For authenticated endpoints, use Bearer token as key if present + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + # Use last 16 chars of token as key to avoid storing full tokens + return f"token:{token[-16:]}", _auth_limiter + + # Fallback to IP for unauthenticated non-public endpoints + return f"ip:{_get_client_ip(request)}", _public_limiter + + +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Skip rate limiting when disabled (e.g. in tests) or for health checks + if not settings.rate_limit_enabled or request.url.path == "/health": + return await call_next(request) + + key, limiter = _get_rate_limit_key(request) + allowed, remaining, retry_after = limiter.is_allowed(key) + + if not allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={ + "detail": "Rate limit exceeded", + "code": "RATE_LIMITED", + }, + headers={ + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limiter.max_requests), + "X-RateLimit-Remaining": "0", + }, + ) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limiter.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response + + +def add_rate_limit_middleware(app: FastAPI) -> None: + app.add_middleware(RateLimitMiddleware) diff --git a/api/src/cartsnitch_api/models/__init__.py b/api/src/cartsnitch_api/models/__init__.py new file mode 100644 index 0000000..d037b05 --- /dev/null +++ b/api/src/cartsnitch_api/models/__init__.py @@ -0,0 +1,26 @@ +"""SQLAlchemy ORM models — re-exports all models for convenience.""" + +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin +from cartsnitch_api.models.coupon import Coupon +from cartsnitch_api.models.price import PriceHistory +from cartsnitch_api.models.product import NormalizedProduct +from cartsnitch_api.models.purchase import Purchase, PurchaseItem +from cartsnitch_api.models.shrinkflation import ShrinkflationEvent +from cartsnitch_api.models.store import Store, StoreLocation +from cartsnitch_api.models.user import User, UserStoreAccount + +__all__ = [ + "Base", + "TimestampMixin", + "UUIDPrimaryKeyMixin", + "Store", + "StoreLocation", + "User", + "UserStoreAccount", + "Purchase", + "PurchaseItem", + "NormalizedProduct", + "PriceHistory", + "Coupon", + "ShrinkflationEvent", +] diff --git a/api/src/cartsnitch_api/models/base.py b/api/src/cartsnitch_api/models/base.py new file mode 100644 index 0000000..f93cf79 --- /dev/null +++ b/api/src/cartsnitch_api/models/base.py @@ -0,0 +1,30 @@ +"""Base model and mixins for all CartSnitch ORM models.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Base class for all CartSnitch models.""" + + +class TimestampMixin: + """Mixin providing created_at / updated_at columns.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + ) + + +class UUIDPrimaryKeyMixin: + """Mixin providing a UUID primary key.""" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + ) diff --git a/api/src/cartsnitch_api/models/coupon.py b/api/src/cartsnitch_api/models/coupon.py new file mode 100644 index 0000000..df2630a --- /dev/null +++ b/api/src/cartsnitch_api/models/coupon.py @@ -0,0 +1,42 @@ +"""Coupon model.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import DiscountType +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.store import Store + + +class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A coupon or deal for a product at a store.""" + + __tablename__ = "coupons" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + title: Mapped[str] = mapped_column(String(300), nullable=False) + description: Mapped[str | None] = mapped_column(String(1000)) + discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False) + discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + valid_from: Mapped[date | None] = mapped_column(Date) + valid_to: Mapped[date | None] = mapped_column(Date) + requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + coupon_code: Mapped[str | None] = mapped_column(String(100)) + source_url: Mapped[str | None] = mapped_column(String(500)) + scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="coupons") + normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons") diff --git a/api/src/cartsnitch_api/models/price.py b/api/src/cartsnitch_api/models/price.py new file mode 100644 index 0000000..7da0fa6 --- /dev/null +++ b/api/src/cartsnitch_api/models/price.py @@ -0,0 +1,50 @@ +"""PriceHistory model — tracks product prices over time.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Index, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import PriceSource +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.purchase import PurchaseItem + from cartsnitch_api.models.store import Store + + +class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single price observation for a product at a store on a date.""" + + __tablename__ = "price_history" + __table_args__ = ( + Index( + "ix_price_history_product_store_date", + "normalized_product_id", + "store_id", + "observed_date", + ), + ) + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + observed_date: Mapped[date] = mapped_column(Date, nullable=False) + regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source: Mapped[PriceSource] = mapped_column(String(20), nullable=False) + purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id")) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories") + store: Mapped["Store"] = relationship(back_populates="price_histories") + purchase_item: Mapped["PurchaseItem | None"] = relationship( + back_populates="price_history_entries" + ) diff --git a/api/src/cartsnitch_api/models/product.py b/api/src/cartsnitch_api/models/product.py new file mode 100644 index 0000000..4061132 --- /dev/null +++ b/api/src/cartsnitch_api/models/product.py @@ -0,0 +1,39 @@ +"""NormalizedProduct model — the canonical product identity.""" + +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import ProductCategory, SizeUnit +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.coupon import Coupon + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.purchase import PurchaseItem + from cartsnitch_api.models.shrinkflation import ShrinkflationEvent + + +class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Canonical product identity — matches products across retailers.""" + + __tablename__ = "normalized_products" + + canonical_name: Mapped[str] = mapped_column(String(300), nullable=False) + category: Mapped[ProductCategory | None] = mapped_column(String(50)) + subcategory: Mapped[str | None] = mapped_column(String(100)) + brand: Mapped[str | None] = mapped_column(String(200)) + size: Mapped[str | None] = mapped_column(String(50)) + size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) + upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) + + # Relationships + purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") + price_histories: Mapped[list["PriceHistory"]] = relationship( + back_populates="normalized_product" + ) + coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product") + shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship( + back_populates="normalized_product" + ) diff --git a/api/src/cartsnitch_api/models/purchase.py b/api/src/cartsnitch_api/models/purchase.py new file mode 100644 index 0000000..f57fde9 --- /dev/null +++ b/api/src/cartsnitch_api/models/purchase.py @@ -0,0 +1,91 @@ +"""Purchase and PurchaseItem models.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import ( + JSON, + Date, + DateTime, + ForeignKey, + Index, + Numeric, + String, + UniqueConstraint, + func, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.store import Store, StoreLocation + from cartsnitch_api.models.user import User + + +class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single shopping trip / receipt.""" + + __tablename__ = "purchases" + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id")) + receipt_id: Mapped[str] = mapped_column(String(200), nullable=False) + purchase_date: Mapped[date] = mapped_column(Date, nullable=False) + total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source_url: Mapped[str | None] = mapped_column(String(500)) + raw_data: Mapped[dict | None] = mapped_column(JSON) + ingested_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="purchases") + store: Mapped["Store"] = relationship(back_populates="purchases") + store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases") + items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase") + + __table_args__ = ( + Index("ix_purchases_user_store", "user_id", "store_id"), + UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"), + ) + + +class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Individual line item on a receipt.""" + + __tablename__ = "purchase_items" + + purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False) + product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False) + upc: Mapped[str | None] = mapped_column(String(20)) + quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1) + unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + category_raw: Mapped[str | None] = mapped_column(String(100)) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + + # Relationships + purchase: Mapped["Purchase"] = relationship(back_populates="items") + normalized_product: Mapped["NormalizedProduct | None"] = relationship( + back_populates="purchase_items" + ) + price_history_entries: Mapped[list["PriceHistory"]] = relationship( + back_populates="purchase_item" + ) diff --git a/api/src/cartsnitch_api/models/shrinkflation.py b/api/src/cartsnitch_api/models/shrinkflation.py new file mode 100644 index 0000000..2ce6f9d --- /dev/null +++ b/api/src/cartsnitch_api/models/shrinkflation.py @@ -0,0 +1,41 @@ +"""ShrinkflationEvent model.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import SizeUnit +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + + +class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Detected shrinkflation event — product size changed while price held or rose.""" + + __tablename__ = "shrinkflation_events" + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + detected_date: Mapped[date] = mapped_column(Date, nullable=False) + old_size: Mapped[str] = mapped_column(String(50), nullable=False) + new_size: Mapped[str] = mapped_column(String(50), nullable=False) + old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + confidence: Mapped[Decimal] = mapped_column( + Numeric(3, 2), nullable=False, default=Decimal("1.00") + ) + notes: Mapped[str | None] = mapped_column(String(1000)) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship( + back_populates="shrinkflation_events" + ) diff --git a/api/src/cartsnitch_api/models/store.py b/api/src/cartsnitch_api/models/store.py new file mode 100644 index 0000000..f75897f --- /dev/null +++ b/api/src/cartsnitch_api/models/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation models.""" + +import uuid +from typing import TYPE_CHECKING + +from sqlalchemy import Float, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import StoreSlug +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.coupon import Coupon + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.purchase import Purchase + from cartsnitch_api.models.user import UserStoreAccount + + +class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Supported retailer.""" + + __tablename__ = "stores" + + name: Mapped[str] = mapped_column(String(100), nullable=False) + slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True) + logo_url: Mapped[str | None] = mapped_column(String(500)) + website_url: Mapped[str | None] = mapped_column(String(500)) + + # Relationships + locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store") + user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store") + price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store") + coupons: Mapped[list["Coupon"]] = relationship(back_populates="store") + + +class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Physical store location.""" + + __tablename__ = "store_locations" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + address: Mapped[str] = mapped_column(String(300), nullable=False) + city: Mapped[str] = mapped_column(String(100), nullable=False) + state: Mapped[str] = mapped_column(String(2), nullable=False) + zip: Mapped[str] = mapped_column(String(10), nullable=False) + lat: Mapped[float | None] = mapped_column(Float) + lng: Mapped[float | None] = mapped_column(Float) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="locations") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location") diff --git a/api/src/cartsnitch_api/models/user.py b/api/src/cartsnitch_api/models/user.py new file mode 100644 index 0000000..56482b0 --- /dev/null +++ b/api/src/cartsnitch_api/models/user.py @@ -0,0 +1,50 @@ +"""User and UserStoreAccount models.""" + +import uuid +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import AccountStatus +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin +from cartsnitch_api.types import EncryptedJSON + +if TYPE_CHECKING: + from cartsnitch_api.models.purchase import Purchase + from cartsnitch_api.models.store import Store + + +class User(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Application user.""" + + __tablename__ = "users" + + email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100)) + + # Relationships + store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="user") + + +class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Link between a user and their retailer account credentials.""" + + __tablename__ = "user_store_accounts" + __table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),) + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + session_data: Mapped[dict | None] = mapped_column(EncryptedJSON) + session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + status: Mapped[AccountStatus] = mapped_column( + String(20), nullable=False, default=AccountStatus.ACTIVE + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="store_accounts") + store: Mapped["Store"] = relationship(back_populates="user_accounts") diff --git a/api/src/cartsnitch_api/routes/__init__.py b/api/src/cartsnitch_api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/routes/alerts.py b/api/src/cartsnitch_api/routes/alerts.py new file mode 100644 index 0000000..45ab33f --- /dev/null +++ b/api/src/cartsnitch_api/routes/alerts.py @@ -0,0 +1,44 @@ +"""Alert routes: list alerts, manage settings.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse +from cartsnitch_api.services.alerts import AlertService + +router = APIRouter(prefix="/alerts", tags=["alerts"]) + + +@router.get("", response_model=list[AlertResponse]) +async def list_alerts( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AlertService(db) + return await svc.list_alerts(user_id) + + +@router.get("/settings", response_model=AlertSettingsResponse) +async def get_alert_settings( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AlertService(db) + return await svc.get_settings(user_id) + + +@router.put("/settings") +async def update_alert_settings( + body: AlertSettingsRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Alert settings persistence not yet implemented. " + "Use GET /alerts/settings for current defaults.", + ) diff --git a/api/src/cartsnitch_api/routes/coupons.py b/api/src/cartsnitch_api/routes/coupons.py new file mode 100644 index 0000000..d33d98a --- /dev/null +++ b/api/src/cartsnitch_api/routes/coupons.py @@ -0,0 +1,32 @@ +"""Coupon routes: browse, relevant matches.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import CouponResponse +from cartsnitch_api.services.coupons import CouponService + +router = APIRouter(prefix="/coupons", tags=["coupons"]) + + +@router.get("", response_model=list[CouponResponse]) +async def list_coupons( + store_id: UUID | None = Query(None), + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = CouponService(db) + return await svc.list_coupons(store_id) + + +@router.get("/relevant", response_model=list[CouponResponse]) +async def relevant_coupons( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = CouponService(db) + return await svc.relevant_coupons(user_id) diff --git a/api/src/cartsnitch_api/routes/health.py b/api/src/cartsnitch_api/routes/health.py new file mode 100644 index 0000000..0574b10 --- /dev/null +++ b/api/src/cartsnitch_api/routes/health.py @@ -0,0 +1,20 @@ +"""Health check and error metrics endpoints.""" + +from fastapi import APIRouter, Depends + +from cartsnitch_api.auth.dependencies import verify_service_key +from cartsnitch_api.middleware.error_handler import get_error_monitor + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +async def health(): + return {"status": "ok"} + + +@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)]) +async def error_stats(): + """Error monitoring stats — internal only (requires X-Service-Key).""" + monitor = get_error_monitor() + return monitor.get_stats() diff --git a/api/src/cartsnitch_api/routes/prices.py b/api/src/cartsnitch_api/routes/prices.py new file mode 100644 index 0000000..487dd92 --- /dev/null +++ b/api/src/cartsnitch_api/routes/prices.py @@ -0,0 +1,47 @@ +"""Price routes: trends, increases, comparison.""" + +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + PriceComparisonResponse, + PriceIncreaseResponse, + PriceTrendResponse, +) +from cartsnitch_api.services.prices import PriceService + +router = APIRouter(prefix="/prices", tags=["prices"]) + + +@router.get("/trends", response_model=list[PriceTrendResponse]) +async def price_trends( + user_id: UUID = Depends(get_current_user), + category: str | None = Query(None), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_trends(category) + + +@router.get("/increases", response_model=list[PriceIncreaseResponse]) +async def price_increases( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_increases() + + +@router.get("/comparison", response_model=list[PriceComparisonResponse]) +async def price_comparison( + product_ids: Annotated[list[UUID], Query()], + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_comparison(product_ids) diff --git a/api/src/cartsnitch_api/routes/products.py b/api/src/cartsnitch_api/routes/products.py new file mode 100644 index 0000000..473cefe --- /dev/null +++ b/api/src/cartsnitch_api/routes/products.py @@ -0,0 +1,56 @@ +"""Product routes: search/list, detail, price history.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse +from cartsnitch_api.services.products import ProductService + +router = APIRouter(prefix="/products", tags=["products"]) + + +@router.get("", response_model=list[ProductResponse]) +async def list_products( + user_id: UUID = Depends(get_current_user), + q: str | None = Query(None), + category: str | None = Query(None), + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + return await svc.list_products(q, category, page, page_size) + + +@router.get("/{product_id}", response_model=ProductDetailResponse) +async def get_product( + product_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + try: + return await svc.get_product(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None + + +@router.get("/{product_id}/prices", response_model=PriceTrendResponse) +async def get_product_prices( + product_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + try: + return await svc.get_price_history(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None diff --git a/api/src/cartsnitch_api/routes/public.py b/api/src/cartsnitch_api/routes/public.py new file mode 100644 index 0000000..5d0b87b --- /dev/null +++ b/api/src/cartsnitch_api/routes/public.py @@ -0,0 +1,48 @@ +"""Public endpoints: price transparency data (no auth required).""" + +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + PublicInflationResponse, + PublicStoreComparisonResponse, + PublicTrendResponse, +) +from cartsnitch_api.services.public import PublicService + +router = APIRouter(prefix="/public", tags=["public"]) + + +@router.get("/trends/{product_id}", response_model=PublicTrendResponse) +async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): + svc = PublicService(db) + try: + return await svc.get_trend(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None + + +@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) +async def public_store_comparison( + product_ids: Annotated[list[UUID], Query(max_length=20)], + db: AsyncSession = Depends(get_db), +): + if not product_ids: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one product_id is required", + ) + svc = PublicService(db) + return await svc.get_store_comparison(product_ids) + + +@router.get("/inflation", response_model=PublicInflationResponse) +async def public_inflation(db: AsyncSession = Depends(get_db)): + svc = PublicService(db) + return await svc.get_inflation() diff --git a/api/src/cartsnitch_api/routes/purchases.py b/api/src/cartsnitch_api/routes/purchases.py new file mode 100644 index 0000000..eba86ac --- /dev/null +++ b/api/src/cartsnitch_api/routes/purchases.py @@ -0,0 +1,49 @@ +"""Purchase routes: list, detail, stats.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse +from cartsnitch_api.services.purchases import PurchaseService + +router = APIRouter(prefix="/purchases", tags=["purchases"]) + + +@router.get("", response_model=list[PurchaseResponse]) +async def list_purchases( + user_id: UUID = Depends(get_current_user), + store_id: UUID | None = Query(None), + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + return await svc.list_purchases(user_id, store_id, page, page_size) + + +@router.get("/stats", response_model=PurchaseStatsResponse) +async def purchase_stats( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + return await svc.get_stats(user_id) + + +@router.get("/{purchase_id}", response_model=PurchaseDetailResponse) +async def get_purchase( + purchase_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + try: + return await svc.get_purchase(purchase_id, user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found" + ) from None diff --git a/api/src/cartsnitch_api/routes/scraping.py b/api/src/cartsnitch_api/routes/scraping.py new file mode 100644 index 0000000..d8bbd5f --- /dev/null +++ b/api/src/cartsnitch_api/routes/scraping.py @@ -0,0 +1,42 @@ +"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness).""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from httpx import HTTPStatusError, RequestError + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse +from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient + +router = APIRouter(prefix="/scraping", tags=["scraping"]) + + +@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse) +async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)): + client = ReceiptWitnessClient() + try: + result = await client.trigger_sync(str(user_id), store_slug) + return result + except HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, + detail="Sync service error", + ) from e + except RequestError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach sync service", + ) from None + + +@router.get("/status", response_model=list[SyncStatusResponse]) +async def sync_status(user_id: UUID = Depends(get_current_user)): + client = ReceiptWitnessClient() + try: + return await client.get_sync_status(str(user_id)) + except (HTTPStatusError, RequestError): + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach sync service", + ) from None diff --git a/api/src/cartsnitch_api/routes/shopping.py b/api/src/cartsnitch_api/routes/shopping.py new file mode 100644 index 0000000..c64d5fd --- /dev/null +++ b/api/src/cartsnitch_api/routes/shopping.py @@ -0,0 +1,48 @@ +"""Shopping routes: optimize list, saved lists.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from httpx import HTTPStatusError, RequestError + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse +from cartsnitch_api.services.clipartist import ClipArtistClient + +router = APIRouter(prefix="/shopping", tags=["shopping"]) + + +@router.post("/optimize", response_model=OptimizeResponse) +async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)): + client = ClipArtistClient() + try: + result = await client.optimize( + user_id=str(user_id), + items=[item.model_dump() for item in body.items], + preferred_stores=( + [str(s) for s in body.preferred_stores] if body.preferred_stores else None + ), + ) + return result + except HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, + detail="Shopping optimization service error", + ) from e + except RequestError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach shopping optimization service", + ) from None + + +@router.get("/lists", response_model=list[ShoppingListResponse]) +async def list_shopping_lists(user_id: UUID = Depends(get_current_user)): + client = ClipArtistClient() + try: + return await client.get_shopping_lists(str(user_id)) + except (HTTPStatusError, RequestError): + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach shopping service", + ) from None diff --git a/api/src/cartsnitch_api/routes/stores.py b/api/src/cartsnitch_api/routes/stores.py new file mode 100644 index 0000000..1ab7947 --- /dev/null +++ b/api/src/cartsnitch_api/routes/stores.py @@ -0,0 +1,61 @@ +"""Store routes: list stores, manage user store connections.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse +from cartsnitch_api.services.stores import StoreService + +router = APIRouter(tags=["stores"]) + + +@router.get("/stores", response_model=list[StoreResponse]) +async def list_stores(db: AsyncSession = Depends(get_db)): + svc = StoreService(db) + return await svc.list_stores() + + +@router.get("/me/stores", response_model=list[StoreAccountResponse]) +async def list_user_stores( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + return await svc.list_user_stores(user_id) + + +@router.post( + "/me/stores/{store_slug}/connect", + response_model=StoreAccountResponse, + status_code=status.HTTP_201_CREATED, +) +async def connect_store( + store_slug: str, + body: ConnectStoreRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + try: + return await svc.connect_store(user_id, store_slug, body.credentials) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT) +async def disconnect_store( + store_slug: str, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + try: + await svc.disconnect_store(user_id, store_slug) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e diff --git a/api/src/cartsnitch_api/schemas.py b/api/src/cartsnitch_api/schemas.py new file mode 100644 index 0000000..19e351a --- /dev/null +++ b/api/src/cartsnitch_api/schemas.py @@ -0,0 +1,291 @@ +"""Pydantic v2 request/response schemas for all API endpoints.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, EmailStr, Field + +# ---------- Auth ---------- + + +class RegisterRequest(BaseModel): + email: EmailStr + password: str = Field(min_length=8, max_length=128) + display_name: str = Field(min_length=1, max_length=100) + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class RefreshRequest(BaseModel): + refresh_token: str + + +class TokenResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int + + +class UpdateUserRequest(BaseModel): + email: EmailStr | None = None + display_name: str | None = Field(None, min_length=1, max_length=100) + + +class UserResponse(BaseModel): + id: UUID + email: str + display_name: str + created_at: datetime + + +# ---------- Stores ---------- + + +class StoreResponse(BaseModel): + id: UUID + name: str + slug: str + logo_url: str | None = None + supported: bool = True + + +class StoreAccountResponse(BaseModel): + store: StoreResponse + connected: bool + last_sync_at: datetime | None = None + sync_status: str | None = None + + +class ConnectStoreRequest(BaseModel): + credentials: dict | None = None + + +# ---------- Purchases ---------- + + +class LineItemResponse(BaseModel): + id: UUID + product_id: UUID | None = None + name: str + quantity: float + unit_price: float + total_price: float + + +class PurchaseResponse(BaseModel): + id: UUID + store_id: UUID + store_name: str + purchased_at: datetime + total: float + item_count: int + + +class PurchaseDetailResponse(PurchaseResponse): + line_items: list[LineItemResponse] + + +class PurchaseStatsResponse(BaseModel): + total_spent: float + purchase_count: int + by_store: dict[str, float] + by_period: dict[str, float] + + +# ---------- Products ---------- + + +class ProductResponse(BaseModel): + id: UUID + name: str + brand: str | None = None + category: str | None = None + upc: str | None = None + image_url: str | None = None + + +class ProductDetailResponse(ProductResponse): + prices_by_store: list["StorePriceResponse"] + + +class StorePriceResponse(BaseModel): + store_id: UUID + store_name: str + current_price: float + last_seen_at: datetime + + +# ---------- Prices ---------- + + +class PriceTrendResponse(BaseModel): + product_id: UUID + product_name: str + data_points: list["PricePointResponse"] + + +class PricePointResponse(BaseModel): + date: datetime + price: float + store_id: UUID + store_name: str + + +class PriceIncreaseResponse(BaseModel): + product_id: UUID + product_name: str + store_name: str + old_price: float + new_price: float + increase_pct: float + detected_at: datetime + + +class PriceComparisonResponse(BaseModel): + product_id: UUID + product_name: str + prices: list[StorePriceResponse] + + +# ---------- Coupons ---------- + + +class CouponResponse(BaseModel): + id: UUID + store_id: UUID + store_name: str + description: str + discount_value: float + discount_type: str + product_id: UUID | None = None + expires_at: datetime | None = None + + +# ---------- Shopping ---------- + + +class ShoppingListItemRequest(BaseModel): + product_id: UUID | None = None + name: str + quantity: int = 1 + + +class OptimizeRequest(BaseModel): + items: list[ShoppingListItemRequest] + preferred_stores: list[UUID] | None = None + + +class OptimizedStoreTrip(BaseModel): + store_id: UUID + store_name: str + items: list["OptimizedItemResponse"] + subtotal: float + coupons: list[CouponResponse] + savings: float + + +class OptimizedItemResponse(BaseModel): + name: str + price: float + product_id: UUID | None = None + + +class OptimizeResponse(BaseModel): + trips: list[OptimizedStoreTrip] + total_cost: float + total_savings: float + + +class ShoppingListResponse(BaseModel): + id: UUID + name: str + item_count: int + created_at: datetime + updated_at: datetime + + +# ---------- Alerts ---------- + + +class AlertResponse(BaseModel): + id: UUID + alert_type: str + product_id: UUID + product_name: str + message: str + triggered_at: datetime + read: bool = False + + +class AlertSettingsRequest(BaseModel): + price_increase_threshold_pct: float | None = None + shrinkflation_enabled: bool | None = None + email_notifications: bool | None = None + + +class AlertSettingsResponse(BaseModel): + price_increase_threshold_pct: float + shrinkflation_enabled: bool + email_notifications: bool + + +# ---------- Scraping ---------- + + +class SyncTriggerResponse(BaseModel): + job_id: UUID + status: str + message: str + + +class SyncStatusResponse(BaseModel): + store_slug: str + status: str + last_sync_at: datetime | None = None + items_synced: int | None = None + + +# ---------- Public ---------- + + +class PublicTrendResponse(BaseModel): + product_id: UUID + product_name: str + data_points: list[PricePointResponse] + + +class PublicStoreComparisonResponse(BaseModel): + products: list[PriceComparisonResponse] + + +class PublicInflationResponse(BaseModel): + period: str + cartsnitch_index: float + cpi_baseline: float + categories: dict[str, float] + + +# ---------- Common ---------- + + +class PaginatedResponse(BaseModel): + items: list + total: int + page: int + page_size: int + pages: int + + +class ErrorResponse(BaseModel): + detail: str + code: str | None = None + + +# Rebuild forward refs +ProductDetailResponse.model_rebuild() +PriceTrendResponse.model_rebuild() +OptimizedStoreTrip.model_rebuild() diff --git a/api/src/cartsnitch_api/services/__init__.py b/api/src/cartsnitch_api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/services/alerts.py b/api/src/cartsnitch_api/services/alerts.py new file mode 100644 index 0000000..fc3ddd4 --- /dev/null +++ b/api/src/cartsnitch_api/services/alerts.py @@ -0,0 +1,75 @@ +"""Alert service — price and shrinkflation alerts for users. + +Alerts are generated by StickerShock and ShrinkRay services and written to the DB. +This service reads them for the API gateway. +""" + +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class AlertService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_alerts(self, user_id: UUID) -> list[dict]: + """List shrinkflation events for products the user has purchased.""" + from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent + + # Get product IDs from user's purchases + items_result = await self.db.execute( + select(PurchaseItem.normalized_product_id) + .join(Purchase) + .where( + Purchase.user_id == user_id, + PurchaseItem.normalized_product_id.isnot(None), + ) + .distinct() + ) + product_ids = [row[0] for row in items_result.all()] + + if not product_ids: + return [] + + result = await self.db.execute( + select(ShrinkflationEvent) + .where(ShrinkflationEvent.normalized_product_id.in_(product_ids)) + .options(selectinload(ShrinkflationEvent.normalized_product)) + .order_by(ShrinkflationEvent.detected_date.desc()) + ) + events = result.scalars().all() + + return [ + { + "id": e.id, + "alert_type": "shrinkflation", + "product_id": e.normalized_product_id, + "product_name": e.normalized_product.canonical_name, + "message": ( + f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}" + ), + "triggered_at": e.detected_date, + "read": False, + } + for e in events + ] + + async def get_settings(self, user_id: UUID) -> dict: + # Alert settings would be stored in a user_settings table. + # For now, return defaults since the table doesn't exist yet in common lib. + return { + "price_increase_threshold_pct": 5.0, + "shrinkflation_enabled": True, + "email_notifications": False, + } + + async def update_settings(self, user_id: UUID, **fields) -> dict: + # Would update user_settings table. Return merged defaults for now. + current = await self.get_settings(user_id) + for k, v in fields.items(): + if v is not None and k in current: + current[k] = v + return current diff --git a/api/src/cartsnitch_api/services/auth.py b/api/src/cartsnitch_api/services/auth.py new file mode 100644 index 0000000..5ea6b77 --- /dev/null +++ b/api/src/cartsnitch_api/services/auth.py @@ -0,0 +1,125 @@ +"""Auth service — user registration, login, token management.""" + +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token +from cartsnitch_api.auth.passwords import hash_password, verify_password +from cartsnitch_api.config import settings + + +class AuthService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def register(self, email: str, password: str, display_name: str) -> dict: + from cartsnitch_api.models import User + + existing = await self.db.execute(select(User).where(User.email == email)) + if existing.scalar_one_or_none(): + raise ValueError("Email already registered") + + user = User( + email=email, + hashed_password=hash_password(password), + display_name=display_name, + ) + self.db.add(user) + await self.db.commit() + await self.db.refresh(user) + + return self._make_token_response(user.id) + + async def login(self, email: str, password: str) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.email == email)) + user = result.scalar_one_or_none() + if not user or not verify_password(password, user.hashed_password): + raise ValueError("Invalid email or password") + + return self._make_token_response(user.id) + + async def refresh(self, refresh_token: str) -> dict: + from cartsnitch_api.models import User + + try: + payload = decode_token(refresh_token) + except ValueError: + raise ValueError("Invalid refresh token") from None + + if payload.get("type") != "refresh": + raise ValueError("Invalid token type") from None + + user_id = UUID(payload["sub"]) + + # Verify the user still exists before issuing new tokens + result = await self.db.execute(select(User).where(User.id == user_id)) + if not result.scalar_one_or_none(): + raise ValueError("User no longer exists") + + return self._make_token_response(user_id) + + async def get_user(self, user_id: UUID) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + return { + "id": user.id, + "email": user.email, + "display_name": user.display_name, + "created_at": user.created_at, + } + + async def update_user(self, user_id: UUID, **fields) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + if "display_name" in fields and fields["display_name"] is not None: + user.display_name = fields["display_name"] + if "email" in fields and fields["email"] is not None: + existing = await self.db.execute( + select(User).where(User.email == fields["email"], User.id != user_id) + ) + if existing.scalar_one_or_none(): + raise ValueError("Email already in use") + user.email = fields["email"] + + await self.db.commit() + await self.db.refresh(user) + + return { + "id": user.id, + "email": user.email, + "display_name": user.display_name, + "created_at": user.created_at, + } + + async def delete_user(self, user_id: UUID) -> None: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + await self.db.delete(user) + await self.db.commit() + + def _make_token_response(self, user_id: UUID) -> dict: + return { + "access_token": create_access_token(user_id), + "refresh_token": create_refresh_token(user_id), + "token_type": "bearer", + "expires_in": settings.jwt_access_token_expire_minutes * 60, + } diff --git a/api/src/cartsnitch_api/services/clipartist.py b/api/src/cartsnitch_api/services/clipartist.py new file mode 100644 index 0000000..86d6c62 --- /dev/null +++ b/api/src/cartsnitch_api/services/clipartist.py @@ -0,0 +1,52 @@ +"""HTTP client for ClipArtist internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ClipArtistClient: + def __init__(self) -> None: + self.base_url = settings.clipartist_url + self.headers = {"X-Service-Key": settings.service_key} + + async def optimize( + self, + user_id: str, + items: list[dict], + preferred_stores: list[str] | None = None, + ) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/optimize", + headers=self.headers, + json={ + "user_id": user_id, + "items": items, + "preferred_stores": preferred_stores, + }, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) + + async def get_shopping_lists(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/shopping-lists", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) + + async def get_relevant_coupons(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/coupons/relevant", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/coupons.py b/api/src/cartsnitch_api/services/coupons.py new file mode 100644 index 0000000..9b1543e --- /dev/null +++ b/api/src/cartsnitch_api/services/coupons.py @@ -0,0 +1,76 @@ +"""Coupon service — browse coupons, find relevant ones.""" + +from datetime import date +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class CouponService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_coupons(self, store_id: UUID | None = None) -> list[dict]: + from cartsnitch_api.models import Coupon + + today = date.today() + query = ( + select(Coupon) + .where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None))) + .options(selectinload(Coupon.store)) + .order_by(Coupon.valid_to.asc().nullslast()) + ) + if store_id: + query = query.where(Coupon.store_id == store_id) + + result = await self.db.execute(query) + coupons = result.scalars().all() + return [self._to_dict(c) for c in coupons] + + async def relevant_coupons(self, user_id: UUID) -> list[dict]: + """Coupons for products the user has purchased.""" + from cartsnitch_api.models import Coupon, PurchaseItem + + today = date.today() + + # Get product IDs from user's purchase history + from cartsnitch_api.models import Purchase + + items_result = await self.db.execute( + select(PurchaseItem.normalized_product_id) + .join(Purchase) + .where( + Purchase.user_id == user_id, + PurchaseItem.normalized_product_id.isnot(None), + ) + .distinct() + ) + product_ids = [row[0] for row in items_result.all()] + + if not product_ids: + return [] + + result = await self.db.execute( + select(Coupon) + .where( + Coupon.normalized_product_id.in_(product_ids), + (Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)), + ) + .options(selectinload(Coupon.store)) + ) + coupons = result.scalars().all() + return [self._to_dict(c) for c in coupons] + + def _to_dict(self, c) -> dict: + return { + "id": c.id, + "store_id": c.store_id, + "store_name": c.store.name, + "description": c.description or c.title, + "discount_value": float(c.discount_value) if c.discount_value else 0, + "discount_type": c.discount_type, + "product_id": c.normalized_product_id, + "expires_at": c.valid_to, + } diff --git a/api/src/cartsnitch_api/services/prices.py b/api/src/cartsnitch_api/services/prices.py new file mode 100644 index 0000000..44b74a0 --- /dev/null +++ b/api/src/cartsnitch_api/services/prices.py @@ -0,0 +1,183 @@ +"""Price service — trends, increases, comparison.""" + +from uuid import UUID + +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class PriceService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def get_trends(self, category: str | None = None) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + query = ( + select(PriceHistory) + .join(NormalizedProduct) + .options( + selectinload(PriceHistory.store), + selectinload(PriceHistory.normalized_product), + ) + .order_by(PriceHistory.observed_date) + ) + if category: + query = query.where(NormalizedProduct.category == category) + + result = await self.db.execute(query) + prices = result.scalars().all() + + # Group by product + by_product: dict[UUID, dict] = {} + for ph in prices: + pid = ph.normalized_product_id + if pid not in by_product: + by_product[pid] = { + "product_id": pid, + "product_name": ph.normalized_product.canonical_name, + "data_points": [], + } + by_product[pid]["data_points"].append( + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + ) + return list(by_product.values()) + + async def get_increases(self) -> list[dict]: + """Find products with recent significant price increases. + + Uses a window function (lag) to compare each price observation with the + previous one per product+store, avoiding the N+1 query pattern. + """ + from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + # Use lag() window function to get previous price in a single query + prev_price = ( + func.lag(PriceHistory.regular_price) + .over( + partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id], + order_by=PriceHistory.observed_date, + ) + .label("prev_price") + ) + + row_num = ( + func.row_number() + .over( + partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id], + order_by=PriceHistory.observed_date.desc(), + ) + .label("rn") + ) + + inner = select( + PriceHistory.normalized_product_id, + PriceHistory.store_id, + PriceHistory.regular_price, + PriceHistory.observed_date, + prev_price, + row_num, + ).subquery() + + # Only keep the latest row (rn=1) where price increased + result = await self.db.execute( + select( + inner.c.normalized_product_id, + inner.c.store_id, + inner.c.regular_price, + inner.c.observed_date, + inner.c.prev_price, + NormalizedProduct.canonical_name, + Store.name.label("store_name"), + ) + .join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id) + .join(Store, Store.id == inner.c.store_id) + .where( + inner.c.rn == 1, + inner.c.prev_price.isnot(None), + inner.c.regular_price > inner.c.prev_price, + ) + ) + + increases = [] + for row in result.all(): + old = float(row.prev_price) + new = float(row.regular_price) + increases.append( + { + "product_id": row.normalized_product_id, + "product_name": row.canonical_name, + "store_name": row.store_name, + "old_price": old, + "new_price": new, + "increase_pct": round((new - old) / old * 100, 2), + "detected_at": row.observed_date, + } + ) + + increases.sort(key=lambda x: x["increase_pct"], reverse=True) + return increases + + async def get_comparison(self, product_ids: list[UUID]) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + if not product_ids: + return [] + + # Fetch all requested products in one query + prod_result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) + ) + products_by_id = {p.id: p for p in prod_result.scalars().all()} + + # Latest prices for all requested products in one query + subq = latest_price_per_store(product_ids) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id.in_(product_ids)) + .options(selectinload(PriceHistory.store)) + ) + all_prices = prices_result.scalars().all() + + # Group prices by product + prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids} + for ph in all_prices: + prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) + + comparisons = [] + for pid in product_ids: + product = products_by_id.get(pid) + if not product: + continue + comparisons.append( + { + "product_id": pid, + "product_name": product.canonical_name, + "prices": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices_by_product.get(pid, []) + ], + } + ) + return comparisons diff --git a/api/src/cartsnitch_api/services/products.py b/api/src/cartsnitch_api/services/products.py new file mode 100644 index 0000000..ad35987 --- /dev/null +++ b/api/src/cartsnitch_api/services/products.py @@ -0,0 +1,124 @@ +"""Product service — catalog, detail, price history.""" + +from uuid import UUID + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class ProductService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_products( + self, + q: str | None = None, + category: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct + + query = select(NormalizedProduct) + if q: + # Escape SQL LIKE wildcards in user input + safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%")) + if category: + query = query.where(NormalizedProduct.category == category) + query = query.order_by(NormalizedProduct.canonical_name) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + products = result.scalars().all() + return [ + { + "id": p.id, + "name": p.canonical_name, + "brand": p.brand, + "category": p.category, + "upc": (p.upc_variants[0] if p.upc_variants else None), + "image_url": None, + } + for p in products + ] + + async def get_product(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + # Get latest price per store + subq = latest_price_per_store([product_id]) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + ) + prices = prices_result.scalars().all() + + return { + "id": product.id, + "name": product.canonical_name, + "brand": product.brand, + "category": product.category, + "upc": (product.upc_variants[0] if product.upc_variants else None), + "image_url": None, + "prices_by_store": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices + ], + } + + async def get_price_history(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + prices_result = await self.db.execute( + select(PriceHistory) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + .order_by(PriceHistory.observed_date) + ) + prices = prices_result.scalars().all() + + return { + "product_id": product.id, + "product_name": product.canonical_name, + "data_points": [ + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + for ph in prices + ], + } diff --git a/api/src/cartsnitch_api/services/public.py b/api/src/cartsnitch_api/services/public.py new file mode 100644 index 0000000..f1ccbeb --- /dev/null +++ b/api/src/cartsnitch_api/services/public.py @@ -0,0 +1,129 @@ +"""Public service — unauthenticated price transparency endpoints.""" + +from uuid import UUID + +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class PublicService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def get_trend(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + prices_result = await self.db.execute( + select(PriceHistory) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + .order_by(PriceHistory.observed_date) + ) + prices = prices_result.scalars().all() + + return { + "product_id": product.id, + "product_name": product.canonical_name, + "data_points": [ + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + for ph in prices + ], + } + + async def get_store_comparison(self, product_ids: list[UUID]) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + if not product_ids: + return {"products": []} + + # Fetch all products in one query + prod_result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) + ) + products_by_id = {p.id: p for p in prod_result.scalars().all()} + + # Latest prices for all requested products in one query + subq = latest_price_per_store(product_ids) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id.in_(product_ids)) + .options(selectinload(PriceHistory.store)) + ) + all_prices = prices_result.scalars().all() + + # Group by product + prices_by_product: dict[UUID, list] = {} + for ph in all_prices: + prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) + + products = [] + for pid in product_ids: + product = products_by_id.get(pid) + if not product: + continue + products.append( + { + "product_id": pid, + "product_name": product.canonical_name, + "prices": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices_by_product.get(pid, []) + ], + } + ) + + return {"products": products} + + async def get_inflation(self) -> dict: + """Aggregate price change stats. Compares average prices across periods.""" + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + # Get average prices grouped by category for recent vs older data + result = await self.db.execute( + select( + NormalizedProduct.category, + func.avg(PriceHistory.regular_price), + ) + .join(NormalizedProduct) + .group_by(NormalizedProduct.category) + ) + categories = {} + for row in result.all(): + cat, avg_price = row + if cat: + categories[cat] = float(avg_price) if avg_price else 0.0 + + return { + "period": "all-time", + "cartsnitch_index": sum(categories.values()) / max(len(categories), 1), + "cpi_baseline": 100.0, + "categories": categories, + } diff --git a/api/src/cartsnitch_api/services/purchases.py b/api/src/cartsnitch_api/services/purchases.py new file mode 100644 index 0000000..41776f4 --- /dev/null +++ b/api/src/cartsnitch_api/services/purchases.py @@ -0,0 +1,116 @@ +"""Purchase service — list, detail, stats.""" + +from uuid import UUID + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class PurchaseService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_purchases( + self, + user_id: UUID, + store_id: UUID | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[dict]: + from cartsnitch_api.models import Purchase, PurchaseItem, Store + + # Count items per purchase in a single subquery instead of N+1 + item_counts = ( + select( + PurchaseItem.purchase_id, + func.count().label("item_count"), + ) + .group_by(PurchaseItem.purchase_id) + .subquery() + ) + + query = ( + select(Purchase, item_counts.c.item_count, Store.name.label("store_name")) + .join(Store, Store.id == Purchase.store_id) + .outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id) + .where(Purchase.user_id == user_id) + ) + if store_id: + query = query.where(Purchase.store_id == store_id) + + query = query.order_by(Purchase.purchase_date.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + + return [ + { + "id": p.id, + "store_id": p.store_id, + "store_name": store_name, + "purchased_at": p.purchase_date, + "total": float(p.total), + "item_count": item_count or 0, + } + for p, item_count, store_name in result.all() + ] + + async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict: + from cartsnitch_api.models import Purchase + + result = await self.db.execute( + select(Purchase) + .where(Purchase.id == purchase_id, Purchase.user_id == user_id) + .options(selectinload(Purchase.store), selectinload(Purchase.items)) + ) + purchase = result.scalar_one_or_none() + if not purchase: + raise LookupError("Purchase not found") + + return { + "id": purchase.id, + "store_id": purchase.store_id, + "store_name": purchase.store.name, + "purchased_at": purchase.purchase_date, + "total": float(purchase.total), + "item_count": len(purchase.items), + "line_items": [ + { + "id": item.id, + "product_id": item.normalized_product_id, + "name": item.product_name_raw, + "quantity": float(item.quantity), + "unit_price": float(item.unit_price), + "total_price": float(item.extended_price), + } + for item in purchase.items + ], + } + + async def get_stats(self, user_id: UUID) -> dict: + from cartsnitch_api.models import Purchase + + result = await self.db.execute( + select(Purchase) + .where(Purchase.user_id == user_id) + .options(selectinload(Purchase.store)) + ) + purchases = result.scalars().all() + + total_spent = sum(float(p.total) for p in purchases) + by_store: dict[str, float] = {} + by_period: dict[str, float] = {} + + for p in purchases: + store_name = p.store.name + by_store[store_name] = by_store.get(store_name, 0) + float(p.total) + period = p.purchase_date.strftime("%Y-%m") + by_period[period] = by_period.get(period, 0) + float(p.total) + + return { + "total_spent": total_spent, + "purchase_count": len(purchases), + "by_store": by_store, + "by_period": by_period, + } diff --git a/api/src/cartsnitch_api/services/queries.py b/api/src/cartsnitch_api/services/queries.py new file mode 100644 index 0000000..8a94f7c --- /dev/null +++ b/api/src/cartsnitch_api/services/queries.py @@ -0,0 +1,23 @@ +"""Shared query helpers for service layer.""" + +from uuid import UUID + +from sqlalchemy import func, select + + +def latest_price_per_store(product_ids: list[UUID] | None = None): + """Subquery returning the latest observed_date per product+store. + + Optionally filtered to a list of product IDs. Returns a subquery with + columns: normalized_product_id, store_id, max_date. + """ + from cartsnitch_api.models import PriceHistory + + query = select( + PriceHistory.normalized_product_id, + PriceHistory.store_id, + func.max(PriceHistory.observed_date).label("max_date"), + ).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id) + if product_ids is not None: + query = query.where(PriceHistory.normalized_product_id.in_(product_ids)) + return query.subquery() diff --git a/api/src/cartsnitch_api/services/receiptwitness.py b/api/src/cartsnitch_api/services/receiptwitness.py new file mode 100644 index 0000000..e6200a9 --- /dev/null +++ b/api/src/cartsnitch_api/services/receiptwitness.py @@ -0,0 +1,33 @@ +"""HTTP client for ReceiptWitness internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ReceiptWitnessClient: + def __init__(self) -> None: + self.base_url = settings.receiptwitness_url + self.headers = {"X-Service-Key": settings.service_key} + + async def trigger_sync(self, user_id: str, store_slug: str) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/sync/{store_slug}", + headers=self.headers, + json={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) + + async def get_sync_status(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/sync/status", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/shrinkray.py b/api/src/cartsnitch_api/services/shrinkray.py new file mode 100644 index 0000000..862881e --- /dev/null +++ b/api/src/cartsnitch_api/services/shrinkray.py @@ -0,0 +1,23 @@ +"""HTTP client for ShrinkRay internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ShrinkRayClient: + def __init__(self) -> None: + self.base_url = settings.shrinkray_url + self.headers = {"X-Service-Key": settings.service_key} + + async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/alerts", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/stickershock.py b/api/src/cartsnitch_api/services/stickershock.py new file mode 100644 index 0000000..3a7928d --- /dev/null +++ b/api/src/cartsnitch_api/services/stickershock.py @@ -0,0 +1,32 @@ +"""HTTP client for StickerShock internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class StickerShockClient: + def __init__(self) -> None: + self.base_url = settings.stickershock_url + self.headers = {"X-Service-Key": settings.service_key} + + async def get_price_increases(self, params: dict | None = None) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/increases", + headers=self.headers, + params=params, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) + + async def get_inflation_data(self) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/inflation", + headers=self.headers, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) diff --git a/api/src/cartsnitch_api/services/stores.py b/api/src/cartsnitch_api/services/stores.py new file mode 100644 index 0000000..610f47e --- /dev/null +++ b/api/src/cartsnitch_api/services/stores.py @@ -0,0 +1,129 @@ +"""Store service — list stores, manage user store account connections.""" + +import json +from uuid import UUID + +from cryptography.fernet import Fernet +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.config import settings + + +def _get_fernet() -> Fernet: + return Fernet(settings.fernet_key.encode()) + + +class StoreService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_stores(self) -> list[dict]: + from cartsnitch_api.models import Store + + result = await self.db.execute(select(Store).order_by(Store.name)) + stores = result.scalars().all() + return [ + { + "id": s.id, + "name": s.name, + "slug": s.slug, + "logo_url": s.logo_url, + "supported": True, + } + for s in stores + ] + + async def list_user_stores(self, user_id: UUID) -> list[dict]: + from cartsnitch_api.models import UserStoreAccount + + result = await self.db.execute( + select(UserStoreAccount) + .where(UserStoreAccount.user_id == user_id) + .options(selectinload(UserStoreAccount.store)) + ) + accounts = result.scalars().all() + return [ + { + "store": { + "id": a.store.id, + "name": a.store.name, + "slug": a.store.slug, + "logo_url": a.store.logo_url, + "supported": True, + }, + "connected": a.status == "active", + "last_sync_at": a.last_sync_at, + "sync_status": a.status, + } + for a in accounts + ] + + async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict: + from cartsnitch_api.models import Store, UserStoreAccount + + result = await self.db.execute(select(Store).where(Store.slug == store_slug)) + store = result.scalar_one_or_none() + if not store: + raise LookupError(f"Store '{store_slug}' not found") + + existing = await self.db.execute( + select(UserStoreAccount).where( + UserStoreAccount.user_id == user_id, + UserStoreAccount.store_id == store.id, + ) + ) + if existing.scalar_one_or_none(): + raise ValueError("Store account already connected") + + encrypted_data = None + if credentials: + fernet = _get_fernet() + encrypted_data = { + "encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode() + } + + account = UserStoreAccount( + user_id=user_id, + store_id=store.id, + session_data=encrypted_data, + status="active", + ) + self.db.add(account) + await self.db.commit() + await self.db.refresh(account) + + return { + "store": { + "id": store.id, + "name": store.name, + "slug": store.slug, + "logo_url": store.logo_url, + "supported": True, + }, + "connected": True, + "last_sync_at": None, + "sync_status": "active", + } + + async def disconnect_store(self, user_id: UUID, store_slug: str) -> None: + from cartsnitch_api.models import Store, UserStoreAccount + + result = await self.db.execute(select(Store).where(Store.slug == store_slug)) + store = result.scalar_one_or_none() + if not store: + raise LookupError(f"Store '{store_slug}' not found") + + result = await self.db.execute( + select(UserStoreAccount).where( + UserStoreAccount.user_id == user_id, + UserStoreAccount.store_id == store.id, + ) + ) + account = result.scalar_one_or_none() + if not account: + raise LookupError("Store account not connected") + + await self.db.delete(account) + await self.db.commit() diff --git a/api/src/cartsnitch_api/types.py b/api/src/cartsnitch_api/types.py new file mode 100644 index 0000000..13a7820 --- /dev/null +++ b/api/src/cartsnitch_api/types.py @@ -0,0 +1,36 @@ +"""Custom SQLAlchemy column types.""" + +import json + +from cryptography.fernet import Fernet +from sqlalchemy import Text +from sqlalchemy.types import TypeDecorator + +from cartsnitch_api.config import settings + + +def _get_fernet() -> Fernet: + return Fernet(settings.fernet_key.encode()) + + +class EncryptedJSON(TypeDecorator): + """SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet. + + Stores data as a Fernet-encrypted text blob in the database. + On read, decrypts and deserialises back to a Python dict/list. + """ + + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + plaintext = json.dumps(value).encode() + return _get_fernet().encrypt(plaintext).decode() + + def process_result_value(self, value, dialect): + if value is None: + return None + decrypted = _get_fernet().decrypt(value.encode()) + return json.loads(decrypted) diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000..9873903 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,101 @@ +"""Shared test fixtures with in-memory SQLite database.""" + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import create_engine, event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import sessionmaker + +from cartsnitch_api.config import settings as cartsnitch_settings +from cartsnitch_api.database import get_db +from cartsnitch_api.main import create_app +from cartsnitch_api.models import Base + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +@pytest.fixture(autouse=True) +def disable_rate_limiting(): + """Disable rate limiting for all tests to prevent 429 interference.""" + cartsnitch_settings.rate_limit_enabled = False + yield + cartsnitch_settings.rate_limit_enabled = True + + +@pytest.fixture +def engine(): + """Sync in-memory SQLite engine for model unit tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """Sync SQLAlchemy session for model unit tests.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess + + +@pytest.fixture +async def db_engine(): + engine = create_async_engine(TEST_DATABASE_URL, echo=False) + + @event.listens_for(engine.sync_engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest.fixture +async def db_session(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + yield session + + +@pytest.fixture +async def client(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + + async def override_get_db(): + async with factory() as session: + yield session + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + app.dependency_overrides.clear() + + +@pytest.fixture +async def auth_headers(client): + """Register a test user and return auth headers.""" + resp = await client.post( + "/auth/register", + json={ + "email": "test@example.com", + "password": "testpass123", + "display_name": "Test User", + }, + ) + assert resp.status_code == 201 + token = resp.json()["access_token"] + return {"Authorization": f"Bearer {token}"} diff --git a/api/tests/test_auth/__init__.py b/api/tests/test_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_auth/test_auth_endpoints.py b/api/tests/test_auth/test_auth_endpoints.py new file mode 100644 index 0000000..878cbc5 --- /dev/null +++ b/api/tests/test_auth/test_auth_endpoints.py @@ -0,0 +1,209 @@ +"""Integration tests for auth endpoints.""" + +import pytest + + +@pytest.mark.asyncio +async def test_register_success(client): + resp = await client.post( + "/auth/register", + json={ + "email": "new@example.com", + "password": "securepass123", + "display_name": "New User", + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + assert data["expires_in"] == 900 # 15 min * 60 + + +@pytest.mark.asyncio +async def test_register_duplicate_email(client): + await client.post( + "/auth/register", + json={ + "email": "dupe@example.com", + "password": "securepass123", + "display_name": "User One", + }, + ) + resp = await client.post( + "/auth/register", + json={ + "email": "dupe@example.com", + "password": "securepass456", + "display_name": "User Two", + }, + ) + assert resp.status_code == 409 + + +@pytest.mark.asyncio +async def test_register_short_password(client): + resp = await client.post( + "/auth/register", + json={ + "email": "short@example.com", + "password": "short", + "display_name": "Short Pass", + }, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_login_success(client): + await client.post( + "/auth/register", + json={ + "email": "login@example.com", + "password": "securepass123", + "display_name": "Login User", + }, + ) + resp = await client.post( + "/auth/login", + json={ + "email": "login@example.com", + "password": "securepass123", + }, + ) + assert resp.status_code == 200 + assert "access_token" in resp.json() + + +@pytest.mark.asyncio +async def test_login_wrong_password(client): + await client.post( + "/auth/register", + json={ + "email": "wrong@example.com", + "password": "securepass123", + "display_name": "Wrong Pass", + }, + ) + resp = await client.post( + "/auth/login", + json={ + "email": "wrong@example.com", + "password": "badpassword1", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_login_nonexistent_user(client): + resp = await client.post( + "/auth/login", + json={ + "email": "ghost@example.com", + "password": "doesntmatter", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_refresh_token(client): + reg = await client.post( + "/auth/register", + json={ + "email": "refresh@example.com", + "password": "securepass123", + "display_name": "Refresh User", + }, + ) + refresh_token = reg.json()["refresh_token"] + + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": refresh_token, + }, + ) + assert resp.status_code == 200 + assert "access_token" in resp.json() + + +@pytest.mark.asyncio +async def test_refresh_with_invalid_token(client): + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": "invalid.token.here", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me(client, auth_headers): + resp = await client.get("/auth/me", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["email"] == "test@example.com" + assert data["display_name"] == "Test User" + assert "id" in data + assert "created_at" in data + + +@pytest.mark.asyncio +async def test_get_me_unauthorized(client): + resp = await client.get("/auth/me") + assert resp.status_code in (401, 403) # No auth header + + +@pytest.mark.asyncio +async def test_update_me(client, auth_headers): + resp = await client.patch( + "/auth/me", + headers=auth_headers, + json={ + "display_name": "Updated Name", + }, + ) + assert resp.status_code == 200 + assert resp.json()["display_name"] == "Updated Name" + + +@pytest.mark.asyncio +async def test_delete_me(client, auth_headers): + resp = await client.delete("/auth/me", headers=auth_headers) + assert resp.status_code == 204 + + # Verify user is gone (token still valid but user deleted) + resp = await client.get("/auth/me", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_refresh_after_delete_fails(client): + """Refresh token for a deleted user must be rejected.""" + reg = await client.post( + "/auth/register", + json={ + "email": "ghost@example.com", + "password": "securepass123", + "display_name": "Ghost User", + }, + ) + tokens = reg.json() + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Delete the user + resp = await client.delete("/auth/me", headers=headers) + assert resp.status_code == 204 + + # Refresh token should now fail + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": tokens["refresh_token"], + }, + ) + assert resp.status_code == 401 diff --git a/api/tests/test_e2e/__init__.py b/api/tests/test_e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_e2e/conftest.py b/api/tests/test_e2e/conftest.py new file mode 100644 index 0000000..f1390fd --- /dev/null +++ b/api/tests/test_e2e/conftest.py @@ -0,0 +1,250 @@ +"""Shared fixtures for E2E integration tests. + +Seeds a realistic dataset with stores, products, price history, +purchases, coupons, and shrinkflation events so E2E flows can +exercise cross-resource queries against real data. +""" + +from datetime import date, timedelta +from decimal import Decimal +from uuid import UUID + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.auth.jwt import decode_token +from cartsnitch_api.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, +) + +# Shared test constants +ZERO_UUID = "00000000-0000-0000-0000-000000000000" +BAD_UUID = "not-a-uuid" +# Fixed anchor date for deterministic tests +ANCHOR_DATE = date(2026, 3, 15) + + +@pytest.fixture +async def seed_data(db_engine, auth_headers): + """Seed a full dataset and return identifiers for test assertions.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + # -- Stores -- + meijer = Store(name="Meijer", slug="meijer") + kroger = Store(name="Kroger", slug="kroger") + target = Store(name="Target", slug="target") + session.add_all([meijer, kroger, target]) + await session.flush() + + # -- Products -- + cheerios = NormalizedProduct( + canonical_name="Cheerios 18oz", + category="pantry", + brand="General Mills", + size="18", + size_unit="oz", + upc_variants=["016000275263"], + ) + milk = NormalizedProduct( + canonical_name="Whole Milk 1gal", + category="dairy", + brand="Meijer", + size="1", + size_unit="gal", + ) + chicken = NormalizedProduct( + canonical_name="Chicken Breast 1lb", + category="meat", + brand=None, + size="1", + size_unit="lb", + ) + session.add_all([cheerios, milk, chicken]) + await session.flush() + + # -- Price history (multiple dates, multiple stores) -- + today = ANCHOR_DATE + prices = [] + # Cheerios at Meijer: price increase over time + for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]): + prices.append( + PriceHistory( + normalized_product_id=cheerios.id, + store_id=meijer.id, + observed_date=today - timedelta(days=60 - i * 30), + regular_price=price_val, + source="receipt", + ) + ) + # Cheerios at Kroger: stable price + for i in range(3): + prices.append( + PriceHistory( + normalized_product_id=cheerios.id, + store_id=kroger.id, + observed_date=today - timedelta(days=60 - i * 30), + regular_price=Decimal("4.49"), + source="catalog", + ) + ) + # Milk at Meijer + prices.append( + PriceHistory( + normalized_product_id=milk.id, + store_id=meijer.id, + observed_date=today - timedelta(days=7), + regular_price=Decimal("3.29"), + source="receipt", + ) + ) + # Milk at Kroger + prices.append( + PriceHistory( + normalized_product_id=milk.id, + store_id=kroger.id, + observed_date=today - timedelta(days=5), + regular_price=Decimal("3.49"), + source="catalog", + ) + ) + # Chicken at Target + prices.append( + PriceHistory( + normalized_product_id=chicken.id, + store_id=target.id, + observed_date=today - timedelta(days=3), + regular_price=Decimal("5.99"), + source="catalog", + ) + ) + session.add_all(prices) + await session.flush() + + # -- Purchases (need the user_id from the registered test user) -- + token = auth_headers["Authorization"].split(" ")[1] + payload = decode_token(token) + user_id = UUID(payload["sub"]) + + purchase1 = Purchase( + user_id=user_id, + store_id=meijer.id, + receipt_id="meijer-2026-001", + purchase_date=today - timedelta(days=10), + total=Decimal("23.45"), + subtotal=Decimal("21.50"), + tax=Decimal("1.95"), + ) + purchase2 = Purchase( + user_id=user_id, + store_id=kroger.id, + receipt_id="kroger-2026-001", + purchase_date=today - timedelta(days=5), + total=Decimal("15.78"), + subtotal=Decimal("14.50"), + tax=Decimal("1.28"), + ) + session.add_all([purchase1, purchase2]) + await session.flush() + + # -- Purchase Items -- + item1 = PurchaseItem( + purchase_id=purchase1.id, + product_name_raw="Cheerios 18oz Box", + quantity=Decimal("1"), + unit_price=Decimal("4.79"), + extended_price=Decimal("4.79"), + normalized_product_id=cheerios.id, + ) + item2 = PurchaseItem( + purchase_id=purchase1.id, + product_name_raw="Meijer Whole Milk 1gal", + quantity=Decimal("2"), + unit_price=Decimal("3.29"), + extended_price=Decimal("6.58"), + normalized_product_id=milk.id, + ) + item3 = PurchaseItem( + purchase_id=purchase2.id, + product_name_raw="KRO CHEERIOS 18OZ", + quantity=Decimal("1"), + unit_price=Decimal("4.49"), + extended_price=Decimal("4.49"), + normalized_product_id=cheerios.id, + ) + session.add_all([item1, item2, item3]) + await session.flush() + + # -- Coupons -- + coupon1 = Coupon( + store_id=meijer.id, + normalized_product_id=cheerios.id, + title="$1 off Cheerios", + description="Save $1 on any Cheerios 18oz or larger", + discount_type="fixed", + discount_value=Decimal("1.00"), + valid_from=today - timedelta(days=7), + valid_to=today + timedelta(days=30), + ) + coupon2 = Coupon( + store_id=kroger.id, + normalized_product_id=None, + title="10% off dairy", + description="10% off all dairy products", + discount_type="percent", + discount_value=Decimal("10.00"), + valid_from=today - timedelta(days=3), + valid_to=today + timedelta(days=14), + ) + session.add_all([coupon1, coupon2]) + await session.flush() + + # -- Shrinkflation events -- + shrink = ShrinkflationEvent( + normalized_product_id=cheerios.id, + detected_date=today - timedelta(days=15), + old_size="20", + new_size="18", + old_unit="oz", + new_unit="oz", + price_at_old_size=Decimal("3.99"), + price_at_new_size=Decimal("4.29"), + confidence=Decimal("0.95"), + notes="Size reduced from 20oz to 18oz while price increased", + ) + session.add(shrink) + await session.commit() + + for obj in [ + meijer, + kroger, + target, + cheerios, + milk, + chicken, + purchase1, + purchase2, + item1, + item2, + item3, + coupon1, + coupon2, + shrink, + ]: + await session.refresh(obj) + + return { + "headers": auth_headers, + "user_id": user_id, + "stores": {"meijer": meijer, "kroger": kroger, "target": target}, + "products": {"cheerios": cheerios, "milk": milk, "chicken": chicken}, + "purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2}, + "items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3}, + "coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2}, + "shrinkflation": {"cheerios_shrink": shrink}, + } diff --git a/api/tests/test_e2e/test_auth_validation.py b/api/tests/test_e2e/test_auth_validation.py new file mode 100644 index 0000000..bbded83 --- /dev/null +++ b/api/tests/test_e2e/test_auth_validation.py @@ -0,0 +1,213 @@ +"""E2E: Auth and token validation flows.""" + +import asyncio + +import pytest + + +@pytest.mark.asyncio +class TestAuthRegistrationLogin: + """Full registration → login → token refresh → profile flow.""" + + async def test_full_auth_lifecycle(self, client, db_engine): + """Register → login → get profile → refresh → get profile again.""" + # Register + reg = await client.post( + "/auth/register", + json={ + "email": "lifecycle@example.com", + "password": "securepass123", + "display_name": "Lifecycle User", + }, + ) + assert reg.status_code == 201 + tokens = reg.json() + assert "access_token" in tokens + assert "refresh_token" in tokens + assert tokens["token_type"] == "bearer" + assert tokens["expires_in"] > 0 + + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Get profile with access token + me = await client.get("/auth/me", headers=headers) + assert me.status_code == 200 + assert me.json()["email"] == "lifecycle@example.com" + assert me.json()["display_name"] == "Lifecycle User" + + # Sleep 1s so the new token has a different exp than the registration token + await asyncio.sleep(1) + + # Login with same credentials + login = await client.post( + "/auth/login", + json={"email": "lifecycle@example.com", "password": "securepass123"}, + ) + assert login.status_code == 200 + login_tokens = login.json() + assert login_tokens["access_token"] != tokens["access_token"] + + # Refresh token + refresh = await client.post( + "/auth/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert refresh.status_code == 200 + new_tokens = refresh.json() + assert new_tokens["access_token"] != tokens["access_token"] + + # Use refreshed token to access profile + new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"} + me2 = await client.get("/auth/me", headers=new_headers) + assert me2.status_code == 200 + assert me2.json()["email"] == "lifecycle@example.com" + + +@pytest.mark.asyncio +class TestTokenValidation: + """Token edge cases and error responses.""" + + async def test_expired_token_rejected(self, client, db_engine): + """Manually craft an expired token and verify rejection.""" + import uuid + from datetime import UTC, datetime, timedelta + + from jose import jwt + + from cartsnitch_api.config import settings + + payload = { + "sub": str(uuid.uuid4()), + "exp": datetime.now(UTC) - timedelta(minutes=5), + "type": "access", + } + token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) + resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 401 + + async def test_invalid_token_rejected(self, client, db_engine): + resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"}) + assert resp.status_code == 401 + + async def test_missing_auth_header(self, client, db_engine): + resp = await client.get("/auth/me") + assert resp.status_code in (401, 403) + + async def test_refresh_token_cannot_access_endpoints(self, client, db_engine): + """A refresh token should not work as an access token.""" + reg = await client.post( + "/auth/register", + json={ + "email": "refresh-test@example.com", + "password": "securepass123", + "display_name": "Refresh Test", + }, + ) + refresh_token = reg.json()["refresh_token"] + resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}) + assert resp.status_code == 401 + + async def test_deleted_user_token_invalid(self, client, db_engine): + """After deleting an account, tokens should no longer work.""" + reg = await client.post( + "/auth/register", + json={ + "email": "delete-me@example.com", + "password": "securepass123", + "display_name": "Delete Me", + }, + ) + tokens = reg.json() + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Delete account + delete_resp = await client.delete("/auth/me", headers=headers) + assert delete_resp.status_code == 204 + + # Profile should fail + me = await client.get("/auth/me", headers=headers) + assert me.status_code in (401, 404) + + +@pytest.mark.asyncio +class TestAuthProtectedEndpoints: + """Verify auth is enforced on all user-specific endpoints.""" + + @pytest.mark.parametrize( + "method,path", + [ + ("GET", "/purchases"), + ("GET", "/products"), + ("GET", "/prices/trends"), + ("GET", "/prices/increases"), + ("GET", "/coupons"), + ("GET", "/alerts"), + ("GET", "/me/stores"), + ], + ) + async def test_endpoints_require_auth(self, client, db_engine, method, path): + resp = await client.request(method, path) + assert resp.status_code in (401, 403), f"{method} {path} should require auth" + + +@pytest.mark.asyncio +class TestCrossUserDataIsolation: + """Verify that users cannot access other users' data.""" + + async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data): + """Register a second user and verify they cannot see User A's purchases.""" + # User A's purchase (from seed_data) + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + + # Register User B + reg = await client.post( + "/auth/register", + json={ + "email": "userb@example.com", + "password": "securepass123", + "display_name": "User B", + }, + ) + assert reg.status_code == 201 + user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + # User B tries to access User A's specific purchase + resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers) + assert resp.status_code in (403, 404), ( + "User B should not be able to access User A's purchase" + ) + + async def test_user_b_purchase_list_is_empty(self, client, seed_data): + """A new user should see no purchases (not User A's purchases).""" + reg = await client.post( + "/auth/register", + json={ + "email": "userc@example.com", + "password": "securepass123", + "display_name": "User C", + }, + ) + assert reg.status_code == 201 + user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + resp = await client.get("/purchases", headers=user_c_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 0, "New user should have no purchases" + + async def test_user_b_stores_isolated(self, client, seed_data): + """User B's connected stores should be independent from User A.""" + reg = await client.post( + "/auth/register", + json={ + "email": "userd@example.com", + "password": "securepass123", + "display_name": "User D", + }, + ) + assert reg.status_code == 201 + user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + # User D should have no connected stores + resp = await client.get("/me/stores", headers=user_d_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 0, "New user should have no connected stores" diff --git a/api/tests/test_e2e/test_cross_resource_flow.py b/api/tests/test_e2e/test_cross_resource_flow.py new file mode 100644 index 0000000..1f90671 --- /dev/null +++ b/api/tests/test_e2e/test_cross_resource_flow.py @@ -0,0 +1,114 @@ +"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts.""" + +import pytest + + +@pytest.mark.asyncio +class TestStoreConnectToPurchaseFlow: + """Connect a store, then verify purchases and related data are accessible.""" + + async def test_connect_store_then_list(self, client, seed_data): + headers = seed_data["headers"] + # Connect to Meijer + resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert resp.status_code in (200, 201) + + # Verify store appears in user's connected stores + stores = await client.get("/me/stores", headers=headers) + assert stores.status_code == 200 + slugs = [s["store"]["slug"] for s in stores.json()] + assert "meijer" in slugs + + async def test_disconnect_store(self, client, seed_data): + headers = seed_data["headers"] + await client.post("/me/stores/kroger/connect", json={}, headers=headers) + resp = await client.delete("/me/stores/kroger", headers=headers) + assert resp.status_code in (200, 204) + + # Verify store no longer in connected list + stores = await client.get("/me/stores", headers=headers) + slugs = [s["store"]["slug"] for s in stores.json()] + assert "kroger" not in slugs + + +@pytest.mark.asyncio +class TestPurchaseToPriceFlow: + """Verify purchase data links to price comparison data.""" + + async def test_purchase_items_link_to_products(self, client, seed_data): + """Items from purchases reference products that have price data.""" + headers = seed_data["headers"] + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + + # Get purchase detail + purchase = await client.get(f"/purchases/{purchase_id}", headers=headers) + assert purchase.status_code == 200 + items = purchase.json()["line_items"] + + # Get product detail for an item that has a product_id + product_ids = [li["product_id"] for li in items if li.get("product_id")] + assert len(product_ids) >= 1 + + for pid in product_ids: + product = await client.get(f"/products/{pid}", headers=headers) + assert product.status_code == 200 + assert len(product.json()["prices_by_store"]) >= 1 + + +@pytest.mark.asyncio +class TestCouponFlow: + """Verify coupon listing and relevance filtering.""" + + async def test_list_all_coupons(self, client, seed_data): + headers = seed_data["headers"] + resp = await client.get("/coupons", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + descriptions = [c["description"] for c in data] + assert any("Cheerios" in d for d in descriptions) + + async def test_filter_coupons_by_store(self, client, seed_data): + headers = seed_data["headers"] + meijer_id = str(seed_data["stores"]["meijer"].id) + resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert all(c["store_name"] == "Meijer" for c in data) + + async def test_relevant_coupons_for_user(self, client, seed_data): + """User bought Cheerios, so the Cheerios coupon should be relevant.""" + headers = seed_data["headers"] + resp = await client.get("/coupons/relevant", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases" + descriptions = [c["description"] for c in data] + assert any("Cheerios" in d for d in descriptions) + + +@pytest.mark.asyncio +class TestAlertFlow: + """Verify alert listing with seeded data.""" + + async def test_list_alerts(self, client, seed_data): + """User bought Cheerios which has a shrinkflation event — may appear as alert.""" + headers = seed_data["headers"] + resp = await client.get("/alerts", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + # If alerts are generated synchronously, verify shrinkflation alert content + if len(data) > 0: + alert_types = [a["alert_type"] for a in data] + product_names = [a["product_name"] for a in data] + assert any(t in ("shrinkflation", "price_increase") for t in alert_types) + assert any("Cheerios" in name for name in product_names) + + async def test_alert_settings_default(self, client, seed_data): + headers = seed_data["headers"] + resp = await client.get("/alerts/settings", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert "price_increase_threshold_pct" in data + assert "shrinkflation_enabled" in data diff --git a/api/tests/test_e2e/test_error_responses.py b/api/tests/test_e2e/test_error_responses.py new file mode 100644 index 0000000..c3ad16e --- /dev/null +++ b/api/tests/test_e2e/test_error_responses.py @@ -0,0 +1,127 @@ +"""E2E: Error responses for bad input across all endpoint categories.""" + +import pytest + +from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID + + +@pytest.mark.asyncio +class TestRegistrationErrors: + """Validation errors during user registration.""" + + async def test_short_password(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "short@example.com", "password": "short", "display_name": "Test"}, + ) + assert resp.status_code == 422 + + async def test_invalid_email(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"}, + ) + assert resp.status_code == 422 + + async def test_missing_fields(self, client, db_engine): + resp = await client.post("/auth/register", json={}) + assert resp.status_code == 422 + + async def test_empty_display_name(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "empty@example.com", "password": "securepass123", "display_name": ""}, + ) + assert resp.status_code == 422 + + async def test_duplicate_email(self, client, db_engine): + payload = { + "email": "dupe@example.com", + "password": "securepass123", + "display_name": "First", + } + first = await client.post("/auth/register", json=payload) + assert first.status_code == 201 + second = await client.post("/auth/register", json=payload) + assert second.status_code == 409 + + +@pytest.mark.asyncio +class TestLoginErrors: + """Login failure modes.""" + + async def test_wrong_password(self, client, db_engine): + await client.post( + "/auth/register", + json={ + "email": "login-err@example.com", + "password": "correctpass1", + "display_name": "Login", + }, + ) + resp = await client.post( + "/auth/login", + json={"email": "login-err@example.com", "password": "wrongpass123"}, + ) + assert resp.status_code == 401 + + async def test_nonexistent_user(self, client, db_engine): + resp = await client.post( + "/auth/login", + json={"email": "nobody@example.com", "password": "doesntmatter"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestNotFoundErrors: + """404 responses for missing resources.""" + + async def test_product_not_found(self, client, seed_data): + resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_purchase_not_found(self, client, seed_data): + resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_public_trend_not_found(self, client, seed_data): + resp = await client.get(f"/public/trends/{ZERO_UUID}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestMalformedInput: + """Invalid UUID formats and bad query params.""" + + async def test_invalid_uuid_product(self, client, seed_data): + resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_invalid_uuid_purchase(self, client, seed_data): + resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_invalid_uuid_public_trend(self, client, seed_data): + resp = await client.get(f"/public/trends/{BAD_UUID}") + assert resp.status_code == 422 + + +@pytest.mark.asyncio +class TestStoreConnectionErrors: + """Store connection edge cases.""" + + async def test_connect_nonexistent_store(self, client, seed_data): + resp = await client.post( + "/me/stores/nonexistent-store/connect", + json={}, + headers=seed_data["headers"], + ) + assert resp.status_code == 404 + + async def test_connect_store_twice(self, client, seed_data): + headers = seed_data["headers"] + first = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert first.status_code in (200, 201) + second = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert second.status_code == 409 diff --git a/api/tests/test_e2e/test_price_history.py b/api/tests/test_e2e/test_price_history.py new file mode 100644 index 0000000..3d53f06 --- /dev/null +++ b/api/tests/test_e2e/test_price_history.py @@ -0,0 +1,102 @@ +"""E2E: Price history queries returning correct data.""" + +import pytest + + +@pytest.mark.asyncio +class TestPriceTrends: + """Verify price trend aggregation against seeded history.""" + + async def test_trends_returns_all_products(self, client, seed_data): + resp = await client.get("/prices/trends", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + product_names = [t["product_name"] for t in data] + assert "Cheerios 18oz" in product_names + assert "Whole Milk 1gal" in product_names + + async def test_trends_filter_by_category(self, client, seed_data): + resp = await client.get( + "/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + # Only dairy products should appear + for trend in data: + assert trend["product_name"] == "Whole Milk 1gal" + + async def test_trends_contain_data_points(self, client, seed_data): + resp = await client.get("/prices/trends", headers=seed_data["headers"]) + data = resp.json() + cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz") + assert len(cheerios_trend["data_points"]) >= 3 + + +@pytest.mark.asyncio +class TestPriceIncreases: + """Detect price increases from seeded price history.""" + + async def test_increases_detected(self, client, seed_data): + resp = await client.get("/prices/increases", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + # Cheerios at Meijer went from 3.99 → 4.29 → 4.79 + cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"] + assert len(cheerios_increases) >= 1 + # Verify the increase data makes sense + for inc in cheerios_increases: + assert inc["new_price"] > inc["old_price"] + assert inc["increase_pct"] > 0 + assert inc["store_name"] == "Meijer" + + async def test_stable_prices_not_flagged(self, client, seed_data): + """Kroger Cheerios price is stable at $4.49 — should not appear as increase.""" + resp = await client.get("/prices/increases", headers=seed_data["headers"]) + data = resp.json() + kroger_increases = [ + inc + for inc in data + if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger" + ] + assert len(kroger_increases) == 0 + + +@pytest.mark.asyncio +class TestPriceComparison: + """Compare prices across stores for specific products.""" + + async def test_compare_cheerios_across_stores(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get( + "/prices/comparison", + params={"product_ids": cheerios_id}, + headers=seed_data["headers"], + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + cheerios_cmp = data[0] + assert cheerios_cmp["product_name"] == "Cheerios 18oz" + store_names = [p["store_name"] for p in cheerios_cmp["prices"]] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_compare_requires_product_ids(self, client, seed_data): + """product_ids is required — omitting it must return 422.""" + resp = await client.get("/prices/comparison", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_compare_multiple_products(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + milk_id = str(seed_data["products"]["milk"].id) + resp = await client.get( + "/prices/comparison", + params=[("product_ids", cheerios_id), ("product_ids", milk_id)], + headers=seed_data["headers"], + ) + assert resp.status_code == 200 + data = resp.json() + names = [c["product_name"] for c in data] + assert "Cheerios 18oz" in names + assert "Whole Milk 1gal" in names diff --git a/api/tests/test_e2e/test_product_search_lookup.py b/api/tests/test_e2e/test_product_search_lookup.py new file mode 100644 index 0000000..ea97c34 --- /dev/null +++ b/api/tests/test_e2e/test_product_search_lookup.py @@ -0,0 +1,82 @@ +"""E2E: Product search/lookup endpoints with real DB fixtures.""" + +import pytest + +from tests.test_e2e.conftest import ZERO_UUID + + +@pytest.mark.asyncio +class TestProductSearch: + """Search and filter products against seeded data.""" + + async def test_list_all_products(self, client, seed_data): + resp = await client.get("/products", headers=seed_data["headers"]) + assert resp.status_code == 200 + products = resp.json() + names = [p["name"] for p in products] + assert "Cheerios 18oz" in names + assert "Whole Milk 1gal" in names + assert "Chicken Breast 1lb" in names + + async def test_search_by_name(self, client, seed_data): + resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"]) + assert resp.status_code == 200 + products = resp.json() + assert len(products) >= 1 + assert all("cheerios" in p["name"].lower() for p in products) + + async def test_search_by_category(self, client, seed_data): + resp = await client.get( + "/products", params={"category": "dairy"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + products = resp.json() + assert len(products) >= 1 + assert all(p["category"] == "dairy" for p in products) + + async def test_search_no_results(self, client, seed_data): + resp = await client.get( + "/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +class TestProductLookup: + """Detailed product lookups with cross-store pricing.""" + + async def test_get_product_detail_with_prices(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Cheerios 18oz" + assert data["brand"] == "General Mills" + assert data["category"] == "pantry" + # Should have prices from both Meijer and Kroger + store_names = [p["store_name"] for p in data["prices_by_store"]] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_product_prices_reflect_latest(self, client, seed_data): + """The latest Meijer price for Cheerios should be 4.79 (the increase).""" + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"]) + data = resp.json() + meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer") + assert meijer_price["current_price"] == 4.79 + + async def test_product_not_found(self, client, seed_data): + resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_product_price_history(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations + # Verify chronological ordering exists + prices = [dp["price"] for dp in data["data_points"]] + assert len(prices) >= 3 diff --git a/api/tests/test_e2e/test_public_endpoints.py b/api/tests/test_e2e/test_public_endpoints.py new file mode 100644 index 0000000..a0e24cf --- /dev/null +++ b/api/tests/test_e2e/test_public_endpoints.py @@ -0,0 +1,59 @@ +"""E2E: Public price transparency endpoints (no auth required).""" + +import uuid + +import pytest + + +@pytest.mark.asyncio +class TestPublicTrends: + """Public price trend endpoint — no auth, real data.""" + + async def test_public_trend_returns_data(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/public/trends/{cheerios_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Cheerios 18oz" + assert len(data["data_points"]) >= 3 + + async def test_public_trend_no_auth_needed(self, client, seed_data): + """Confirm no Authorization header is required.""" + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/public/trends/{cheerios_id}") + assert resp.status_code == 200 + + +@pytest.mark.asyncio +class TestPublicStoreComparison: + """Public store comparison endpoint.""" + + async def test_store_comparison(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get( + "/public/store-comparison", + params=[("product_ids", cheerios_id)], + ) + assert resp.status_code == 200 + data = resp.json() + assert "products" in data + assert len(data["products"]) >= 1 + + async def test_store_comparison_rejects_more_than_20_ids(self, client): + """max_length=20 guard: 21 product IDs must return 422.""" + too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)] + resp = await client.get("/public/store-comparison", params=too_many) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +class TestPublicInflation: + """Public inflation index endpoint.""" + + async def test_inflation_returns_index(self, client, seed_data): + resp = await client.get("/public/inflation") + assert resp.status_code == 200 + data = resp.json() + assert "cartsnitch_index" in data + assert "cpi_baseline" in data + assert "categories" in data diff --git a/api/tests/test_e2e/test_purchase_flow.py b/api/tests/test_e2e/test_purchase_flow.py new file mode 100644 index 0000000..44de438 --- /dev/null +++ b/api/tests/test_e2e/test_purchase_flow.py @@ -0,0 +1,87 @@ +"""E2E: Purchase listing, detail, and stats against real DB fixtures.""" + +import pytest + +from tests.test_e2e.conftest import ZERO_UUID + + +@pytest.mark.asyncio +class TestPurchaseList: + """List and filter a user's purchases.""" + + async def test_list_user_purchases(self, client, seed_data): + resp = await client.get("/purchases", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + store_names = [p["store_name"] for p in data] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_filter_purchases_by_store(self, client, seed_data): + meijer_id = str(seed_data["stores"]["meijer"].id) + resp = await client.get( + "/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert all(p["store_name"] == "Meijer" for p in data) + + async def test_purchases_require_auth(self, client, seed_data): + resp = await client.get("/purchases") + assert resp.status_code in (401, 403) + + +@pytest.mark.asyncio +class TestPurchaseDetail: + """Retrieve individual purchase with line items.""" + + async def test_get_purchase_detail(self, client, seed_data): + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["store_name"] == "Meijer" + assert data["total"] == 23.45 + assert len(data["line_items"]) == 2 + item_names = [li["name"] for li in data["line_items"]] + assert "Cheerios 18oz Box" in item_names + assert "Meijer Whole Milk 1gal" in item_names + + async def test_line_item_amounts_correct(self, client, seed_data): + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"]) + data = resp.json() + cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"]) + assert cheerios_item["unit_price"] == 4.79 + assert cheerios_item["quantity"] == 1.0 + assert cheerios_item["total_price"] == 4.79 + + async def test_purchase_not_found(self, client, seed_data): + resp = await client.get( + f"/purchases/{ZERO_UUID}", + headers=seed_data["headers"], + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestPurchaseStats: + """Verify spending aggregation across purchases.""" + + async def test_purchase_stats_totals(self, client, seed_data): + resp = await client.get("/purchases/stats", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["purchase_count"] == 2 + # 23.45 + 15.78 = 39.23 + assert abs(data["total_spent"] - 39.23) < 0.01 + + async def test_purchase_stats_by_store(self, client, seed_data): + resp = await client.get("/purchases/stats", headers=seed_data["headers"]) + data = resp.json() + assert "Meijer" in data["by_store"] + assert "Kroger" in data["by_store"] + assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01 + assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01 diff --git a/api/tests/test_encrypted_json.py b/api/tests/test_encrypted_json.py new file mode 100644 index 0000000..2ef3ccb --- /dev/null +++ b/api/tests/test_encrypted_json.py @@ -0,0 +1,130 @@ +"""Tests for EncryptedJSON TypeDecorator and session_data encryption.""" + +import json + +import pytest +from cryptography.fernet import Fernet +from pydantic import ValidationError +from sqlalchemy import column, create_engine, table, text +from sqlalchemy.orm import sessionmaker + +from cartsnitch_api.config import settings +from cartsnitch_api.models import Base +from cartsnitch_api.models.store import Store +from cartsnitch_api.models.user import User, UserStoreAccount + + +@pytest.fixture +def engine(): + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess + + +@pytest.fixture +def store(session): + s = Store(name="Test Store", slug="test-store") + session.add(s) + session.commit() + session.refresh(s) + return s + + +@pytest.fixture +def user(session): + u = User(email="alice@example.com", hashed_password="fakehash") + session.add(u) + session.commit() + session.refresh(u) + return u + + +class TestEncryptedJSONType: + """Unit tests for the EncryptedJSON TypeDecorator.""" + + def test_round_trip(self, session, user, store): + """Data written via the ORM comes back as the original dict.""" + original = {"token": "abc123", "cookies": {"session_id": "xyz"}} + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == original + + def test_stored_value_is_encrypted(self, session, user, store): + """The raw value in the DB should be a Fernet token, not plaintext JSON.""" + original = {"secret": "do-not-leak"} + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original) + session.add(account) + session.commit() + + # Use a raw table construct to bypass TypeDecorator on read + raw_table = table("user_store_accounts", column("id"), column("session_data")) + raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first() + # If UUID matching fails with str, try bytes format + if raw is None: + raw = session.execute( + text("SELECT session_data FROM user_store_accounts LIMIT 1") + ).scalar_one() + else: + raw = raw[1] + + assert raw != json.dumps(original) + assert raw.startswith("gAAAAA") + + # Verify we can decrypt the raw value manually + f = Fernet(settings.fernet_key.encode()) + decrypted = json.loads(f.decrypt(raw.encode())) + assert decrypted == original + + def test_null_round_trip(self, session, user, store): + """NULL session_data stays NULL.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data is None + + def test_empty_dict_round_trip(self, session, user, store): + """Empty dict round-trips correctly.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={}) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == {} + + def test_update_session_data(self, session, user, store): + """Updating session_data re-encrypts the new value.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1}) + session.add(account) + session.commit() + + account.session_data = {"v": 2, "new_field": True} + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == {"v": 2, "new_field": True} + + +class TestEncryptionKeyValidation: + """Test that invalid/missing keys are caught at startup.""" + + def test_invalid_fernet_key_rejected(self, monkeypatch): + """Settings validation rejects a bad key.""" + monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key") + + with pytest.raises(ValidationError): + from cartsnitch_api.config import Settings + + Settings() diff --git a/api/tests/test_middleware/__init__.py b/api/tests/test_middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_middleware/conftest.py b/api/tests/test_middleware/conftest.py new file mode 100644 index 0000000..12f6b47 --- /dev/null +++ b/api/tests/test_middleware/conftest.py @@ -0,0 +1,19 @@ +"""Conftest for middleware tests — re-enables rate limiting after global disable.""" + +import pytest + +from cartsnitch_api.config import settings as cartsnitch_settings + + +@pytest.fixture(autouse=True) +def enable_rate_limiting(): + """Re-enable rate limiting after the global disable_rate_limiting fixture runs. + + The root conftest disables rate limiting for all tests to prevent 429 + interference. Middleware tests need it active to verify headers and + enforcement. This fixture runs after the root fixture (more local = later + in setup order) so True is the effective value during the test body. + """ + cartsnitch_settings.rate_limit_enabled = True + yield + cartsnitch_settings.rate_limit_enabled = False diff --git a/api/tests/test_middleware/test_error_handler.py b/api/tests/test_middleware/test_error_handler.py new file mode 100644 index 0000000..950351d --- /dev/null +++ b/api/tests/test_middleware/test_error_handler.py @@ -0,0 +1,54 @@ +"""Tests for structured error responses and error monitoring.""" + +import pytest + + +@pytest.mark.asyncio +async def test_404_returns_structured_error(client): + """Non-existent route should return structured error.""" + resp = await client.get("/nonexistent") + assert resp.status_code == 404 + body = resp.json() + assert "detail" in body + assert "code" in body + assert body["code"] == "NOT_FOUND" + + +@pytest.mark.asyncio +async def test_validation_error_returns_422_with_field_errors(client): + """Invalid request body should return structured validation errors.""" + resp = await client.post( + "/auth/register", + json={"email": "not-an-email", "password": "short", "display_name": ""}, + ) + assert resp.status_code == 422 + body = resp.json() + assert body["code"] == "VALIDATION_ERROR" + assert "errors" in body + assert isinstance(body["errors"], list) + assert len(body["errors"]) > 0 + # Each error should have field, message, type + for err in body["errors"]: + assert "field" in err + assert "message" in err + assert "type" in err + + +@pytest.mark.asyncio +async def test_error_stats_requires_service_key(client): + """Error stats endpoint should require X-Service-Key.""" + resp = await client.get("/internal/error-stats") + assert resp.status_code == 422 # Missing required header + + +@pytest.mark.asyncio +async def test_error_stats_with_valid_key(client): + """Error stats endpoint returns monitoring data with valid key.""" + resp = await client.get( + "/internal/error-stats", + headers={"X-Service-Key": "change-me-in-production"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "error_counts" in body + assert "recent_5xx_count" in body diff --git a/api/tests/test_middleware/test_rate_limit.py b/api/tests/test_middleware/test_rate_limit.py new file mode 100644 index 0000000..d5b7691 --- /dev/null +++ b/api/tests/test_middleware/test_rate_limit.py @@ -0,0 +1,55 @@ +"""Tests for rate limiting middleware.""" + +import pytest + +from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter + + +class TestSlidingWindowCounter: + def test_allows_within_limit(self): + counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) + for i in range(5): + allowed, remaining, retry = counter.is_allowed("test-key") + assert allowed is True + assert remaining == 4 - i + + def test_blocks_over_limit(self): + counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) + for _ in range(3): + counter.is_allowed("test-key") + + allowed, remaining, retry = counter.is_allowed("test-key") + assert allowed is False + assert remaining == 0 + assert retry > 0 + + def test_separate_keys(self): + counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) + # Fill key-a + counter.is_allowed("key-a") + counter.is_allowed("key-a") + allowed_a, _, _ = counter.is_allowed("key-a") + assert allowed_a is False + + # key-b should still be allowed + allowed_b, remaining, _ = counter.is_allowed("key-b") + assert allowed_b is True + assert remaining == 1 + + +@pytest.mark.asyncio +async def test_rate_limit_returns_429(client): + """Public endpoint should return 429 after limit exceeded.""" + # The default limit is 60/min — we won't hit it in normal tests, + # but we verify the middleware adds rate limit headers. + resp = await client.get("/public/inflation") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + + +@pytest.mark.asyncio +async def test_health_skips_rate_limit(client): + """Health endpoint should not have rate limit headers.""" + resp = await client.get("/health") + assert resp.status_code == 200 + assert "x-ratelimit-limit" not in resp.headers diff --git a/api/tests/test_models.py b/api/tests/test_models.py new file mode 100644 index 0000000..c0f8651 --- /dev/null +++ b/api/tests/test_models.py @@ -0,0 +1,376 @@ +"""Tests for SQLAlchemy ORM models.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from sqlalchemy import inspect + +from cartsnitch_api.constants import ( + AccountStatus, + DiscountType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_api.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + StoreLocation, + User, + UserStoreAccount, +) + + +class TestTableCreation: + """Verify all expected tables are created.""" + + def test_all_tables_exist(self, engine): + inspector = inspect(engine) + table_names = set(inspector.get_table_names()) + expected = { + "stores", + "store_locations", + "users", + "user_store_accounts", + "purchases", + "purchase_items", + "normalized_products", + "price_history", + "coupons", + "shrinkflation_events", + } + assert expected.issubset(table_names) + + def test_ten_tables_total(self, engine): + inspector = inspect(engine) + assert len(inspector.get_table_names()) == 10 + + +class TestUUIDPrimaryKeys: + """All models use UUID PKs.""" + + def test_store_uuid_pk(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert isinstance(store.id, uuid.UUID) + + def test_user_uuid_pk(self, session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(user) + session.commit() + assert isinstance(user.id, uuid.UUID) + + +class TestStoreModel: + def test_store_slug_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert store.slug == StoreSlug.KROGER + + def test_store_unique_slug(self, session): + s1 = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + s2 = Store( + id=uuid.uuid4(), + name="Target Duplicate", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(s1) + session.commit() + session.add(s2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestStoreLocationModel: + def test_store_location_fields(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + loc = StoreLocation( + id=uuid.uuid4(), + store_id=store.id, + address="123 Main St", + city="Ann Arbor", + state="MI", + zip="48104", + lat=42.2808, + lng=-83.7430, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(loc) + session.commit() + assert loc.city == "Ann Arbor" + assert loc.lat == pytest.approx(42.2808) + + +class TestUserStoreAccountModel: + def test_account_status_enum(self, session): + user = User( + id=uuid.uuid4(), + email="test@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + acct = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(acct) + session.commit() + assert acct.status == AccountStatus.ACTIVE + + def test_unique_user_store_constraint(self, session): + """One account per user per store.""" + user = User( + id=uuid.uuid4(), + email="unique@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + a1 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + a2 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.EXPIRED, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(a1) + session.commit() + session.add(a2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestPurchaseModel: + def test_purchase_with_items(self, session): + user = User( + id=uuid.uuid4(), + email="buyer@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + purchase = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + ingested_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(purchase) + session.flush() + item = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase.id, + product_name_raw="Meijer Whole Milk 1 Gallon", + upc="0041250000001", + quantity=Decimal("1"), + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ) + session.add(item) + session.commit() + assert item.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert item.unit_price == Decimal("3.49") + + +class TestNormalizedProductModel: + def test_product_with_upc_variants(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + brand="Store Brand", + size="128", + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + assert product.category == ProductCategory.DAIRY + assert product.size_unit == SizeUnit.FL_OZ + + +class TestPriceHistoryModel: + def test_price_source_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Eggs, Large, 12ct", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([store, product]) + session.flush() + ph = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + sale_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(ph) + session.commit() + assert ph.source == PriceSource.RECEIPT + assert ph.regular_price == Decimal("4.99") + + +class TestCouponModel: + def test_coupon_discount_types(self, session): + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + coupon = Coupon( + id=uuid.uuid4(), + store_id=store.id, + title="$2 off eggs", + discount_type=DiscountType.FIXED, + discount_value=Decimal("2.00"), + requires_clip=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(coupon) + session.commit() + assert coupon.discount_type == DiscountType.FIXED + assert coupon.discount_value == Decimal("2.00") + + +class TestShrinkflationEventModel: + def test_shrinkflation_event(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Cereal, Honey Oats", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + price_at_old_size=Decimal("4.99"), + price_at_new_size=Decimal("4.99"), + confidence=Decimal("0.95"), + notes="Size reduced by 14.4%, price unchanged", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(event) + session.commit() + assert event.confidence == Decimal("0.95") + assert event.old_unit == SizeUnit.OZ diff --git a/api/tests/test_openapi.py b/api/tests/test_openapi.py new file mode 100644 index 0000000..97eef19 --- /dev/null +++ b/api/tests/test_openapi.py @@ -0,0 +1,92 @@ +"""Verify all expected routes are present in the OpenAPI spec.""" + +import pytest +from httpx import ASGITransport, AsyncClient + +from cartsnitch_api.main import app + +EXPECTED_ROUTES = [ + # Auth (6) + ("post", "/auth/register"), + ("post", "/auth/login"), + ("post", "/auth/refresh"), + ("get", "/auth/me"), + ("patch", "/auth/me"), + ("delete", "/auth/me"), + # Stores (4) + ("get", "/stores"), + ("get", "/me/stores"), + ("post", "/me/stores/{store_slug}/connect"), + ("delete", "/me/stores/{store_slug}"), + # Purchases (3) + ("get", "/purchases"), + ("get", "/purchases/stats"), + ("get", "/purchases/{purchase_id}"), + # Products (3) + ("get", "/products"), + ("get", "/products/{product_id}"), + ("get", "/products/{product_id}/prices"), + # Prices (3) + ("get", "/prices/trends"), + ("get", "/prices/increases"), + ("get", "/prices/comparison"), + # Coupons (2) + ("get", "/coupons"), + ("get", "/coupons/relevant"), + # Shopping (2) + ("post", "/shopping/optimize"), + ("get", "/shopping/lists"), + # Alerts (3) + ("get", "/alerts"), + ("get", "/alerts/settings"), + ("put", "/alerts/settings"), + # Scraping (2) + ("post", "/scraping/{store_slug}/sync"), + ("get", "/scraping/status"), + # Public (3) + ("get", "/public/trends/{product_id}"), + ("get", "/public/store-comparison"), + ("get", "/public/inflation"), + # Health (1) + ("get", "/health"), +] + + +@pytest.mark.asyncio +async def test_all_routes_in_openapi(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/openapi.json") + assert resp.status_code == 200 + spec = resp.json() + paths = spec["paths"] + + registered = set() + for path, methods in paths.items(): + for method in methods: + if method in ("get", "post", "put", "delete", "patch"): + registered.add((method, path)) + + missing = [] + for method, path in EXPECTED_ROUTES: + if (method, path) not in registered: + missing.append(f"{method.upper()} {path}") + + assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing) + + +@pytest.mark.asyncio +async def test_route_count(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/openapi.json") + spec = resp.json() + paths = spec["paths"] + + count = 0 + for _path, methods in paths.items(): + for method in methods: + if method in ("get", "post", "put", "delete", "patch"): + count += 1 + + assert count == 33, f"Expected 33 routes, found {count}" diff --git a/api/tests/test_routes/__init__.py b/api/tests/test_routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_routes/test_alerts.py b/api/tests/test_routes/test_alerts.py new file mode 100644 index 0000000..5b576a5 --- /dev/null +++ b/api/tests/test_routes/test_alerts.py @@ -0,0 +1,35 @@ +"""Integration tests for alert endpoints.""" + +import pytest + + +@pytest.mark.asyncio +async def test_list_alerts_empty(client, auth_headers): + """No purchases means no alerts.""" + resp = await client.get("/alerts", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_get_alert_settings(client, auth_headers): + resp = await client.get("/alerts/settings", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["price_increase_threshold_pct"] == 5.0 + assert data["shrinkflation_enabled"] is True + assert data["email_notifications"] is False + + +@pytest.mark.asyncio +async def test_update_alert_settings_returns_501(client, auth_headers): + resp = await client.put( + "/alerts/settings", + headers=auth_headers, + json={ + "price_increase_threshold_pct": 10.0, + "shrinkflation_enabled": False, + "email_notifications": True, + }, + ) + assert resp.status_code == 501 diff --git a/api/tests/test_routes/test_coupons.py b/api/tests/test_routes/test_coupons.py new file mode 100644 index 0000000..8687acc --- /dev/null +++ b/api/tests/test_routes/test_coupons.py @@ -0,0 +1,58 @@ +"""Integration tests for coupon endpoints.""" + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import Coupon, Store + + +@pytest.fixture +async def coupon_data(db_engine, auth_headers): + """Seed stores and coupons.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Target", slug="target") + session.add(store) + await session.commit() + await session.refresh(store) + + coupon = Coupon( + store_id=store.id, + title="$2 off laundry", + description="$2 off any laundry detergent", + discount_value=Decimal("2.00"), + discount_type="fixed", + valid_from=date(2026, 1, 1), + valid_to=date(2026, 12, 31), + ) + session.add(coupon) + await session.commit() + + return {"store": store, "coupon": coupon, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_list_coupons(client, coupon_data): + resp = await client.get("/coupons", headers=coupon_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + + +@pytest.mark.asyncio +async def test_list_coupons_by_store(client, coupon_data): + store_id = str(coupon_data["store"].id) + resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) >= 1 + + +@pytest.mark.asyncio +async def test_relevant_coupons_empty(client, auth_headers): + """No purchases means no relevant coupons.""" + resp = await client.get("/coupons/relevant", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] diff --git a/api/tests/test_routes/test_prices.py b/api/tests/test_routes/test_prices.py new file mode 100644 index 0000000..7bdc60f --- /dev/null +++ b/api/tests/test_routes/test_prices.py @@ -0,0 +1,90 @@ +"""Integration tests for price endpoints.""" + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def price_data(db_engine, auth_headers): + """Seed products with price history showing an increase.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Walmart", slug="walmart") + product = NormalizedProduct( + canonical_name="Tide Pods 42ct", + category="household", + brand="Tide", + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + # Two price points — second is higher (increase) + ph1 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 2, 1), + regular_price=Decimal("12.99"), + source="receipt", + ) + ph2 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("14.49"), + source="receipt", + ) + session.add_all([ph1, ph2]) + await session.commit() + + return {"product": product, "store": store, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_price_trends(client, price_data): + resp = await client.get("/prices/trends", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["product_name"] == "Tide Pods 42ct" + assert len(data[0]["data_points"]) == 2 + + +@pytest.mark.asyncio +async def test_price_trends_by_category(client, price_data): + resp = await client.get("/prices/trends?category=household", headers=price_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 0 + + +@pytest.mark.asyncio +async def test_price_increases(client, price_data): + resp = await client.get("/prices/increases", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + increase = data[0] + assert increase["old_price"] == 12.99 + assert increase["new_price"] == 14.49 + assert increase["increase_pct"] > 0 + + +@pytest.mark.asyncio +async def test_price_comparison(client, price_data): + pid = str(price_data["product"].id) + resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["product_name"] == "Tide Pods 42ct" + assert len(data[0]["prices"]) >= 1 diff --git a/api/tests/test_routes/test_products.py b/api/tests/test_routes/test_products.py new file mode 100644 index 0000000..7e27c9c --- /dev/null +++ b/api/tests/test_routes/test_products.py @@ -0,0 +1,94 @@ +"""Integration tests for product endpoints.""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def product_data(db_engine, auth_headers): + """Seed products and price history.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Meijer", slug="meijer") + product = NormalizedProduct( + canonical_name="Cheerios 18oz", + category="pantry", + brand="General Mills", + upc_variants=["016000275263"], + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + ph1 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("4.99"), + source="receipt", + ) + ph2 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 10), + regular_price=Decimal("5.49"), + source="receipt", + ) + session.add_all([ph1, ph2]) + await session.commit() + + return {"product": product, "store": store, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_list_products(client, product_data): + resp = await client.get("/products", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["name"] == "Cheerios 18oz" + + +@pytest.mark.asyncio +async def test_search_products(client, product_data): + resp = await client.get("/products?q=Cheerios", headers=product_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + resp = await client.get("/products?q=nonexistent", headers=product_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 0 + + +@pytest.mark.asyncio +async def test_get_product_detail(client, product_data): + pid = str(product_data["product"].id) + resp = await client.get(f"/products/{pid}", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Cheerios 18oz" + assert data["brand"] == "General Mills" + assert len(data["prices_by_store"]) >= 1 + + +@pytest.mark.asyncio +async def test_get_product_not_found(client, auth_headers): + resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_product_prices(client, product_data): + pid = str(product_data["product"].id) + resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Cheerios 18oz" + assert len(data["data_points"]) == 2 diff --git a/api/tests/test_routes/test_public.py b/api/tests/test_routes/test_public.py new file mode 100644 index 0000000..08a5d29 --- /dev/null +++ b/api/tests/test_routes/test_public.py @@ -0,0 +1,73 @@ +"""Integration tests for public endpoints (no auth).""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def public_data(db_engine): + """Seed data for public endpoints.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Target", slug="target") + product = NormalizedProduct( + canonical_name="Skippy PB 16oz", + category="pantry", + brand="Skippy", + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + ph = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 5), + regular_price=Decimal("3.99"), + source="receipt", + ) + session.add(ph) + await session.commit() + + return {"product": product, "store": store} + + +@pytest.mark.asyncio +async def test_public_trend(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}") + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Skippy PB 16oz" + assert len(data["data_points"]) == 1 + + +@pytest.mark.asyncio +async def test_public_trend_not_found(client): + resp = await client.get(f"/public/trends/{uuid.uuid4()}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_public_store_comparison(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/store-comparison?product_ids={pid}") + assert resp.status_code == 200 + data = resp.json() + assert len(data["products"]) == 1 + + +@pytest.mark.asyncio +async def test_public_inflation(client, public_data): + resp = await client.get("/public/inflation") + assert resp.status_code == 200 + data = resp.json() + assert "categories" in data + assert "cartsnitch_index" in data diff --git a/api/tests/test_routes/test_purchases.py b/api/tests/test_routes/test_purchases.py new file mode 100644 index 0000000..14d5eb6 --- /dev/null +++ b/api/tests/test_routes/test_purchases.py @@ -0,0 +1,95 @@ +"""Integration tests for purchase endpoints.""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.auth.jwt import create_access_token +from cartsnitch_api.models import Purchase, PurchaseItem, Store, User + + +@pytest.fixture +async def purchase_data(db_engine): + """Seed a user, store, purchase, and items.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + from cartsnitch_api.auth.passwords import hash_password + + user = User( + email="buyer@example.com", + hashed_password=hash_password("testpass123"), + display_name="Buyer", + ) + store = Store(name="Kroger", slug="kroger") + session.add_all([user, store]) + await session.commit() + await session.refresh(user) + await session.refresh(store) + + purchase = Purchase( + user_id=user.id, + store_id=store.id, + receipt_id="receipt-001", + purchase_date=date(2026, 3, 10), + total=Decimal("42.50"), + ) + session.add(purchase) + await session.commit() + await session.refresh(purchase) + + item = PurchaseItem( + purchase_id=purchase.id, + product_name_raw="Organic Milk 1gal", + quantity=Decimal("1"), + unit_price=Decimal("5.99"), + extended_price=Decimal("5.99"), + ) + session.add(item) + await session.commit() + + token = create_access_token(user.id) + return { + "user": user, + "store": store, + "purchase": purchase, + "headers": {"Authorization": f"Bearer {token}"}, + } + + +@pytest.mark.asyncio +async def test_list_purchases(client, purchase_data): + resp = await client.get("/purchases", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["store_name"] == "Kroger" + assert data[0]["total"] == 42.50 + + +@pytest.mark.asyncio +async def test_get_purchase_detail(client, purchase_data): + pid = str(purchase_data["purchase"].id) + resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data["line_items"]) == 1 + assert data["line_items"][0]["name"] == "Organic Milk 1gal" + + +@pytest.mark.asyncio +async def test_get_purchase_not_found(client, auth_headers): + resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_purchase_stats(client, purchase_data): + resp = await client.get("/purchases/stats", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["total_spent"] == 42.50 + assert data["purchase_count"] == 1 + assert "Kroger" in data["by_store"] diff --git a/api/tests/test_routes/test_stores.py b/api/tests/test_routes/test_stores.py new file mode 100644 index 0000000..002ff05 --- /dev/null +++ b/api/tests/test_routes/test_stores.py @@ -0,0 +1,77 @@ +"""Integration tests for store endpoints.""" + +import pytest + +from cartsnitch_api.models import Store + + +@pytest.fixture +async def seeded_store(db_engine): + """Insert a test store directly into the DB.""" + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None) + session.add(store) + await session.commit() + await session.refresh(store) + return store + + +@pytest.mark.asyncio +async def test_list_stores(client, seeded_store): + resp = await client.get("/stores") + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["slug"] == "meijer" + + +@pytest.mark.asyncio +async def test_list_user_stores_empty(client, auth_headers): + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_connect_and_disconnect_store(client, auth_headers, seeded_store): + # Connect + resp = await client.post( + "/me/stores/meijer/connect", + headers=auth_headers, + json={"credentials": None}, + ) + assert resp.status_code == 201 + assert resp.json()["connected"] is True + + # List should show connected + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + # Disconnect + resp = await client.delete("/me/stores/meijer", headers=auth_headers) + assert resp.status_code == 204 + + # List should be empty again + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_connect_nonexistent_store(client, auth_headers): + resp = await client.post( + "/me/stores/nonexistent/connect", + headers=auth_headers, + json={}, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_connect_duplicate_store(client, auth_headers, seeded_store): + await client.post("/me/stores/meijer/connect", headers=auth_headers, json={}) + resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={}) + assert resp.status_code == 409 diff --git a/api/tests/test_services/__init__.py b/api/tests/test_services/__init__.py new file mode 100644 index 0000000..e69de29