From 27fe95707444c383aefa4dc183f1bd4046800404 Mon Sep 17 00:00:00 2001 From: Coupon Carl Date: Sat, 28 Mar 2026 02:24:02 +0000 Subject: [PATCH 1/4] feat: merge cartsnitch/api into api/ subdirectory Consolidate API gateway service into monorepo. Squashed from https://github.com/cartsnitch/api main (89bacb1). Co-Authored-By: Paperclip --- api/.dockerignore | 14 + api/.github/workflows/ci.yml | 164 ++++++++ api/.gitignore | 9 + api/CLAUDE.md | 175 ++++++++ api/Dockerfile | 26 ++ api/alembic.ini | 36 ++ api/alembic/env.py | 55 +++ api/alembic/script.py.mako | 25 ++ .../versions/001_encrypt_session_data.py | 89 +++++ api/pyproject.toml | 58 +++ api/renovate.json | 4 + api/src/cartsnitch_api/__init__.py | 0 api/src/cartsnitch_api/auth/__init__.py | 0 api/src/cartsnitch_api/auth/dependencies.py | 39 ++ api/src/cartsnitch_api/auth/jwt.py | 31 ++ api/src/cartsnitch_api/auth/passwords.py | 11 + api/src/cartsnitch_api/auth/routes.py | 96 +++++ api/src/cartsnitch_api/cache.py | 26 ++ api/src/cartsnitch_api/config.py | 51 +++ api/src/cartsnitch_api/constants.py | 85 ++++ api/src/cartsnitch_api/database.py | 16 + api/src/cartsnitch_api/main.py | 62 +++ api/src/cartsnitch_api/middleware/__init__.py | 0 api/src/cartsnitch_api/middleware/cors.py | 16 + .../middleware/error_handler.py | 190 +++++++++ .../cartsnitch_api/middleware/rate_limit.py | 111 ++++++ api/src/cartsnitch_api/models/__init__.py | 26 ++ api/src/cartsnitch_api/models/base.py | 30 ++ api/src/cartsnitch_api/models/coupon.py | 42 ++ api/src/cartsnitch_api/models/price.py | 50 +++ api/src/cartsnitch_api/models/product.py | 39 ++ api/src/cartsnitch_api/models/purchase.py | 91 +++++ .../cartsnitch_api/models/shrinkflation.py | 41 ++ api/src/cartsnitch_api/models/store.py | 52 +++ api/src/cartsnitch_api/models/user.py | 50 +++ api/src/cartsnitch_api/routes/__init__.py | 0 api/src/cartsnitch_api/routes/alerts.py | 44 ++ api/src/cartsnitch_api/routes/coupons.py | 32 ++ api/src/cartsnitch_api/routes/health.py | 20 + api/src/cartsnitch_api/routes/prices.py | 47 +++ api/src/cartsnitch_api/routes/products.py | 56 +++ api/src/cartsnitch_api/routes/public.py | 48 +++ api/src/cartsnitch_api/routes/purchases.py | 49 +++ api/src/cartsnitch_api/routes/scraping.py | 42 ++ api/src/cartsnitch_api/routes/shopping.py | 48 +++ api/src/cartsnitch_api/routes/stores.py | 61 +++ api/src/cartsnitch_api/schemas.py | 291 ++++++++++++++ api/src/cartsnitch_api/services/__init__.py | 0 api/src/cartsnitch_api/services/alerts.py | 75 ++++ api/src/cartsnitch_api/services/auth.py | 125 ++++++ api/src/cartsnitch_api/services/clipartist.py | 52 +++ api/src/cartsnitch_api/services/coupons.py | 76 ++++ api/src/cartsnitch_api/services/prices.py | 183 +++++++++ api/src/cartsnitch_api/services/products.py | 124 ++++++ api/src/cartsnitch_api/services/public.py | 129 ++++++ api/src/cartsnitch_api/services/purchases.py | 116 ++++++ api/src/cartsnitch_api/services/queries.py | 23 ++ .../cartsnitch_api/services/receiptwitness.py | 33 ++ api/src/cartsnitch_api/services/shrinkray.py | 23 ++ .../cartsnitch_api/services/stickershock.py | 32 ++ api/src/cartsnitch_api/services/stores.py | 129 ++++++ api/src/cartsnitch_api/types.py | 36 ++ api/tests/__init__.py | 0 api/tests/conftest.py | 101 +++++ api/tests/test_auth/__init__.py | 0 api/tests/test_auth/test_auth_endpoints.py | 209 ++++++++++ api/tests/test_e2e/__init__.py | 0 api/tests/test_e2e/conftest.py | 250 ++++++++++++ api/tests/test_e2e/test_auth_validation.py | 213 ++++++++++ .../test_e2e/test_cross_resource_flow.py | 114 ++++++ api/tests/test_e2e/test_error_responses.py | 127 ++++++ api/tests/test_e2e/test_price_history.py | 102 +++++ .../test_e2e/test_product_search_lookup.py | 82 ++++ api/tests/test_e2e/test_public_endpoints.py | 59 +++ api/tests/test_e2e/test_purchase_flow.py | 87 ++++ api/tests/test_encrypted_json.py | 130 ++++++ api/tests/test_middleware/__init__.py | 0 api/tests/test_middleware/conftest.py | 19 + .../test_middleware/test_error_handler.py | 54 +++ api/tests/test_middleware/test_rate_limit.py | 55 +++ api/tests/test_models.py | 376 ++++++++++++++++++ api/tests/test_openapi.py | 92 +++++ api/tests/test_routes/__init__.py | 0 api/tests/test_routes/test_alerts.py | 35 ++ api/tests/test_routes/test_coupons.py | 58 +++ api/tests/test_routes/test_prices.py | 90 +++++ api/tests/test_routes/test_products.py | 94 +++++ api/tests/test_routes/test_public.py | 73 ++++ api/tests/test_routes/test_purchases.py | 95 +++++ api/tests/test_routes/test_stores.py | 77 ++++ api/tests/test_services/__init__.py | 0 91 files changed, 6296 insertions(+) create mode 100644 api/.dockerignore create mode 100644 api/.github/workflows/ci.yml create mode 100644 api/.gitignore create mode 100644 api/CLAUDE.md create mode 100644 api/Dockerfile create mode 100644 api/alembic.ini create mode 100644 api/alembic/env.py create mode 100644 api/alembic/script.py.mako create mode 100644 api/alembic/versions/001_encrypt_session_data.py create mode 100644 api/pyproject.toml create mode 100644 api/renovate.json create mode 100644 api/src/cartsnitch_api/__init__.py create mode 100644 api/src/cartsnitch_api/auth/__init__.py create mode 100644 api/src/cartsnitch_api/auth/dependencies.py create mode 100644 api/src/cartsnitch_api/auth/jwt.py create mode 100644 api/src/cartsnitch_api/auth/passwords.py create mode 100644 api/src/cartsnitch_api/auth/routes.py create mode 100644 api/src/cartsnitch_api/cache.py create mode 100644 api/src/cartsnitch_api/config.py create mode 100644 api/src/cartsnitch_api/constants.py create mode 100644 api/src/cartsnitch_api/database.py create mode 100644 api/src/cartsnitch_api/main.py create mode 100644 api/src/cartsnitch_api/middleware/__init__.py create mode 100644 api/src/cartsnitch_api/middleware/cors.py create mode 100644 api/src/cartsnitch_api/middleware/error_handler.py create mode 100644 api/src/cartsnitch_api/middleware/rate_limit.py create mode 100644 api/src/cartsnitch_api/models/__init__.py create mode 100644 api/src/cartsnitch_api/models/base.py create mode 100644 api/src/cartsnitch_api/models/coupon.py create mode 100644 api/src/cartsnitch_api/models/price.py create mode 100644 api/src/cartsnitch_api/models/product.py create mode 100644 api/src/cartsnitch_api/models/purchase.py create mode 100644 api/src/cartsnitch_api/models/shrinkflation.py create mode 100644 api/src/cartsnitch_api/models/store.py create mode 100644 api/src/cartsnitch_api/models/user.py create mode 100644 api/src/cartsnitch_api/routes/__init__.py create mode 100644 api/src/cartsnitch_api/routes/alerts.py create mode 100644 api/src/cartsnitch_api/routes/coupons.py create mode 100644 api/src/cartsnitch_api/routes/health.py create mode 100644 api/src/cartsnitch_api/routes/prices.py create mode 100644 api/src/cartsnitch_api/routes/products.py create mode 100644 api/src/cartsnitch_api/routes/public.py create mode 100644 api/src/cartsnitch_api/routes/purchases.py create mode 100644 api/src/cartsnitch_api/routes/scraping.py create mode 100644 api/src/cartsnitch_api/routes/shopping.py create mode 100644 api/src/cartsnitch_api/routes/stores.py create mode 100644 api/src/cartsnitch_api/schemas.py create mode 100644 api/src/cartsnitch_api/services/__init__.py create mode 100644 api/src/cartsnitch_api/services/alerts.py create mode 100644 api/src/cartsnitch_api/services/auth.py create mode 100644 api/src/cartsnitch_api/services/clipartist.py create mode 100644 api/src/cartsnitch_api/services/coupons.py create mode 100644 api/src/cartsnitch_api/services/prices.py create mode 100644 api/src/cartsnitch_api/services/products.py create mode 100644 api/src/cartsnitch_api/services/public.py create mode 100644 api/src/cartsnitch_api/services/purchases.py create mode 100644 api/src/cartsnitch_api/services/queries.py create mode 100644 api/src/cartsnitch_api/services/receiptwitness.py create mode 100644 api/src/cartsnitch_api/services/shrinkray.py create mode 100644 api/src/cartsnitch_api/services/stickershock.py create mode 100644 api/src/cartsnitch_api/services/stores.py create mode 100644 api/src/cartsnitch_api/types.py create mode 100644 api/tests/__init__.py create mode 100644 api/tests/conftest.py create mode 100644 api/tests/test_auth/__init__.py create mode 100644 api/tests/test_auth/test_auth_endpoints.py create mode 100644 api/tests/test_e2e/__init__.py create mode 100644 api/tests/test_e2e/conftest.py create mode 100644 api/tests/test_e2e/test_auth_validation.py create mode 100644 api/tests/test_e2e/test_cross_resource_flow.py create mode 100644 api/tests/test_e2e/test_error_responses.py create mode 100644 api/tests/test_e2e/test_price_history.py create mode 100644 api/tests/test_e2e/test_product_search_lookup.py create mode 100644 api/tests/test_e2e/test_public_endpoints.py create mode 100644 api/tests/test_e2e/test_purchase_flow.py create mode 100644 api/tests/test_encrypted_json.py create mode 100644 api/tests/test_middleware/__init__.py create mode 100644 api/tests/test_middleware/conftest.py create mode 100644 api/tests/test_middleware/test_error_handler.py create mode 100644 api/tests/test_middleware/test_rate_limit.py create mode 100644 api/tests/test_models.py create mode 100644 api/tests/test_openapi.py create mode 100644 api/tests/test_routes/__init__.py create mode 100644 api/tests/test_routes/test_alerts.py create mode 100644 api/tests/test_routes/test_coupons.py create mode 100644 api/tests/test_routes/test_prices.py create mode 100644 api/tests/test_routes/test_products.py create mode 100644 api/tests/test_routes/test_public.py create mode 100644 api/tests/test_routes/test_purchases.py create mode 100644 api/tests/test_routes/test_stores.py create mode 100644 api/tests/test_services/__init__.py diff --git a/api/.dockerignore b/api/.dockerignore new file mode 100644 index 0000000..d7f640e --- /dev/null +++ b/api/.dockerignore @@ -0,0 +1,14 @@ +.git +.github +.pytest_cache +.ruff_cache +__pycache__ +*.py[cod] +*.egg-info +dist +.venv +.env +tests +openapi.json +CLAUDE.md +README.md diff --git a/api/.github/workflows/ci.yml b/api/.github/workflows/ci.yml new file mode 100644 index 0000000..5c61bb7 --- /dev/null +++ b/api/.github/workflows/ci.yml @@ -0,0 +1,164 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + packages: write + +env: + REGISTRY: ghcr.io + IMAGE_NAME: cartsnitch/api + +jobs: + lint: + runs-on: runners-cartsnitch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install ruff + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + + typecheck: + runs-on: runners-cartsnitch + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git" + - run: pip install -e ".[dev]" mypy + - name: Type check + run: mypy src/cartsnitch_api + + test: + runs-on: runners-cartsnitch + services: + postgres: + image: postgres:15-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + env: + POSTGRES_USER: cartsnitch + POSTGRES_PASSWORD: cartsnitch_test + POSTGRES_DB: cartsnitch_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + redis: + image: redis:7-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + CARTSNITCH_REDIS_URL: redis://localhost:6379/0 + CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git" + - run: pip install -e ".[dev]" + - name: Run tests + run: pytest --tb=short -q + + build-and-push: + runs-on: runners-cartsnitch + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Generate CalVer tag + id: calver + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + DATE_TAG=$(date -u +%Y.%m.%d) + EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1) + if [ -z "$EXISTING" ]; then + VERSION="$DATE_TAG" + elif [ "$EXISTING" = "v${DATE_TAG}" ]; then + VERSION="${DATE_TAG}.2" + else + BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//") + VERSION="${DATE_TAG}.$((BUILD_NUM + 1))" + fi + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + echo "CalVer tag: $VERSION" + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Log in to GHCR + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=sha,prefix=sha- + type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + target: prod + + - name: Create git tag + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + git tag "v${{ steps.calver.outputs.version }}" + git push origin "v${{ steps.calver.outputs.version }}" \ No newline at end of file diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..b0492c2 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,9 @@ +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +.venv/ +.env +.pytest_cache/ +.ruff_cache/ +openapi.json diff --git a/api/CLAUDE.md b/api/CLAUDE.md new file mode 100644 index 0000000..fcba89c --- /dev/null +++ b/api/CLAUDE.md @@ -0,0 +1,175 @@ +# CartSnitch API Gateway + +## Project Context + +CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/api`) is the public-facing API gateway that serves the frontend and proxies requests to internal services. + +**GitHub org:** github.com/cartsnitch +**Domain:** cartsnitch.com + +### CartSnitch Services + +| Repo | Service | Purpose | +|------|---------|---------| +| `cartsnitch/common` | — | Shared models, schemas, utilities | +| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers | +| `cartsnitch/api` | API Gateway | Frontend-facing REST API (this repo) | +| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) | +| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | +| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | +| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | +| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations | + +### Architecture Decisions + +- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline. +- **Shared DB:** One PostgreSQL cluster. This service reads from all tables for serving frontend queries. Models come from `cartsnitch-common`. +- **Inter-service comms:** REST to internal services, Redis pub/sub for event subscriptions. +- **Target scale:** 500–1,000 users initially. + +## What This Service Does + +The API Gateway is the single entry point for the frontend PWA and any external consumers. It: + +1. **Handles user authentication** — registration, login, JWT token management +2. **Serves purchase/product/price data** — reads from the shared DB +3. **Proxies scraping operations** — forwards scrape triggers to ReceiptWitness +4. **Serves coupon/deal data** — reads from shared DB (written by ClipArtist) +5. **Serves alerts** — price increase alerts (StickerShock), shrinkflation alerts (ShrinkRay) +6. **Provides public data endpoints** — aggregate price trends for the transparency/shaming features + +## Tech Stack + +- Python 3.12+ +- FastAPI (async) +- SQLAlchemy 2.0 (via `cartsnitch-common`, read-heavy) +- Pydantic v2 (request/response validation) +- python-jose or PyJWT (JWT auth) +- passlib + bcrypt (password hashing) +- httpx (async HTTP client for proxying to internal services) +- Redis (subscribe to events for websocket push, caching) +- uvicorn (ASGI server) + +## Repo Structure + +``` +api/ +├── CLAUDE.md +├── README.md +├── pyproject.toml +├── Dockerfile +├── docker-compose.yml +├── src/ +│ └── cartsnitch_api/ +│ ├── __init__.py +│ ├── config.py # Service-specific settings +│ ├── main.py # FastAPI app factory, lifespan, middleware +│ ├── auth/ +│ │ ├── __init__.py +│ │ ├── jwt.py # JWT creation/validation +│ │ ├── passwords.py # Hashing, verification +│ │ ├── dependencies.py # FastAPI dependency injection (get_current_user) +│ │ └── routes.py # /auth/register, /auth/login, /auth/refresh +│ ├── routes/ +│ │ ├── __init__.py +│ │ ├── purchases.py # Purchase history endpoints +│ │ ├── products.py # Normalized product catalog +│ │ ├── prices.py # Price history and trends +│ │ ├── coupons.py # Active coupons and deals +│ │ ├── alerts.py # Price increase / shrinkflation alerts +│ │ ├── stores.py # Store info, user store account management +│ │ ├── scraping.py # Proxy to ReceiptWitness (trigger scrape, status) +│ │ ├── shopping.py # Optimized shopping list (proxy to ClipArtist) +│ │ ├── public.py # Public price transparency endpoints (no auth) +│ │ └── health.py +│ ├── services/ +│ │ ├── __init__.py +│ │ ├── receiptwitness.py # HTTP client for ReceiptWitness internal API +│ │ ├── stickershock.py # HTTP client for StickerShock internal API +│ │ ├── clipartist.py # HTTP client for ClipArtist internal API +│ │ └── shrinkray.py # HTTP client for ShrinkRay internal API +│ ├── middleware/ +│ │ ├── __init__.py +│ │ ├── cors.py +│ │ └── rate_limit.py +│ └── cache.py # Redis caching helpers +└── tests/ + ├── conftest.py + ├── test_auth/ + ├── test_routes/ + └── test_services/ +``` + +## API Endpoint Design + +### Auth +- `POST /auth/register` — create account +- `POST /auth/login` — get JWT access + refresh tokens +- `POST /auth/refresh` — refresh access token +- `GET /auth/me` — current user profile + +### Store Accounts +- `GET /stores` — list supported stores +- `GET /me/stores` — list user's connected store accounts + sync status +- `POST /me/stores/{store_slug}/connect` — initiate store connection flow +- `DELETE /me/stores/{store_slug}` — disconnect store account + +### Purchases +- `GET /purchases` — list user's purchases (paginated, filterable by store/date) +- `GET /purchases/{id}` — purchase detail with line items +- `GET /purchases/stats` — spending summary (by store, by category, by period) + +### Products +- `GET /products` — normalized product catalog (search, filter) +- `GET /products/{id}` — product detail with cross-store price comparison +- `GET /products/{id}/prices` — price history for a product across stores + +### Prices +- `GET /prices/trends` — aggregate price trends (public-capable) +- `GET /prices/increases` — recent significant price increases +- `GET /prices/comparison` — compare specific items across stores + +### Coupons +- `GET /coupons` — active coupons/deals (filterable by store) +- `GET /coupons/relevant` — coupons relevant to user's purchase history + +### Shopping +- `POST /shopping/optimize` — input: shopping list → output: store-split + coupons +- `GET /shopping/lists` — user's saved shopping lists + +### Alerts +- `GET /alerts` — user's price increase and shrinkflation alerts +- `PUT /alerts/settings` — configure alert thresholds + +### Public (No Auth) +- `GET /public/trends/{product_id}` — public price trend for a product +- `GET /public/store-comparison` — public store-vs-store price comparison +- `GET /public/inflation` — price changes vs CPI baseline + +### Scraping (Proxy to ReceiptWitness) +- `POST /scraping/{store_slug}/sync` — trigger a sync for the current user +- `GET /scraping/status` — sync status across all stores + +## Authentication + +- JWT-based auth with short-lived access tokens (15 min) and longer refresh tokens (7 days). +- Passwords hashed with bcrypt via passlib. +- All user-specific endpoints require a valid JWT in the `Authorization: Bearer` header. +- Public endpoints under `/public/` do not require auth. +- Internal service-to-service calls (ReceiptWitness, etc.) use a shared API key in the `X-Service-Key` header — not user JWTs. + +## Development Workflow + +- **Never push directly to main.** Always create feature branches and open PRs. +- Branch naming: `feature/` or `fix/` +- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` +- OpenAPI docs auto-generated at `/docs` (Swagger) and `/redoc`. +- Write tests for all routes. Use httpx.AsyncClient with FastAPI's TestClient pattern. + +## Important Notes + +- This service is read-heavy on the shared DB. Use async SQLAlchemy sessions. +- Consider Redis caching for expensive queries (price trends, product comparisons). Cache invalidation via Redis pub/sub events from other services. +- Rate limiting on public endpoints is important — these could get hammered if the price transparency features get attention. +- CORS must allow the frontend origin (cartsnitch.com and localhost for dev). +- The store connection flow is the trickiest UX challenge: the user needs to authenticate with each retailer, and we need to capture the resulting session. This likely involves a controlled Playwright browser session that the user can see/interact with, or an OAuth-like redirect flow if the retailer supports it (Kroger does for its public API, but not for purchase history access). diff --git a/api/Dockerfile b/api/Dockerfile new file mode 100644 index 0000000..bb5d3bd --- /dev/null +++ b/api/Dockerfile @@ -0,0 +1,26 @@ +FROM python:3.12-slim AS build + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libpq-dev \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY pyproject.toml ./ +COPY src/ ./src/ +RUN pip install --no-cache-dir --prefix=/install . + +FROM python:3.12-slim AS prod + +WORKDIR /app +RUN adduser --system --group --uid 1000 app +COPY --from=build /install /usr/local +COPY src/ ./src/ + +USER 1000 +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=3s \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" + +CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/api/alembic.ini b/api/alembic.ini new file mode 100644 index 0000000..42fafc3 --- /dev/null +++ b/api/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql://OVERRIDE_VIA_ENV_VAR + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/alembic/env.py b/api/alembic/env.py new file mode 100644 index 0000000..3e563e1 --- /dev/null +++ b/api/alembic/env.py @@ -0,0 +1,55 @@ +"""Alembic environment configuration for CartSnitch.""" + +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context +from cartsnitch_api.models import Base # noqa: F401 — imports all models for autogenerate + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC") +if not db_url: + raise RuntimeError( + "CARTSNITCH_DATABASE_URL_SYNC must be set. " + "Example: postgresql://user:pass@localhost:5432/cartsnitch" + ) +config.set_main_option("sqlalchemy.url", db_url) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/alembic/script.py.mako b/api/alembic/script.py.mako new file mode 100644 index 0000000..fe3b097 --- /dev/null +++ b/api/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/alembic/versions/001_encrypt_session_data.py b/api/alembic/versions/001_encrypt_session_data.py new file mode 100644 index 0000000..4932231 --- /dev/null +++ b/api/alembic/versions/001_encrypt_session_data.py @@ -0,0 +1,89 @@ +"""Encrypt existing plaintext session_data with Fernet. + +Revision ID: 001_encrypt_session_data +Revises: +Create Date: 2026-03-19 +""" + +import json +import os + +import sqlalchemy as sa +from cryptography.fernet import Fernet +from sqlalchemy import text + +from alembic import op + +revision = "001_encrypt_session_data" +down_revision = None +branch_labels = None +depends_on = None + + +def _get_fernet() -> Fernet: + key = os.environ.get("CARTSNITCH_FERNET_KEY") + if not key: + raise RuntimeError("CARTSNITCH_FERNET_KEY must be set to run this migration") + return Fernet(key.encode()) + + +def _is_fernet_token(value: str) -> bool: + """Check if a string looks like a Fernet token (base64 starting with gAAAAA).""" + return value.startswith("gAAAAA") + + +def upgrade() -> None: + # Change column type from JSON to TEXT to hold Fernet ciphertext + op.alter_column( + "user_store_accounts", + "session_data", + type_=sa.Text(), + existing_type=sa.JSON(), + existing_nullable=True, + postgresql_using="session_data::text", + ) + + conn = op.get_bind() + rows = conn.execute( + text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL") + ).fetchall() + + f = _get_fernet() + for row_id, session_data in rows: + raw = str(session_data) + if _is_fernet_token(raw): + continue + plaintext = raw if isinstance(session_data, str) else json.dumps(session_data) + encrypted = f.encrypt(plaintext.encode()).decode() + conn.execute( + text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"), + {"data": encrypted, "id": row_id}, + ) + + +def downgrade() -> None: + conn = op.get_bind() + rows = conn.execute( + text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL") + ).fetchall() + + f = _get_fernet() + for row_id, session_data in rows: + raw = str(session_data) + if not _is_fernet_token(raw): + continue + decrypted = f.decrypt(raw.encode()).decode() + conn.execute( + text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"), + {"data": decrypted, "id": row_id}, + ) + + # Revert column type from TEXT back to JSON + op.alter_column( + "user_store_accounts", + "session_data", + type_=sa.JSON(), + existing_type=sa.Text(), + existing_nullable=True, + postgresql_using="session_data::json", + ) diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 0000000..8509182 --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cartsnitch-api" +version = "0.1.0" +description = "CartSnitch API Gateway — public-facing REST API" +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "pydantic[email]>=2.9.0", + "pydantic-settings>=2.5.0", + "sqlalchemy[asyncio]>=2.0.35", + "asyncpg>=0.30.0", + "alembic>=1.13,<2.0", + "psycopg2>=2.9,<3.0", + "python-jose[cryptography]>=3.3.0", + "passlib[bcrypt]>=1.7.4", + "httpx>=0.27.0", + "redis[hiredis]>=5.2.0", + "cryptography>=43.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.0", + "pytest-asyncio>=0.24.0", + "aiosqlite>=0.20.0", + "httpx>=0.27.0", + "ruff>=0.7.0", + "psycopg2-binary>=2.9,<3.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/cartsnitch_api"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "UP", "B"] + +[tool.ruff.lint.per-file-ignores] +"src/cartsnitch_api/**/routes*.py" = ["B008"] +"src/cartsnitch_api/**/dependencies.py" = ["B008"] + +[tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +warn_return_any = true +warn_unused_configs = true diff --git a/api/renovate.json b/api/renovate.json new file mode 100644 index 0000000..833ba3b --- /dev/null +++ b/api/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["local>cartsnitch/.github:renovate-config"] +} diff --git a/api/src/cartsnitch_api/__init__.py b/api/src/cartsnitch_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/auth/__init__.py b/api/src/cartsnitch_api/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/auth/dependencies.py b/api/src/cartsnitch_api/auth/dependencies.py new file mode 100644 index 0000000..61735ee --- /dev/null +++ b/api/src/cartsnitch_api/auth/dependencies.py @@ -0,0 +1,39 @@ +"""FastAPI dependency injection for authentication.""" + +from uuid import UUID + +from fastapi import Depends, Header, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from cartsnitch_api.auth.jwt import decode_token +from cartsnitch_api.config import settings + +bearer_scheme = HTTPBearer() + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), +) -> UUID: + try: + payload = decode_token(credentials.credentials) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + ) from None + + if payload.get("type") != "access": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + ) from None + + return UUID(payload["sub"]) + + +async def verify_service_key(x_service_key: str = Header()) -> None: + if x_service_key != settings.service_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid service key", + ) diff --git a/api/src/cartsnitch_api/auth/jwt.py b/api/src/cartsnitch_api/auth/jwt.py new file mode 100644 index 0000000..100c77b --- /dev/null +++ b/api/src/cartsnitch_api/auth/jwt.py @@ -0,0 +1,31 @@ +"""JWT token creation and validation.""" + +from datetime import UTC, datetime, timedelta +from typing import Any, cast +from uuid import UUID + +from jose import JWTError, jwt + +from cartsnitch_api.config import settings + + +def create_access_token(user_id: UUID) -> str: + expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes) + payload = {"sub": str(user_id), "exp": expire, "type": "access"} + return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) + + +def create_refresh_token(user_id: UUID) -> str: + expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days) + payload = {"sub": str(user_id), "exp": expire, "type": "refresh"} + return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) + + +def decode_token(token: str) -> dict: + try: + return cast( + dict[str, Any], + jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]), + ) + except JWTError as e: + raise ValueError(f"Invalid token: {e}") from e diff --git a/api/src/cartsnitch_api/auth/passwords.py b/api/src/cartsnitch_api/auth/passwords.py new file mode 100644 index 0000000..180f994 --- /dev/null +++ b/api/src/cartsnitch_api/auth/passwords.py @@ -0,0 +1,11 @@ +"""Password hashing and verification with bcrypt.""" + +import bcrypt + + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) diff --git a/api/src/cartsnitch_api/auth/routes.py b/api/src/cartsnitch_api/auth/routes.py new file mode 100644 index 0000000..ab34c3e --- /dev/null +++ b/api/src/cartsnitch_api/auth/routes.py @@ -0,0 +1,96 @@ +"""Auth routes: register, login, refresh, me, update, delete.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + LoginRequest, + RefreshRequest, + RegisterRequest, + TokenResponse, + UpdateUserRequest, + UserResponse, +) +from cartsnitch_api.services.auth import AuthService + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.register(body.email, body.password, body.display_name) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.post("/login", response_model=TokenResponse) +async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.login(body.email, body.password) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password" + ) from None + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)): + svc = AuthService(db) + try: + return await svc.refresh(body.refresh_token) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" + ) from None + + +@router.get("/me", response_model=UserResponse) +async def get_me( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + return await svc.get_user(user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None + + +@router.patch("/me", response_model=UserResponse) +async def update_me( + body: UpdateUserRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + return await svc.update_user(user_id, email=body.email, display_name=body.display_name) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT) +async def delete_me( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AuthService(db) + try: + await svc.delete_user(user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) from None diff --git a/api/src/cartsnitch_api/cache.py b/api/src/cartsnitch_api/cache.py new file mode 100644 index 0000000..a7fdc81 --- /dev/null +++ b/api/src/cartsnitch_api/cache.py @@ -0,0 +1,26 @@ +"""Redis/DragonflyDB caching helpers.""" + +from cartsnitch_api.config import settings + + +class CacheClient: + """Stub for Redis/DragonflyDB caching. + + Will be used for expensive queries: price trends, product comparisons. + Cache invalidation via Redis pub/sub events from other services. + """ + + def __init__(self) -> None: + self.url = settings.redis_url + + async def get(self, key: str) -> str | None: + # TODO: implement with redis-py async + return None + + async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None: + # TODO: implement with redis-py async + pass + + async def delete(self, key: str) -> None: + # TODO: implement with redis-py async + pass diff --git a/api/src/cartsnitch_api/config.py b/api/src/cartsnitch_api/config.py new file mode 100644 index 0000000..52474b2 --- /dev/null +++ b/api/src/cartsnitch_api/config.py @@ -0,0 +1,51 @@ +import base64 + +from pydantic import model_validator +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + model_config = {"env_prefix": "CARTSNITCH_"} + + database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + redis_url: str = "redis://localhost:6379/0" + + jwt_secret_key: str = "change-me-in-production" + jwt_algorithm: str = "HS256" + jwt_access_token_expire_minutes: int = 15 + jwt_refresh_token_expire_days: int = 7 + + service_key: str = "change-me-in-production" + # Valid Fernet key for local dev — MUST be overridden in production + fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + + cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"] + + receiptwitness_url: str = "http://receiptwitness:8001" + stickershock_url: str = "http://stickershock:8002" + clipartist_url: str = "http://clipartist:8003" + shrinkray_url: str = "http://shrinkray:8004" + + rate_limit_requests: int = 60 + rate_limit_window_seconds: int = 60 + rate_limit_enabled: bool = True + + @model_validator(mode="after") + def validate_fernet_key(self): + """Validate fernet_key is a valid 32-byte url-safe base64 key at startup.""" + try: + decoded = base64.urlsafe_b64decode(self.fernet_key.encode()) + if len(decoded) != 32: + raise ValueError + except Exception: + raise ValueError( + "CARTSNITCH_FERNET_KEY must be a valid Fernet key " + "(32 bytes, url-safe base64 encoded). " + "Generate one with: python -c " + "'from cryptography.fernet import Fernet; " + "print(Fernet.generate_key().decode())'" + ) from None + return self + + +settings = Settings() diff --git a/api/src/cartsnitch_api/constants.py b/api/src/cartsnitch_api/constants.py new file mode 100644 index 0000000..b7a716c --- /dev/null +++ b/api/src/cartsnitch_api/constants.py @@ -0,0 +1,85 @@ +"""Constants and enums shared across CartSnitch services.""" + +from enum import StrEnum + + +class StoreSlug(StrEnum): + """Supported retailer slugs.""" + + MEIJER = "meijer" + KROGER = "kroger" + TARGET = "target" + + +class AccountStatus(StrEnum): + """User store account link status.""" + + ACTIVE = "active" + EXPIRED = "expired" + ERROR = "error" + + +class DiscountType(StrEnum): + """Coupon discount type.""" + + PERCENT = "percent" + FIXED = "fixed" + BOGO = "bogo" + BUY_X_GET_Y = "buy_x_get_y" + + +class PriceSource(StrEnum): + """Source of a price observation.""" + + RECEIPT = "receipt" + CATALOG = "catalog" + WEEKLY_AD = "weekly_ad" + + +class EventType(StrEnum): + """Redis pub/sub event types.""" + + RECEIPTS_INGESTED = "cartsnitch.receipts.ingested" + PRICES_UPDATED = "cartsnitch.prices.updated" + PRODUCTS_NORMALIZED = "cartsnitch.products.normalized" + COUPONS_UPDATED = "cartsnitch.coupons.updated" + ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase" + ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation" + + +class ProductCategory(StrEnum): + """Top-level product categories.""" + + PRODUCE = "produce" + DAIRY = "dairy" + MEAT = "meat" + BAKERY = "bakery" + FROZEN = "frozen" + PANTRY = "pantry" + BEVERAGES = "beverages" + SNACKS = "snacks" + HOUSEHOLD = "household" + PERSONAL_CARE = "personal_care" + OTHER = "other" + + +class MatchConfidence(StrEnum): + """Confidence level for product matching.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class SizeUnit(StrEnum): + """Standardized product size units.""" + + OZ = "oz" + FL_OZ = "fl_oz" + LB = "lb" + G = "g" + KG = "kg" + ML = "ml" + L = "l" + CT = "ct" + PK = "pk" diff --git a/api/src/cartsnitch_api/database.py b/api/src/cartsnitch_api/database.py new file mode 100644 index 0000000..324c5bf --- /dev/null +++ b/api/src/cartsnitch_api/database.py @@ -0,0 +1,16 @@ +"""Database session management for the API gateway.""" + +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from cartsnitch_api.config import settings + +engine = create_async_engine(settings.database_url, echo=False) +async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency that yields an async DB session.""" + async with async_session_factory() as session: + yield session diff --git a/api/src/cartsnitch_api/main.py b/api/src/cartsnitch_api/main.py new file mode 100644 index 0000000..1cd54ef --- /dev/null +++ b/api/src/cartsnitch_api/main.py @@ -0,0 +1,62 @@ +"""FastAPI app factory for CartSnitch API Gateway.""" + +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from cartsnitch_api.auth.routes import router as auth_router +from cartsnitch_api.middleware.cors import add_cors_middleware +from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware +from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware +from cartsnitch_api.routes.alerts import router as alerts_router +from cartsnitch_api.routes.coupons import router as coupons_router +from cartsnitch_api.routes.health import router as health_router +from cartsnitch_api.routes.prices import router as prices_router +from cartsnitch_api.routes.products import router as products_router +from cartsnitch_api.routes.public import router as public_router +from cartsnitch_api.routes.purchases import router as purchases_router +from cartsnitch_api.routes.scraping import router as scraping_router +from cartsnitch_api.routes.shopping import router as shopping_router +from cartsnitch_api.routes.stores import router as stores_router + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # TODO: initialize DB session pool, Redis connection, service clients + yield + # TODO: cleanup connections + + +def create_app() -> FastAPI: + app = FastAPI( + title="CartSnitch API", + description="Grocery price tracking and shrinkflation detection API", + version="0.1.0", + lifespan=lifespan, + ) + + # Middleware (order matters — outermost first) + add_cors_middleware(app) + add_error_monitor_middleware(app) + add_rate_limit_middleware(app) + + # Exception handlers + add_error_handlers(app) + + # Routers + app.include_router(health_router) + app.include_router(auth_router) + app.include_router(stores_router) + app.include_router(purchases_router) + app.include_router(products_router) + app.include_router(prices_router) + app.include_router(coupons_router) + app.include_router(shopping_router) + app.include_router(alerts_router) + app.include_router(scraping_router) + app.include_router(public_router) + + return app + + +app = create_app() diff --git a/api/src/cartsnitch_api/middleware/__init__.py b/api/src/cartsnitch_api/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/middleware/cors.py b/api/src/cartsnitch_api/middleware/cors.py new file mode 100644 index 0000000..0e6a4ae --- /dev/null +++ b/api/src/cartsnitch_api/middleware/cors.py @@ -0,0 +1,16 @@ +"""CORS middleware configuration.""" + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from cartsnitch_api.config import settings + + +def add_cors_middleware(app: FastAPI) -> None: + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) diff --git a/api/src/cartsnitch_api/middleware/error_handler.py b/api/src/cartsnitch_api/middleware/error_handler.py new file mode 100644 index 0000000..a32a008 --- /dev/null +++ b/api/src/cartsnitch_api/middleware/error_handler.py @@ -0,0 +1,190 @@ +"""Structured error responses and error monitoring. + +Ensures all errors return a consistent JSON shape and never leak stack traces. +Provides hooks for error monitoring/alerting. +""" + +import logging +import time +import traceback +from collections.abc import Awaitable, Callable + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger("cartsnitch_api.errors") + + +def _error_response( + status_code: int, + detail: str, + code: str | None = None, + errors: list[dict] | None = None, +) -> JSONResponse: + """Build a consistent error response.""" + body: dict = {"detail": detail} + if code: + body["code"] = code + if errors: + body["errors"] = errors + return JSONResponse(status_code=status_code, content=body) + + +def add_error_handlers(app: FastAPI) -> None: + """Register global exception handlers for consistent error responses.""" + + @app.exception_handler(RequestValidationError) + async def validation_error_handler( + request: Request, exc: RequestValidationError + ) -> JSONResponse: + """Return 422 with structured field-level error details.""" + field_errors = [] + for err in exc.errors(): + loc = err.get("loc", ()) + field_errors.append( + { + "field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc), + "message": err.get("msg", "Invalid value"), + "type": err.get("type", "value_error"), + } + ) + return _error_response( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Validation error", + code="VALIDATION_ERROR", + errors=field_errors, + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse: + """Wrap HTTP exceptions (Starlette and FastAPI) in consistent format.""" + detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail) + return _error_response( + status_code=exc.status_code, + detail=detail, + code=_status_to_code(exc.status_code), + ) + + @app.exception_handler(Exception) + async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Catch-all: log full traceback, return safe 500 to client.""" + logger.error( + "Unhandled exception on %s %s: %s\n%s", + request.method, + request.url.path, + exc, + traceback.format_exc(), + ) + _notify_error_monitor(request, exc) + + return _error_response( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error", + code="INTERNAL_ERROR", + ) + + +def _status_to_code(status_code: int) -> str: + """Map HTTP status code to a machine-readable error code.""" + mapping = { + 400: "BAD_REQUEST", + 401: "UNAUTHORIZED", + 403: "FORBIDDEN", + 404: "NOT_FOUND", + 409: "CONFLICT", + 422: "VALIDATION_ERROR", + 429: "RATE_LIMITED", + 502: "BAD_GATEWAY", + 503: "SERVICE_UNAVAILABLE", + } + return mapping.get(status_code, f"HTTP_{status_code}") + + +# ---------- Error Monitoring ---------- + + +class _ErrorMonitor: + """Simple error counter for monitoring and alerting hooks. + + Tracks error counts and rates. In production, this would forward + to an external monitoring service (Prometheus, Sentry, etc.). + """ + + def __init__(self) -> None: + self.error_counts: dict[int, int] = {} + self.recent_5xx: list[dict] = [] + self._max_recent = 100 + + def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None: + self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1 + + if status_code >= 500: + entry = { + "timestamp": time.time(), + "status": status_code, + "path": path, + "method": method, + "error": error, + } + self.recent_5xx.append(entry) + if len(self.recent_5xx) > self._max_recent: + self.recent_5xx = self.recent_5xx[-self._max_recent :] + + logger.warning( + "5xx error recorded: %s %s -> %d (%s)", + method, + path, + status_code, + error or "unknown", + ) + + def get_stats(self) -> dict: + return { + "error_counts": dict(self.error_counts), + "recent_5xx_count": len(self.recent_5xx), + } + + +_monitor = _ErrorMonitor() + + +def get_error_monitor() -> _ErrorMonitor: + """Access the global error monitor (for health/metrics endpoints).""" + return _monitor + + +def _notify_error_monitor(request: Request, exc: Exception) -> None: + """Record unhandled exception in the error monitor.""" + _monitor.record( + status_code=500, + path=request.url.path, + method=request.method, + error=str(exc)[:200], + ) + + +class ErrorMonitorMiddleware(BaseHTTPMiddleware): + """Middleware to track all 4xx/5xx responses for monitoring.""" + + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable], + ): + response = await call_next(request) + + if response.status_code >= 400: + _monitor.record( + status_code=response.status_code, + path=request.url.path, + method=request.method, + ) + + return response + + +def add_error_monitor_middleware(app: FastAPI) -> None: + app.add_middleware(ErrorMonitorMiddleware) diff --git a/api/src/cartsnitch_api/middleware/rate_limit.py b/api/src/cartsnitch_api/middleware/rate_limit.py new file mode 100644 index 0000000..424ed19 --- /dev/null +++ b/api/src/cartsnitch_api/middleware/rate_limit.py @@ -0,0 +1,111 @@ +"""Rate limiting middleware for public and authenticated endpoints. + +Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. +Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. +""" + +import time +from collections import defaultdict +from threading import Lock + +from fastapi import FastAPI, Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from cartsnitch_api.config import settings + + +class _SlidingWindowCounter: + """Thread-safe in-memory sliding window rate limiter.""" + + def __init__(self, max_requests: int, window_seconds: int) -> None: + self.max_requests = max_requests + self.window_seconds = window_seconds + self._hits: dict[str, list[float]] = defaultdict(list) + self._lock = Lock() + + def is_allowed(self, key: str) -> tuple[bool, int, int]: + """Check if request is allowed. Returns (allowed, remaining, retry_after).""" + now = time.monotonic() + cutoff = now - self.window_seconds + + with self._lock: + # Prune expired entries + self._hits[key] = [t for t in self._hits[key] if t > cutoff] + + current_count = len(self._hits[key]) + if current_count >= self.max_requests: + retry_after = int(self._hits[key][0] - cutoff) + 1 + return False, 0, retry_after + + self._hits[key].append(now) + remaining = self.max_requests - current_count - 1 + return True, remaining, 0 + + +# Module-level counters — one for public (per-IP), one for auth (per-token) +_public_limiter = _SlidingWindowCounter( + max_requests=settings.rate_limit_requests, + window_seconds=settings.rate_limit_window_seconds, +) +_auth_limiter = _SlidingWindowCounter( + max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users + window_seconds=settings.rate_limit_window_seconds, +) + + +def _get_client_ip(request: Request) -> str: + """Extract client IP, respecting X-Forwarded-For behind a reverse proxy.""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + +def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: + """Determine rate limit key and which limiter to use.""" + if request.url.path.startswith("/public"): + return f"ip:{_get_client_ip(request)}", _public_limiter + + # For authenticated endpoints, use Bearer token as key if present + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + # Use last 16 chars of token as key to avoid storing full tokens + return f"token:{token[-16:]}", _auth_limiter + + # Fallback to IP for unauthenticated non-public endpoints + return f"ip:{_get_client_ip(request)}", _public_limiter + + +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Skip rate limiting when disabled (e.g. in tests) or for health checks + if not settings.rate_limit_enabled or request.url.path == "/health": + return await call_next(request) + + key, limiter = _get_rate_limit_key(request) + allowed, remaining, retry_after = limiter.is_allowed(key) + + if not allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={ + "detail": "Rate limit exceeded", + "code": "RATE_LIMITED", + }, + headers={ + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limiter.max_requests), + "X-RateLimit-Remaining": "0", + }, + ) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limiter.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response + + +def add_rate_limit_middleware(app: FastAPI) -> None: + app.add_middleware(RateLimitMiddleware) diff --git a/api/src/cartsnitch_api/models/__init__.py b/api/src/cartsnitch_api/models/__init__.py new file mode 100644 index 0000000..d037b05 --- /dev/null +++ b/api/src/cartsnitch_api/models/__init__.py @@ -0,0 +1,26 @@ +"""SQLAlchemy ORM models — re-exports all models for convenience.""" + +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin +from cartsnitch_api.models.coupon import Coupon +from cartsnitch_api.models.price import PriceHistory +from cartsnitch_api.models.product import NormalizedProduct +from cartsnitch_api.models.purchase import Purchase, PurchaseItem +from cartsnitch_api.models.shrinkflation import ShrinkflationEvent +from cartsnitch_api.models.store import Store, StoreLocation +from cartsnitch_api.models.user import User, UserStoreAccount + +__all__ = [ + "Base", + "TimestampMixin", + "UUIDPrimaryKeyMixin", + "Store", + "StoreLocation", + "User", + "UserStoreAccount", + "Purchase", + "PurchaseItem", + "NormalizedProduct", + "PriceHistory", + "Coupon", + "ShrinkflationEvent", +] diff --git a/api/src/cartsnitch_api/models/base.py b/api/src/cartsnitch_api/models/base.py new file mode 100644 index 0000000..f93cf79 --- /dev/null +++ b/api/src/cartsnitch_api/models/base.py @@ -0,0 +1,30 @@ +"""Base model and mixins for all CartSnitch ORM models.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Base class for all CartSnitch models.""" + + +class TimestampMixin: + """Mixin providing created_at / updated_at columns.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + ) + + +class UUIDPrimaryKeyMixin: + """Mixin providing a UUID primary key.""" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + ) diff --git a/api/src/cartsnitch_api/models/coupon.py b/api/src/cartsnitch_api/models/coupon.py new file mode 100644 index 0000000..df2630a --- /dev/null +++ b/api/src/cartsnitch_api/models/coupon.py @@ -0,0 +1,42 @@ +"""Coupon model.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import DiscountType +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.store import Store + + +class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A coupon or deal for a product at a store.""" + + __tablename__ = "coupons" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + title: Mapped[str] = mapped_column(String(300), nullable=False) + description: Mapped[str | None] = mapped_column(String(1000)) + discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False) + discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + valid_from: Mapped[date | None] = mapped_column(Date) + valid_to: Mapped[date | None] = mapped_column(Date) + requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + coupon_code: Mapped[str | None] = mapped_column(String(100)) + source_url: Mapped[str | None] = mapped_column(String(500)) + scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="coupons") + normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons") diff --git a/api/src/cartsnitch_api/models/price.py b/api/src/cartsnitch_api/models/price.py new file mode 100644 index 0000000..7da0fa6 --- /dev/null +++ b/api/src/cartsnitch_api/models/price.py @@ -0,0 +1,50 @@ +"""PriceHistory model — tracks product prices over time.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Index, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import PriceSource +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.purchase import PurchaseItem + from cartsnitch_api.models.store import Store + + +class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single price observation for a product at a store on a date.""" + + __tablename__ = "price_history" + __table_args__ = ( + Index( + "ix_price_history_product_store_date", + "normalized_product_id", + "store_id", + "observed_date", + ), + ) + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + observed_date: Mapped[date] = mapped_column(Date, nullable=False) + regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source: Mapped[PriceSource] = mapped_column(String(20), nullable=False) + purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id")) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories") + store: Mapped["Store"] = relationship(back_populates="price_histories") + purchase_item: Mapped["PurchaseItem | None"] = relationship( + back_populates="price_history_entries" + ) diff --git a/api/src/cartsnitch_api/models/product.py b/api/src/cartsnitch_api/models/product.py new file mode 100644 index 0000000..4061132 --- /dev/null +++ b/api/src/cartsnitch_api/models/product.py @@ -0,0 +1,39 @@ +"""NormalizedProduct model — the canonical product identity.""" + +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import ProductCategory, SizeUnit +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.coupon import Coupon + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.purchase import PurchaseItem + from cartsnitch_api.models.shrinkflation import ShrinkflationEvent + + +class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Canonical product identity — matches products across retailers.""" + + __tablename__ = "normalized_products" + + canonical_name: Mapped[str] = mapped_column(String(300), nullable=False) + category: Mapped[ProductCategory | None] = mapped_column(String(50)) + subcategory: Mapped[str | None] = mapped_column(String(100)) + brand: Mapped[str | None] = mapped_column(String(200)) + size: Mapped[str | None] = mapped_column(String(50)) + size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) + upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) + + # Relationships + purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") + price_histories: Mapped[list["PriceHistory"]] = relationship( + back_populates="normalized_product" + ) + coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product") + shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship( + back_populates="normalized_product" + ) diff --git a/api/src/cartsnitch_api/models/purchase.py b/api/src/cartsnitch_api/models/purchase.py new file mode 100644 index 0000000..f57fde9 --- /dev/null +++ b/api/src/cartsnitch_api/models/purchase.py @@ -0,0 +1,91 @@ +"""Purchase and PurchaseItem models.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import ( + JSON, + Date, + DateTime, + ForeignKey, + Index, + Numeric, + String, + UniqueConstraint, + func, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.product import NormalizedProduct + from cartsnitch_api.models.store import Store, StoreLocation + from cartsnitch_api.models.user import User + + +class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single shopping trip / receipt.""" + + __tablename__ = "purchases" + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id")) + receipt_id: Mapped[str] = mapped_column(String(200), nullable=False) + purchase_date: Mapped[date] = mapped_column(Date, nullable=False) + total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source_url: Mapped[str | None] = mapped_column(String(500)) + raw_data: Mapped[dict | None] = mapped_column(JSON) + ingested_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="purchases") + store: Mapped["Store"] = relationship(back_populates="purchases") + store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases") + items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase") + + __table_args__ = ( + Index("ix_purchases_user_store", "user_id", "store_id"), + UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"), + ) + + +class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Individual line item on a receipt.""" + + __tablename__ = "purchase_items" + + purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False) + product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False) + upc: Mapped[str | None] = mapped_column(String(20)) + quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1) + unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + category_raw: Mapped[str | None] = mapped_column(String(100)) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + + # Relationships + purchase: Mapped["Purchase"] = relationship(back_populates="items") + normalized_product: Mapped["NormalizedProduct | None"] = relationship( + back_populates="purchase_items" + ) + price_history_entries: Mapped[list["PriceHistory"]] = relationship( + back_populates="purchase_item" + ) diff --git a/api/src/cartsnitch_api/models/shrinkflation.py b/api/src/cartsnitch_api/models/shrinkflation.py new file mode 100644 index 0000000..2ce6f9d --- /dev/null +++ b/api/src/cartsnitch_api/models/shrinkflation.py @@ -0,0 +1,41 @@ +"""ShrinkflationEvent model.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import SizeUnit +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.product import NormalizedProduct + + +class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Detected shrinkflation event — product size changed while price held or rose.""" + + __tablename__ = "shrinkflation_events" + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + detected_date: Mapped[date] = mapped_column(Date, nullable=False) + old_size: Mapped[str] = mapped_column(String(50), nullable=False) + new_size: Mapped[str] = mapped_column(String(50), nullable=False) + old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + confidence: Mapped[Decimal] = mapped_column( + Numeric(3, 2), nullable=False, default=Decimal("1.00") + ) + notes: Mapped[str | None] = mapped_column(String(1000)) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship( + back_populates="shrinkflation_events" + ) diff --git a/api/src/cartsnitch_api/models/store.py b/api/src/cartsnitch_api/models/store.py new file mode 100644 index 0000000..f75897f --- /dev/null +++ b/api/src/cartsnitch_api/models/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation models.""" + +import uuid +from typing import TYPE_CHECKING + +from sqlalchemy import Float, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import StoreSlug +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_api.models.coupon import Coupon + from cartsnitch_api.models.price import PriceHistory + from cartsnitch_api.models.purchase import Purchase + from cartsnitch_api.models.user import UserStoreAccount + + +class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Supported retailer.""" + + __tablename__ = "stores" + + name: Mapped[str] = mapped_column(String(100), nullable=False) + slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True) + logo_url: Mapped[str | None] = mapped_column(String(500)) + website_url: Mapped[str | None] = mapped_column(String(500)) + + # Relationships + locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store") + user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store") + price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store") + coupons: Mapped[list["Coupon"]] = relationship(back_populates="store") + + +class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Physical store location.""" + + __tablename__ = "store_locations" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + address: Mapped[str] = mapped_column(String(300), nullable=False) + city: Mapped[str] = mapped_column(String(100), nullable=False) + state: Mapped[str] = mapped_column(String(2), nullable=False) + zip: Mapped[str] = mapped_column(String(10), nullable=False) + lat: Mapped[float | None] = mapped_column(Float) + lng: Mapped[float | None] = mapped_column(Float) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="locations") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location") diff --git a/api/src/cartsnitch_api/models/user.py b/api/src/cartsnitch_api/models/user.py new file mode 100644 index 0000000..56482b0 --- /dev/null +++ b/api/src/cartsnitch_api/models/user.py @@ -0,0 +1,50 @@ +"""User and UserStoreAccount models.""" + +import uuid +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_api.constants import AccountStatus +from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin +from cartsnitch_api.types import EncryptedJSON + +if TYPE_CHECKING: + from cartsnitch_api.models.purchase import Purchase + from cartsnitch_api.models.store import Store + + +class User(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Application user.""" + + __tablename__ = "users" + + email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100)) + + # Relationships + store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="user") + + +class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Link between a user and their retailer account credentials.""" + + __tablename__ = "user_store_accounts" + __table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),) + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + session_data: Mapped[dict | None] = mapped_column(EncryptedJSON) + session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + status: Mapped[AccountStatus] = mapped_column( + String(20), nullable=False, default=AccountStatus.ACTIVE + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="store_accounts") + store: Mapped["Store"] = relationship(back_populates="user_accounts") diff --git a/api/src/cartsnitch_api/routes/__init__.py b/api/src/cartsnitch_api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/routes/alerts.py b/api/src/cartsnitch_api/routes/alerts.py new file mode 100644 index 0000000..45ab33f --- /dev/null +++ b/api/src/cartsnitch_api/routes/alerts.py @@ -0,0 +1,44 @@ +"""Alert routes: list alerts, manage settings.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse +from cartsnitch_api.services.alerts import AlertService + +router = APIRouter(prefix="/alerts", tags=["alerts"]) + + +@router.get("", response_model=list[AlertResponse]) +async def list_alerts( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AlertService(db) + return await svc.list_alerts(user_id) + + +@router.get("/settings", response_model=AlertSettingsResponse) +async def get_alert_settings( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = AlertService(db) + return await svc.get_settings(user_id) + + +@router.put("/settings") +async def update_alert_settings( + body: AlertSettingsRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Alert settings persistence not yet implemented. " + "Use GET /alerts/settings for current defaults.", + ) diff --git a/api/src/cartsnitch_api/routes/coupons.py b/api/src/cartsnitch_api/routes/coupons.py new file mode 100644 index 0000000..d33d98a --- /dev/null +++ b/api/src/cartsnitch_api/routes/coupons.py @@ -0,0 +1,32 @@ +"""Coupon routes: browse, relevant matches.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import CouponResponse +from cartsnitch_api.services.coupons import CouponService + +router = APIRouter(prefix="/coupons", tags=["coupons"]) + + +@router.get("", response_model=list[CouponResponse]) +async def list_coupons( + store_id: UUID | None = Query(None), + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = CouponService(db) + return await svc.list_coupons(store_id) + + +@router.get("/relevant", response_model=list[CouponResponse]) +async def relevant_coupons( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = CouponService(db) + return await svc.relevant_coupons(user_id) diff --git a/api/src/cartsnitch_api/routes/health.py b/api/src/cartsnitch_api/routes/health.py new file mode 100644 index 0000000..0574b10 --- /dev/null +++ b/api/src/cartsnitch_api/routes/health.py @@ -0,0 +1,20 @@ +"""Health check and error metrics endpoints.""" + +from fastapi import APIRouter, Depends + +from cartsnitch_api.auth.dependencies import verify_service_key +from cartsnitch_api.middleware.error_handler import get_error_monitor + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +async def health(): + return {"status": "ok"} + + +@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)]) +async def error_stats(): + """Error monitoring stats — internal only (requires X-Service-Key).""" + monitor = get_error_monitor() + return monitor.get_stats() diff --git a/api/src/cartsnitch_api/routes/prices.py b/api/src/cartsnitch_api/routes/prices.py new file mode 100644 index 0000000..487dd92 --- /dev/null +++ b/api/src/cartsnitch_api/routes/prices.py @@ -0,0 +1,47 @@ +"""Price routes: trends, increases, comparison.""" + +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + PriceComparisonResponse, + PriceIncreaseResponse, + PriceTrendResponse, +) +from cartsnitch_api.services.prices import PriceService + +router = APIRouter(prefix="/prices", tags=["prices"]) + + +@router.get("/trends", response_model=list[PriceTrendResponse]) +async def price_trends( + user_id: UUID = Depends(get_current_user), + category: str | None = Query(None), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_trends(category) + + +@router.get("/increases", response_model=list[PriceIncreaseResponse]) +async def price_increases( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_increases() + + +@router.get("/comparison", response_model=list[PriceComparisonResponse]) +async def price_comparison( + product_ids: Annotated[list[UUID], Query()], + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PriceService(db) + return await svc.get_comparison(product_ids) diff --git a/api/src/cartsnitch_api/routes/products.py b/api/src/cartsnitch_api/routes/products.py new file mode 100644 index 0000000..473cefe --- /dev/null +++ b/api/src/cartsnitch_api/routes/products.py @@ -0,0 +1,56 @@ +"""Product routes: search/list, detail, price history.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse +from cartsnitch_api.services.products import ProductService + +router = APIRouter(prefix="/products", tags=["products"]) + + +@router.get("", response_model=list[ProductResponse]) +async def list_products( + user_id: UUID = Depends(get_current_user), + q: str | None = Query(None), + category: str | None = Query(None), + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + return await svc.list_products(q, category, page, page_size) + + +@router.get("/{product_id}", response_model=ProductDetailResponse) +async def get_product( + product_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + try: + return await svc.get_product(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None + + +@router.get("/{product_id}/prices", response_model=PriceTrendResponse) +async def get_product_prices( + product_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = ProductService(db) + try: + return await svc.get_price_history(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None diff --git a/api/src/cartsnitch_api/routes/public.py b/api/src/cartsnitch_api/routes/public.py new file mode 100644 index 0000000..5d0b87b --- /dev/null +++ b/api/src/cartsnitch_api/routes/public.py @@ -0,0 +1,48 @@ +"""Public endpoints: price transparency data (no auth required).""" + +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ( + PublicInflationResponse, + PublicStoreComparisonResponse, + PublicTrendResponse, +) +from cartsnitch_api.services.public import PublicService + +router = APIRouter(prefix="/public", tags=["public"]) + + +@router.get("/trends/{product_id}", response_model=PublicTrendResponse) +async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): + svc = PublicService(db) + try: + return await svc.get_trend(product_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" + ) from None + + +@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) +async def public_store_comparison( + product_ids: Annotated[list[UUID], Query(max_length=20)], + db: AsyncSession = Depends(get_db), +): + if not product_ids: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one product_id is required", + ) + svc = PublicService(db) + return await svc.get_store_comparison(product_ids) + + +@router.get("/inflation", response_model=PublicInflationResponse) +async def public_inflation(db: AsyncSession = Depends(get_db)): + svc = PublicService(db) + return await svc.get_inflation() diff --git a/api/src/cartsnitch_api/routes/purchases.py b/api/src/cartsnitch_api/routes/purchases.py new file mode 100644 index 0000000..eba86ac --- /dev/null +++ b/api/src/cartsnitch_api/routes/purchases.py @@ -0,0 +1,49 @@ +"""Purchase routes: list, detail, stats.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse +from cartsnitch_api.services.purchases import PurchaseService + +router = APIRouter(prefix="/purchases", tags=["purchases"]) + + +@router.get("", response_model=list[PurchaseResponse]) +async def list_purchases( + user_id: UUID = Depends(get_current_user), + store_id: UUID | None = Query(None), + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + return await svc.list_purchases(user_id, store_id, page, page_size) + + +@router.get("/stats", response_model=PurchaseStatsResponse) +async def purchase_stats( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + return await svc.get_stats(user_id) + + +@router.get("/{purchase_id}", response_model=PurchaseDetailResponse) +async def get_purchase( + purchase_id: UUID, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = PurchaseService(db) + try: + return await svc.get_purchase(purchase_id, user_id) + except LookupError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found" + ) from None diff --git a/api/src/cartsnitch_api/routes/scraping.py b/api/src/cartsnitch_api/routes/scraping.py new file mode 100644 index 0000000..d8bbd5f --- /dev/null +++ b/api/src/cartsnitch_api/routes/scraping.py @@ -0,0 +1,42 @@ +"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness).""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from httpx import HTTPStatusError, RequestError + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse +from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient + +router = APIRouter(prefix="/scraping", tags=["scraping"]) + + +@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse) +async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)): + client = ReceiptWitnessClient() + try: + result = await client.trigger_sync(str(user_id), store_slug) + return result + except HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, + detail="Sync service error", + ) from e + except RequestError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach sync service", + ) from None + + +@router.get("/status", response_model=list[SyncStatusResponse]) +async def sync_status(user_id: UUID = Depends(get_current_user)): + client = ReceiptWitnessClient() + try: + return await client.get_sync_status(str(user_id)) + except (HTTPStatusError, RequestError): + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach sync service", + ) from None diff --git a/api/src/cartsnitch_api/routes/shopping.py b/api/src/cartsnitch_api/routes/shopping.py new file mode 100644 index 0000000..c64d5fd --- /dev/null +++ b/api/src/cartsnitch_api/routes/shopping.py @@ -0,0 +1,48 @@ +"""Shopping routes: optimize list, saved lists.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from httpx import HTTPStatusError, RequestError + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse +from cartsnitch_api.services.clipartist import ClipArtistClient + +router = APIRouter(prefix="/shopping", tags=["shopping"]) + + +@router.post("/optimize", response_model=OptimizeResponse) +async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)): + client = ClipArtistClient() + try: + result = await client.optimize( + user_id=str(user_id), + items=[item.model_dump() for item in body.items], + preferred_stores=( + [str(s) for s in body.preferred_stores] if body.preferred_stores else None + ), + ) + return result + except HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, + detail="Shopping optimization service error", + ) from e + except RequestError: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach shopping optimization service", + ) from None + + +@router.get("/lists", response_model=list[ShoppingListResponse]) +async def list_shopping_lists(user_id: UUID = Depends(get_current_user)): + client = ClipArtistClient() + try: + return await client.get_shopping_lists(str(user_id)) + except (HTTPStatusError, RequestError): + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Unable to reach shopping service", + ) from None diff --git a/api/src/cartsnitch_api/routes/stores.py b/api/src/cartsnitch_api/routes/stores.py new file mode 100644 index 0000000..1ab7947 --- /dev/null +++ b/api/src/cartsnitch_api/routes/stores.py @@ -0,0 +1,61 @@ +"""Store routes: list stores, manage user store connections.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.dependencies import get_current_user +from cartsnitch_api.database import get_db +from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse +from cartsnitch_api.services.stores import StoreService + +router = APIRouter(tags=["stores"]) + + +@router.get("/stores", response_model=list[StoreResponse]) +async def list_stores(db: AsyncSession = Depends(get_db)): + svc = StoreService(db) + return await svc.list_stores() + + +@router.get("/me/stores", response_model=list[StoreAccountResponse]) +async def list_user_stores( + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + return await svc.list_user_stores(user_id) + + +@router.post( + "/me/stores/{store_slug}/connect", + response_model=StoreAccountResponse, + status_code=status.HTTP_201_CREATED, +) +async def connect_store( + store_slug: str, + body: ConnectStoreRequest, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + try: + return await svc.connect_store(user_id, store_slug, body.credentials) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT) +async def disconnect_store( + store_slug: str, + user_id: UUID = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + svc = StoreService(db) + try: + await svc.disconnect_store(user_id, store_slug) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e diff --git a/api/src/cartsnitch_api/schemas.py b/api/src/cartsnitch_api/schemas.py new file mode 100644 index 0000000..19e351a --- /dev/null +++ b/api/src/cartsnitch_api/schemas.py @@ -0,0 +1,291 @@ +"""Pydantic v2 request/response schemas for all API endpoints.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, EmailStr, Field + +# ---------- Auth ---------- + + +class RegisterRequest(BaseModel): + email: EmailStr + password: str = Field(min_length=8, max_length=128) + display_name: str = Field(min_length=1, max_length=100) + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class RefreshRequest(BaseModel): + refresh_token: str + + +class TokenResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int + + +class UpdateUserRequest(BaseModel): + email: EmailStr | None = None + display_name: str | None = Field(None, min_length=1, max_length=100) + + +class UserResponse(BaseModel): + id: UUID + email: str + display_name: str + created_at: datetime + + +# ---------- Stores ---------- + + +class StoreResponse(BaseModel): + id: UUID + name: str + slug: str + logo_url: str | None = None + supported: bool = True + + +class StoreAccountResponse(BaseModel): + store: StoreResponse + connected: bool + last_sync_at: datetime | None = None + sync_status: str | None = None + + +class ConnectStoreRequest(BaseModel): + credentials: dict | None = None + + +# ---------- Purchases ---------- + + +class LineItemResponse(BaseModel): + id: UUID + product_id: UUID | None = None + name: str + quantity: float + unit_price: float + total_price: float + + +class PurchaseResponse(BaseModel): + id: UUID + store_id: UUID + store_name: str + purchased_at: datetime + total: float + item_count: int + + +class PurchaseDetailResponse(PurchaseResponse): + line_items: list[LineItemResponse] + + +class PurchaseStatsResponse(BaseModel): + total_spent: float + purchase_count: int + by_store: dict[str, float] + by_period: dict[str, float] + + +# ---------- Products ---------- + + +class ProductResponse(BaseModel): + id: UUID + name: str + brand: str | None = None + category: str | None = None + upc: str | None = None + image_url: str | None = None + + +class ProductDetailResponse(ProductResponse): + prices_by_store: list["StorePriceResponse"] + + +class StorePriceResponse(BaseModel): + store_id: UUID + store_name: str + current_price: float + last_seen_at: datetime + + +# ---------- Prices ---------- + + +class PriceTrendResponse(BaseModel): + product_id: UUID + product_name: str + data_points: list["PricePointResponse"] + + +class PricePointResponse(BaseModel): + date: datetime + price: float + store_id: UUID + store_name: str + + +class PriceIncreaseResponse(BaseModel): + product_id: UUID + product_name: str + store_name: str + old_price: float + new_price: float + increase_pct: float + detected_at: datetime + + +class PriceComparisonResponse(BaseModel): + product_id: UUID + product_name: str + prices: list[StorePriceResponse] + + +# ---------- Coupons ---------- + + +class CouponResponse(BaseModel): + id: UUID + store_id: UUID + store_name: str + description: str + discount_value: float + discount_type: str + product_id: UUID | None = None + expires_at: datetime | None = None + + +# ---------- Shopping ---------- + + +class ShoppingListItemRequest(BaseModel): + product_id: UUID | None = None + name: str + quantity: int = 1 + + +class OptimizeRequest(BaseModel): + items: list[ShoppingListItemRequest] + preferred_stores: list[UUID] | None = None + + +class OptimizedStoreTrip(BaseModel): + store_id: UUID + store_name: str + items: list["OptimizedItemResponse"] + subtotal: float + coupons: list[CouponResponse] + savings: float + + +class OptimizedItemResponse(BaseModel): + name: str + price: float + product_id: UUID | None = None + + +class OptimizeResponse(BaseModel): + trips: list[OptimizedStoreTrip] + total_cost: float + total_savings: float + + +class ShoppingListResponse(BaseModel): + id: UUID + name: str + item_count: int + created_at: datetime + updated_at: datetime + + +# ---------- Alerts ---------- + + +class AlertResponse(BaseModel): + id: UUID + alert_type: str + product_id: UUID + product_name: str + message: str + triggered_at: datetime + read: bool = False + + +class AlertSettingsRequest(BaseModel): + price_increase_threshold_pct: float | None = None + shrinkflation_enabled: bool | None = None + email_notifications: bool | None = None + + +class AlertSettingsResponse(BaseModel): + price_increase_threshold_pct: float + shrinkflation_enabled: bool + email_notifications: bool + + +# ---------- Scraping ---------- + + +class SyncTriggerResponse(BaseModel): + job_id: UUID + status: str + message: str + + +class SyncStatusResponse(BaseModel): + store_slug: str + status: str + last_sync_at: datetime | None = None + items_synced: int | None = None + + +# ---------- Public ---------- + + +class PublicTrendResponse(BaseModel): + product_id: UUID + product_name: str + data_points: list[PricePointResponse] + + +class PublicStoreComparisonResponse(BaseModel): + products: list[PriceComparisonResponse] + + +class PublicInflationResponse(BaseModel): + period: str + cartsnitch_index: float + cpi_baseline: float + categories: dict[str, float] + + +# ---------- Common ---------- + + +class PaginatedResponse(BaseModel): + items: list + total: int + page: int + page_size: int + pages: int + + +class ErrorResponse(BaseModel): + detail: str + code: str | None = None + + +# Rebuild forward refs +ProductDetailResponse.model_rebuild() +PriceTrendResponse.model_rebuild() +OptimizedStoreTrip.model_rebuild() diff --git a/api/src/cartsnitch_api/services/__init__.py b/api/src/cartsnitch_api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/cartsnitch_api/services/alerts.py b/api/src/cartsnitch_api/services/alerts.py new file mode 100644 index 0000000..fc3ddd4 --- /dev/null +++ b/api/src/cartsnitch_api/services/alerts.py @@ -0,0 +1,75 @@ +"""Alert service — price and shrinkflation alerts for users. + +Alerts are generated by StickerShock and ShrinkRay services and written to the DB. +This service reads them for the API gateway. +""" + +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class AlertService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_alerts(self, user_id: UUID) -> list[dict]: + """List shrinkflation events for products the user has purchased.""" + from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent + + # Get product IDs from user's purchases + items_result = await self.db.execute( + select(PurchaseItem.normalized_product_id) + .join(Purchase) + .where( + Purchase.user_id == user_id, + PurchaseItem.normalized_product_id.isnot(None), + ) + .distinct() + ) + product_ids = [row[0] for row in items_result.all()] + + if not product_ids: + return [] + + result = await self.db.execute( + select(ShrinkflationEvent) + .where(ShrinkflationEvent.normalized_product_id.in_(product_ids)) + .options(selectinload(ShrinkflationEvent.normalized_product)) + .order_by(ShrinkflationEvent.detected_date.desc()) + ) + events = result.scalars().all() + + return [ + { + "id": e.id, + "alert_type": "shrinkflation", + "product_id": e.normalized_product_id, + "product_name": e.normalized_product.canonical_name, + "message": ( + f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}" + ), + "triggered_at": e.detected_date, + "read": False, + } + for e in events + ] + + async def get_settings(self, user_id: UUID) -> dict: + # Alert settings would be stored in a user_settings table. + # For now, return defaults since the table doesn't exist yet in common lib. + return { + "price_increase_threshold_pct": 5.0, + "shrinkflation_enabled": True, + "email_notifications": False, + } + + async def update_settings(self, user_id: UUID, **fields) -> dict: + # Would update user_settings table. Return merged defaults for now. + current = await self.get_settings(user_id) + for k, v in fields.items(): + if v is not None and k in current: + current[k] = v + return current diff --git a/api/src/cartsnitch_api/services/auth.py b/api/src/cartsnitch_api/services/auth.py new file mode 100644 index 0000000..5ea6b77 --- /dev/null +++ b/api/src/cartsnitch_api/services/auth.py @@ -0,0 +1,125 @@ +"""Auth service — user registration, login, token management.""" + +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from cartsnitch_api.auth.jwt import create_access_token, create_refresh_token, decode_token +from cartsnitch_api.auth.passwords import hash_password, verify_password +from cartsnitch_api.config import settings + + +class AuthService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def register(self, email: str, password: str, display_name: str) -> dict: + from cartsnitch_api.models import User + + existing = await self.db.execute(select(User).where(User.email == email)) + if existing.scalar_one_or_none(): + raise ValueError("Email already registered") + + user = User( + email=email, + hashed_password=hash_password(password), + display_name=display_name, + ) + self.db.add(user) + await self.db.commit() + await self.db.refresh(user) + + return self._make_token_response(user.id) + + async def login(self, email: str, password: str) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.email == email)) + user = result.scalar_one_or_none() + if not user or not verify_password(password, user.hashed_password): + raise ValueError("Invalid email or password") + + return self._make_token_response(user.id) + + async def refresh(self, refresh_token: str) -> dict: + from cartsnitch_api.models import User + + try: + payload = decode_token(refresh_token) + except ValueError: + raise ValueError("Invalid refresh token") from None + + if payload.get("type") != "refresh": + raise ValueError("Invalid token type") from None + + user_id = UUID(payload["sub"]) + + # Verify the user still exists before issuing new tokens + result = await self.db.execute(select(User).where(User.id == user_id)) + if not result.scalar_one_or_none(): + raise ValueError("User no longer exists") + + return self._make_token_response(user_id) + + async def get_user(self, user_id: UUID) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + return { + "id": user.id, + "email": user.email, + "display_name": user.display_name, + "created_at": user.created_at, + } + + async def update_user(self, user_id: UUID, **fields) -> dict: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + if "display_name" in fields and fields["display_name"] is not None: + user.display_name = fields["display_name"] + if "email" in fields and fields["email"] is not None: + existing = await self.db.execute( + select(User).where(User.email == fields["email"], User.id != user_id) + ) + if existing.scalar_one_or_none(): + raise ValueError("Email already in use") + user.email = fields["email"] + + await self.db.commit() + await self.db.refresh(user) + + return { + "id": user.id, + "email": user.email, + "display_name": user.display_name, + "created_at": user.created_at, + } + + async def delete_user(self, user_id: UUID) -> None: + from cartsnitch_api.models import User + + result = await self.db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise LookupError("User not found") + + await self.db.delete(user) + await self.db.commit() + + def _make_token_response(self, user_id: UUID) -> dict: + return { + "access_token": create_access_token(user_id), + "refresh_token": create_refresh_token(user_id), + "token_type": "bearer", + "expires_in": settings.jwt_access_token_expire_minutes * 60, + } diff --git a/api/src/cartsnitch_api/services/clipartist.py b/api/src/cartsnitch_api/services/clipartist.py new file mode 100644 index 0000000..86d6c62 --- /dev/null +++ b/api/src/cartsnitch_api/services/clipartist.py @@ -0,0 +1,52 @@ +"""HTTP client for ClipArtist internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ClipArtistClient: + def __init__(self) -> None: + self.base_url = settings.clipartist_url + self.headers = {"X-Service-Key": settings.service_key} + + async def optimize( + self, + user_id: str, + items: list[dict], + preferred_stores: list[str] | None = None, + ) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/optimize", + headers=self.headers, + json={ + "user_id": user_id, + "items": items, + "preferred_stores": preferred_stores, + }, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) + + async def get_shopping_lists(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/shopping-lists", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) + + async def get_relevant_coupons(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/coupons/relevant", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/coupons.py b/api/src/cartsnitch_api/services/coupons.py new file mode 100644 index 0000000..9b1543e --- /dev/null +++ b/api/src/cartsnitch_api/services/coupons.py @@ -0,0 +1,76 @@ +"""Coupon service — browse coupons, find relevant ones.""" + +from datetime import date +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class CouponService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_coupons(self, store_id: UUID | None = None) -> list[dict]: + from cartsnitch_api.models import Coupon + + today = date.today() + query = ( + select(Coupon) + .where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None))) + .options(selectinload(Coupon.store)) + .order_by(Coupon.valid_to.asc().nullslast()) + ) + if store_id: + query = query.where(Coupon.store_id == store_id) + + result = await self.db.execute(query) + coupons = result.scalars().all() + return [self._to_dict(c) for c in coupons] + + async def relevant_coupons(self, user_id: UUID) -> list[dict]: + """Coupons for products the user has purchased.""" + from cartsnitch_api.models import Coupon, PurchaseItem + + today = date.today() + + # Get product IDs from user's purchase history + from cartsnitch_api.models import Purchase + + items_result = await self.db.execute( + select(PurchaseItem.normalized_product_id) + .join(Purchase) + .where( + Purchase.user_id == user_id, + PurchaseItem.normalized_product_id.isnot(None), + ) + .distinct() + ) + product_ids = [row[0] for row in items_result.all()] + + if not product_ids: + return [] + + result = await self.db.execute( + select(Coupon) + .where( + Coupon.normalized_product_id.in_(product_ids), + (Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)), + ) + .options(selectinload(Coupon.store)) + ) + coupons = result.scalars().all() + return [self._to_dict(c) for c in coupons] + + def _to_dict(self, c) -> dict: + return { + "id": c.id, + "store_id": c.store_id, + "store_name": c.store.name, + "description": c.description or c.title, + "discount_value": float(c.discount_value) if c.discount_value else 0, + "discount_type": c.discount_type, + "product_id": c.normalized_product_id, + "expires_at": c.valid_to, + } diff --git a/api/src/cartsnitch_api/services/prices.py b/api/src/cartsnitch_api/services/prices.py new file mode 100644 index 0000000..44b74a0 --- /dev/null +++ b/api/src/cartsnitch_api/services/prices.py @@ -0,0 +1,183 @@ +"""Price service — trends, increases, comparison.""" + +from uuid import UUID + +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class PriceService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def get_trends(self, category: str | None = None) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + query = ( + select(PriceHistory) + .join(NormalizedProduct) + .options( + selectinload(PriceHistory.store), + selectinload(PriceHistory.normalized_product), + ) + .order_by(PriceHistory.observed_date) + ) + if category: + query = query.where(NormalizedProduct.category == category) + + result = await self.db.execute(query) + prices = result.scalars().all() + + # Group by product + by_product: dict[UUID, dict] = {} + for ph in prices: + pid = ph.normalized_product_id + if pid not in by_product: + by_product[pid] = { + "product_id": pid, + "product_name": ph.normalized_product.canonical_name, + "data_points": [], + } + by_product[pid]["data_points"].append( + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + ) + return list(by_product.values()) + + async def get_increases(self) -> list[dict]: + """Find products with recent significant price increases. + + Uses a window function (lag) to compare each price observation with the + previous one per product+store, avoiding the N+1 query pattern. + """ + from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + # Use lag() window function to get previous price in a single query + prev_price = ( + func.lag(PriceHistory.regular_price) + .over( + partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id], + order_by=PriceHistory.observed_date, + ) + .label("prev_price") + ) + + row_num = ( + func.row_number() + .over( + partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id], + order_by=PriceHistory.observed_date.desc(), + ) + .label("rn") + ) + + inner = select( + PriceHistory.normalized_product_id, + PriceHistory.store_id, + PriceHistory.regular_price, + PriceHistory.observed_date, + prev_price, + row_num, + ).subquery() + + # Only keep the latest row (rn=1) where price increased + result = await self.db.execute( + select( + inner.c.normalized_product_id, + inner.c.store_id, + inner.c.regular_price, + inner.c.observed_date, + inner.c.prev_price, + NormalizedProduct.canonical_name, + Store.name.label("store_name"), + ) + .join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id) + .join(Store, Store.id == inner.c.store_id) + .where( + inner.c.rn == 1, + inner.c.prev_price.isnot(None), + inner.c.regular_price > inner.c.prev_price, + ) + ) + + increases = [] + for row in result.all(): + old = float(row.prev_price) + new = float(row.regular_price) + increases.append( + { + "product_id": row.normalized_product_id, + "product_name": row.canonical_name, + "store_name": row.store_name, + "old_price": old, + "new_price": new, + "increase_pct": round((new - old) / old * 100, 2), + "detected_at": row.observed_date, + } + ) + + increases.sort(key=lambda x: x["increase_pct"], reverse=True) + return increases + + async def get_comparison(self, product_ids: list[UUID]) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + if not product_ids: + return [] + + # Fetch all requested products in one query + prod_result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) + ) + products_by_id = {p.id: p for p in prod_result.scalars().all()} + + # Latest prices for all requested products in one query + subq = latest_price_per_store(product_ids) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id.in_(product_ids)) + .options(selectinload(PriceHistory.store)) + ) + all_prices = prices_result.scalars().all() + + # Group prices by product + prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids} + for ph in all_prices: + prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) + + comparisons = [] + for pid in product_ids: + product = products_by_id.get(pid) + if not product: + continue + comparisons.append( + { + "product_id": pid, + "product_name": product.canonical_name, + "prices": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices_by_product.get(pid, []) + ], + } + ) + return comparisons diff --git a/api/src/cartsnitch_api/services/products.py b/api/src/cartsnitch_api/services/products.py new file mode 100644 index 0000000..ad35987 --- /dev/null +++ b/api/src/cartsnitch_api/services/products.py @@ -0,0 +1,124 @@ +"""Product service — catalog, detail, price history.""" + +from uuid import UUID + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class ProductService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_products( + self, + q: str | None = None, + category: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[dict]: + from cartsnitch_api.models import NormalizedProduct + + query = select(NormalizedProduct) + if q: + # Escape SQL LIKE wildcards in user input + safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%")) + if category: + query = query.where(NormalizedProduct.category == category) + query = query.order_by(NormalizedProduct.canonical_name) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + products = result.scalars().all() + return [ + { + "id": p.id, + "name": p.canonical_name, + "brand": p.brand, + "category": p.category, + "upc": (p.upc_variants[0] if p.upc_variants else None), + "image_url": None, + } + for p in products + ] + + async def get_product(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + # Get latest price per store + subq = latest_price_per_store([product_id]) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + ) + prices = prices_result.scalars().all() + + return { + "id": product.id, + "name": product.canonical_name, + "brand": product.brand, + "category": product.category, + "upc": (product.upc_variants[0] if product.upc_variants else None), + "image_url": None, + "prices_by_store": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices + ], + } + + async def get_price_history(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + prices_result = await self.db.execute( + select(PriceHistory) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + .order_by(PriceHistory.observed_date) + ) + prices = prices_result.scalars().all() + + return { + "product_id": product.id, + "product_name": product.canonical_name, + "data_points": [ + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + for ph in prices + ], + } diff --git a/api/src/cartsnitch_api/services/public.py b/api/src/cartsnitch_api/services/public.py new file mode 100644 index 0000000..f1ccbeb --- /dev/null +++ b/api/src/cartsnitch_api/services/public.py @@ -0,0 +1,129 @@ +"""Public service — unauthenticated price transparency endpoints.""" + +from uuid import UUID + +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.services.queries import latest_price_per_store + + +class PublicService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def get_trend(self, product_id: UUID) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id == product_id) + ) + product = result.scalar_one_or_none() + if not product: + raise LookupError("Product not found") + + prices_result = await self.db.execute( + select(PriceHistory) + .where(PriceHistory.normalized_product_id == product_id) + .options(selectinload(PriceHistory.store)) + .order_by(PriceHistory.observed_date) + ) + prices = prices_result.scalars().all() + + return { + "product_id": product.id, + "product_name": product.canonical_name, + "data_points": [ + { + "date": ph.observed_date, + "price": float(ph.regular_price), + "store_id": ph.store_id, + "store_name": ph.store.name, + } + for ph in prices + ], + } + + async def get_store_comparison(self, product_ids: list[UUID]) -> dict: + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + if not product_ids: + return {"products": []} + + # Fetch all products in one query + prod_result = await self.db.execute( + select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) + ) + products_by_id = {p.id: p for p in prod_result.scalars().all()} + + # Latest prices for all requested products in one query + subq = latest_price_per_store(product_ids) + prices_result = await self.db.execute( + select(PriceHistory) + .join( + subq, + and_( + PriceHistory.store_id == subq.c.store_id, + PriceHistory.observed_date == subq.c.max_date, + PriceHistory.normalized_product_id == subq.c.normalized_product_id, + ), + ) + .where(PriceHistory.normalized_product_id.in_(product_ids)) + .options(selectinload(PriceHistory.store)) + ) + all_prices = prices_result.scalars().all() + + # Group by product + prices_by_product: dict[UUID, list] = {} + for ph in all_prices: + prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) + + products = [] + for pid in product_ids: + product = products_by_id.get(pid) + if not product: + continue + products.append( + { + "product_id": pid, + "product_name": product.canonical_name, + "prices": [ + { + "store_id": ph.store_id, + "store_name": ph.store.name, + "current_price": float(ph.regular_price), + "last_seen_at": ph.observed_date, + } + for ph in prices_by_product.get(pid, []) + ], + } + ) + + return {"products": products} + + async def get_inflation(self) -> dict: + """Aggregate price change stats. Compares average prices across periods.""" + from cartsnitch_api.models import NormalizedProduct, PriceHistory + + # Get average prices grouped by category for recent vs older data + result = await self.db.execute( + select( + NormalizedProduct.category, + func.avg(PriceHistory.regular_price), + ) + .join(NormalizedProduct) + .group_by(NormalizedProduct.category) + ) + categories = {} + for row in result.all(): + cat, avg_price = row + if cat: + categories[cat] = float(avg_price) if avg_price else 0.0 + + return { + "period": "all-time", + "cartsnitch_index": sum(categories.values()) / max(len(categories), 1), + "cpi_baseline": 100.0, + "categories": categories, + } diff --git a/api/src/cartsnitch_api/services/purchases.py b/api/src/cartsnitch_api/services/purchases.py new file mode 100644 index 0000000..41776f4 --- /dev/null +++ b/api/src/cartsnitch_api/services/purchases.py @@ -0,0 +1,116 @@ +"""Purchase service — list, detail, stats.""" + +from uuid import UUID + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +class PurchaseService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_purchases( + self, + user_id: UUID, + store_id: UUID | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[dict]: + from cartsnitch_api.models import Purchase, PurchaseItem, Store + + # Count items per purchase in a single subquery instead of N+1 + item_counts = ( + select( + PurchaseItem.purchase_id, + func.count().label("item_count"), + ) + .group_by(PurchaseItem.purchase_id) + .subquery() + ) + + query = ( + select(Purchase, item_counts.c.item_count, Store.name.label("store_name")) + .join(Store, Store.id == Purchase.store_id) + .outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id) + .where(Purchase.user_id == user_id) + ) + if store_id: + query = query.where(Purchase.store_id == store_id) + + query = query.order_by(Purchase.purchase_date.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + + return [ + { + "id": p.id, + "store_id": p.store_id, + "store_name": store_name, + "purchased_at": p.purchase_date, + "total": float(p.total), + "item_count": item_count or 0, + } + for p, item_count, store_name in result.all() + ] + + async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict: + from cartsnitch_api.models import Purchase + + result = await self.db.execute( + select(Purchase) + .where(Purchase.id == purchase_id, Purchase.user_id == user_id) + .options(selectinload(Purchase.store), selectinload(Purchase.items)) + ) + purchase = result.scalar_one_or_none() + if not purchase: + raise LookupError("Purchase not found") + + return { + "id": purchase.id, + "store_id": purchase.store_id, + "store_name": purchase.store.name, + "purchased_at": purchase.purchase_date, + "total": float(purchase.total), + "item_count": len(purchase.items), + "line_items": [ + { + "id": item.id, + "product_id": item.normalized_product_id, + "name": item.product_name_raw, + "quantity": float(item.quantity), + "unit_price": float(item.unit_price), + "total_price": float(item.extended_price), + } + for item in purchase.items + ], + } + + async def get_stats(self, user_id: UUID) -> dict: + from cartsnitch_api.models import Purchase + + result = await self.db.execute( + select(Purchase) + .where(Purchase.user_id == user_id) + .options(selectinload(Purchase.store)) + ) + purchases = result.scalars().all() + + total_spent = sum(float(p.total) for p in purchases) + by_store: dict[str, float] = {} + by_period: dict[str, float] = {} + + for p in purchases: + store_name = p.store.name + by_store[store_name] = by_store.get(store_name, 0) + float(p.total) + period = p.purchase_date.strftime("%Y-%m") + by_period[period] = by_period.get(period, 0) + float(p.total) + + return { + "total_spent": total_spent, + "purchase_count": len(purchases), + "by_store": by_store, + "by_period": by_period, + } diff --git a/api/src/cartsnitch_api/services/queries.py b/api/src/cartsnitch_api/services/queries.py new file mode 100644 index 0000000..8a94f7c --- /dev/null +++ b/api/src/cartsnitch_api/services/queries.py @@ -0,0 +1,23 @@ +"""Shared query helpers for service layer.""" + +from uuid import UUID + +from sqlalchemy import func, select + + +def latest_price_per_store(product_ids: list[UUID] | None = None): + """Subquery returning the latest observed_date per product+store. + + Optionally filtered to a list of product IDs. Returns a subquery with + columns: normalized_product_id, store_id, max_date. + """ + from cartsnitch_api.models import PriceHistory + + query = select( + PriceHistory.normalized_product_id, + PriceHistory.store_id, + func.max(PriceHistory.observed_date).label("max_date"), + ).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id) + if product_ids is not None: + query = query.where(PriceHistory.normalized_product_id.in_(product_ids)) + return query.subquery() diff --git a/api/src/cartsnitch_api/services/receiptwitness.py b/api/src/cartsnitch_api/services/receiptwitness.py new file mode 100644 index 0000000..e6200a9 --- /dev/null +++ b/api/src/cartsnitch_api/services/receiptwitness.py @@ -0,0 +1,33 @@ +"""HTTP client for ReceiptWitness internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ReceiptWitnessClient: + def __init__(self) -> None: + self.base_url = settings.receiptwitness_url + self.headers = {"X-Service-Key": settings.service_key} + + async def trigger_sync(self, user_id: str, store_slug: str) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/sync/{store_slug}", + headers=self.headers, + json={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) + + async def get_sync_status(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/sync/status", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/shrinkray.py b/api/src/cartsnitch_api/services/shrinkray.py new file mode 100644 index 0000000..862881e --- /dev/null +++ b/api/src/cartsnitch_api/services/shrinkray.py @@ -0,0 +1,23 @@ +"""HTTP client for ShrinkRay internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class ShrinkRayClient: + def __init__(self) -> None: + self.base_url = settings.shrinkray_url + self.headers = {"X-Service-Key": settings.service_key} + + async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/alerts", + headers=self.headers, + params={"user_id": user_id}, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) diff --git a/api/src/cartsnitch_api/services/stickershock.py b/api/src/cartsnitch_api/services/stickershock.py new file mode 100644 index 0000000..3a7928d --- /dev/null +++ b/api/src/cartsnitch_api/services/stickershock.py @@ -0,0 +1,32 @@ +"""HTTP client for StickerShock internal API.""" + +from typing import Any, cast + +import httpx + +from cartsnitch_api.config import settings + + +class StickerShockClient: + def __init__(self) -> None: + self.base_url = settings.stickershock_url + self.headers = {"X-Service-Key": settings.service_key} + + async def get_price_increases(self, params: dict | None = None) -> list[dict]: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/increases", + headers=self.headers, + params=params, + ) + resp.raise_for_status() + return cast(list[dict[str, Any]], resp.json()) + + async def get_inflation_data(self) -> dict: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/inflation", + headers=self.headers, + ) + resp.raise_for_status() + return cast(dict[str, Any], resp.json()) diff --git a/api/src/cartsnitch_api/services/stores.py b/api/src/cartsnitch_api/services/stores.py new file mode 100644 index 0000000..610f47e --- /dev/null +++ b/api/src/cartsnitch_api/services/stores.py @@ -0,0 +1,129 @@ +"""Store service — list stores, manage user store account connections.""" + +import json +from uuid import UUID + +from cryptography.fernet import Fernet +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from cartsnitch_api.config import settings + + +def _get_fernet() -> Fernet: + return Fernet(settings.fernet_key.encode()) + + +class StoreService: + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def list_stores(self) -> list[dict]: + from cartsnitch_api.models import Store + + result = await self.db.execute(select(Store).order_by(Store.name)) + stores = result.scalars().all() + return [ + { + "id": s.id, + "name": s.name, + "slug": s.slug, + "logo_url": s.logo_url, + "supported": True, + } + for s in stores + ] + + async def list_user_stores(self, user_id: UUID) -> list[dict]: + from cartsnitch_api.models import UserStoreAccount + + result = await self.db.execute( + select(UserStoreAccount) + .where(UserStoreAccount.user_id == user_id) + .options(selectinload(UserStoreAccount.store)) + ) + accounts = result.scalars().all() + return [ + { + "store": { + "id": a.store.id, + "name": a.store.name, + "slug": a.store.slug, + "logo_url": a.store.logo_url, + "supported": True, + }, + "connected": a.status == "active", + "last_sync_at": a.last_sync_at, + "sync_status": a.status, + } + for a in accounts + ] + + async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict: + from cartsnitch_api.models import Store, UserStoreAccount + + result = await self.db.execute(select(Store).where(Store.slug == store_slug)) + store = result.scalar_one_or_none() + if not store: + raise LookupError(f"Store '{store_slug}' not found") + + existing = await self.db.execute( + select(UserStoreAccount).where( + UserStoreAccount.user_id == user_id, + UserStoreAccount.store_id == store.id, + ) + ) + if existing.scalar_one_or_none(): + raise ValueError("Store account already connected") + + encrypted_data = None + if credentials: + fernet = _get_fernet() + encrypted_data = { + "encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode() + } + + account = UserStoreAccount( + user_id=user_id, + store_id=store.id, + session_data=encrypted_data, + status="active", + ) + self.db.add(account) + await self.db.commit() + await self.db.refresh(account) + + return { + "store": { + "id": store.id, + "name": store.name, + "slug": store.slug, + "logo_url": store.logo_url, + "supported": True, + }, + "connected": True, + "last_sync_at": None, + "sync_status": "active", + } + + async def disconnect_store(self, user_id: UUID, store_slug: str) -> None: + from cartsnitch_api.models import Store, UserStoreAccount + + result = await self.db.execute(select(Store).where(Store.slug == store_slug)) + store = result.scalar_one_or_none() + if not store: + raise LookupError(f"Store '{store_slug}' not found") + + result = await self.db.execute( + select(UserStoreAccount).where( + UserStoreAccount.user_id == user_id, + UserStoreAccount.store_id == store.id, + ) + ) + account = result.scalar_one_or_none() + if not account: + raise LookupError("Store account not connected") + + await self.db.delete(account) + await self.db.commit() diff --git a/api/src/cartsnitch_api/types.py b/api/src/cartsnitch_api/types.py new file mode 100644 index 0000000..13a7820 --- /dev/null +++ b/api/src/cartsnitch_api/types.py @@ -0,0 +1,36 @@ +"""Custom SQLAlchemy column types.""" + +import json + +from cryptography.fernet import Fernet +from sqlalchemy import Text +from sqlalchemy.types import TypeDecorator + +from cartsnitch_api.config import settings + + +def _get_fernet() -> Fernet: + return Fernet(settings.fernet_key.encode()) + + +class EncryptedJSON(TypeDecorator): + """SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet. + + Stores data as a Fernet-encrypted text blob in the database. + On read, decrypts and deserialises back to a Python dict/list. + """ + + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + plaintext = json.dumps(value).encode() + return _get_fernet().encrypt(plaintext).decode() + + def process_result_value(self, value, dialect): + if value is None: + return None + decrypted = _get_fernet().decrypt(value.encode()) + return json.loads(decrypted) diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000..9873903 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,101 @@ +"""Shared test fixtures with in-memory SQLite database.""" + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import create_engine, event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import sessionmaker + +from cartsnitch_api.config import settings as cartsnitch_settings +from cartsnitch_api.database import get_db +from cartsnitch_api.main import create_app +from cartsnitch_api.models import Base + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +@pytest.fixture(autouse=True) +def disable_rate_limiting(): + """Disable rate limiting for all tests to prevent 429 interference.""" + cartsnitch_settings.rate_limit_enabled = False + yield + cartsnitch_settings.rate_limit_enabled = True + + +@pytest.fixture +def engine(): + """Sync in-memory SQLite engine for model unit tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """Sync SQLAlchemy session for model unit tests.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess + + +@pytest.fixture +async def db_engine(): + engine = create_async_engine(TEST_DATABASE_URL, echo=False) + + @event.listens_for(engine.sync_engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest.fixture +async def db_session(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + yield session + + +@pytest.fixture +async def client(db_engine): + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + + async def override_get_db(): + async with factory() as session: + yield session + + app = create_app() + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + app.dependency_overrides.clear() + + +@pytest.fixture +async def auth_headers(client): + """Register a test user and return auth headers.""" + resp = await client.post( + "/auth/register", + json={ + "email": "test@example.com", + "password": "testpass123", + "display_name": "Test User", + }, + ) + assert resp.status_code == 201 + token = resp.json()["access_token"] + return {"Authorization": f"Bearer {token}"} diff --git a/api/tests/test_auth/__init__.py b/api/tests/test_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_auth/test_auth_endpoints.py b/api/tests/test_auth/test_auth_endpoints.py new file mode 100644 index 0000000..878cbc5 --- /dev/null +++ b/api/tests/test_auth/test_auth_endpoints.py @@ -0,0 +1,209 @@ +"""Integration tests for auth endpoints.""" + +import pytest + + +@pytest.mark.asyncio +async def test_register_success(client): + resp = await client.post( + "/auth/register", + json={ + "email": "new@example.com", + "password": "securepass123", + "display_name": "New User", + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + assert data["expires_in"] == 900 # 15 min * 60 + + +@pytest.mark.asyncio +async def test_register_duplicate_email(client): + await client.post( + "/auth/register", + json={ + "email": "dupe@example.com", + "password": "securepass123", + "display_name": "User One", + }, + ) + resp = await client.post( + "/auth/register", + json={ + "email": "dupe@example.com", + "password": "securepass456", + "display_name": "User Two", + }, + ) + assert resp.status_code == 409 + + +@pytest.mark.asyncio +async def test_register_short_password(client): + resp = await client.post( + "/auth/register", + json={ + "email": "short@example.com", + "password": "short", + "display_name": "Short Pass", + }, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_login_success(client): + await client.post( + "/auth/register", + json={ + "email": "login@example.com", + "password": "securepass123", + "display_name": "Login User", + }, + ) + resp = await client.post( + "/auth/login", + json={ + "email": "login@example.com", + "password": "securepass123", + }, + ) + assert resp.status_code == 200 + assert "access_token" in resp.json() + + +@pytest.mark.asyncio +async def test_login_wrong_password(client): + await client.post( + "/auth/register", + json={ + "email": "wrong@example.com", + "password": "securepass123", + "display_name": "Wrong Pass", + }, + ) + resp = await client.post( + "/auth/login", + json={ + "email": "wrong@example.com", + "password": "badpassword1", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_login_nonexistent_user(client): + resp = await client.post( + "/auth/login", + json={ + "email": "ghost@example.com", + "password": "doesntmatter", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_refresh_token(client): + reg = await client.post( + "/auth/register", + json={ + "email": "refresh@example.com", + "password": "securepass123", + "display_name": "Refresh User", + }, + ) + refresh_token = reg.json()["refresh_token"] + + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": refresh_token, + }, + ) + assert resp.status_code == 200 + assert "access_token" in resp.json() + + +@pytest.mark.asyncio +async def test_refresh_with_invalid_token(client): + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": "invalid.token.here", + }, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me(client, auth_headers): + resp = await client.get("/auth/me", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["email"] == "test@example.com" + assert data["display_name"] == "Test User" + assert "id" in data + assert "created_at" in data + + +@pytest.mark.asyncio +async def test_get_me_unauthorized(client): + resp = await client.get("/auth/me") + assert resp.status_code in (401, 403) # No auth header + + +@pytest.mark.asyncio +async def test_update_me(client, auth_headers): + resp = await client.patch( + "/auth/me", + headers=auth_headers, + json={ + "display_name": "Updated Name", + }, + ) + assert resp.status_code == 200 + assert resp.json()["display_name"] == "Updated Name" + + +@pytest.mark.asyncio +async def test_delete_me(client, auth_headers): + resp = await client.delete("/auth/me", headers=auth_headers) + assert resp.status_code == 204 + + # Verify user is gone (token still valid but user deleted) + resp = await client.get("/auth/me", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_refresh_after_delete_fails(client): + """Refresh token for a deleted user must be rejected.""" + reg = await client.post( + "/auth/register", + json={ + "email": "ghost@example.com", + "password": "securepass123", + "display_name": "Ghost User", + }, + ) + tokens = reg.json() + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Delete the user + resp = await client.delete("/auth/me", headers=headers) + assert resp.status_code == 204 + + # Refresh token should now fail + resp = await client.post( + "/auth/refresh", + json={ + "refresh_token": tokens["refresh_token"], + }, + ) + assert resp.status_code == 401 diff --git a/api/tests/test_e2e/__init__.py b/api/tests/test_e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_e2e/conftest.py b/api/tests/test_e2e/conftest.py new file mode 100644 index 0000000..f1390fd --- /dev/null +++ b/api/tests/test_e2e/conftest.py @@ -0,0 +1,250 @@ +"""Shared fixtures for E2E integration tests. + +Seeds a realistic dataset with stores, products, price history, +purchases, coupons, and shrinkflation events so E2E flows can +exercise cross-resource queries against real data. +""" + +from datetime import date, timedelta +from decimal import Decimal +from uuid import UUID + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.auth.jwt import decode_token +from cartsnitch_api.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, +) + +# Shared test constants +ZERO_UUID = "00000000-0000-0000-0000-000000000000" +BAD_UUID = "not-a-uuid" +# Fixed anchor date for deterministic tests +ANCHOR_DATE = date(2026, 3, 15) + + +@pytest.fixture +async def seed_data(db_engine, auth_headers): + """Seed a full dataset and return identifiers for test assertions.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + # -- Stores -- + meijer = Store(name="Meijer", slug="meijer") + kroger = Store(name="Kroger", slug="kroger") + target = Store(name="Target", slug="target") + session.add_all([meijer, kroger, target]) + await session.flush() + + # -- Products -- + cheerios = NormalizedProduct( + canonical_name="Cheerios 18oz", + category="pantry", + brand="General Mills", + size="18", + size_unit="oz", + upc_variants=["016000275263"], + ) + milk = NormalizedProduct( + canonical_name="Whole Milk 1gal", + category="dairy", + brand="Meijer", + size="1", + size_unit="gal", + ) + chicken = NormalizedProduct( + canonical_name="Chicken Breast 1lb", + category="meat", + brand=None, + size="1", + size_unit="lb", + ) + session.add_all([cheerios, milk, chicken]) + await session.flush() + + # -- Price history (multiple dates, multiple stores) -- + today = ANCHOR_DATE + prices = [] + # Cheerios at Meijer: price increase over time + for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]): + prices.append( + PriceHistory( + normalized_product_id=cheerios.id, + store_id=meijer.id, + observed_date=today - timedelta(days=60 - i * 30), + regular_price=price_val, + source="receipt", + ) + ) + # Cheerios at Kroger: stable price + for i in range(3): + prices.append( + PriceHistory( + normalized_product_id=cheerios.id, + store_id=kroger.id, + observed_date=today - timedelta(days=60 - i * 30), + regular_price=Decimal("4.49"), + source="catalog", + ) + ) + # Milk at Meijer + prices.append( + PriceHistory( + normalized_product_id=milk.id, + store_id=meijer.id, + observed_date=today - timedelta(days=7), + regular_price=Decimal("3.29"), + source="receipt", + ) + ) + # Milk at Kroger + prices.append( + PriceHistory( + normalized_product_id=milk.id, + store_id=kroger.id, + observed_date=today - timedelta(days=5), + regular_price=Decimal("3.49"), + source="catalog", + ) + ) + # Chicken at Target + prices.append( + PriceHistory( + normalized_product_id=chicken.id, + store_id=target.id, + observed_date=today - timedelta(days=3), + regular_price=Decimal("5.99"), + source="catalog", + ) + ) + session.add_all(prices) + await session.flush() + + # -- Purchases (need the user_id from the registered test user) -- + token = auth_headers["Authorization"].split(" ")[1] + payload = decode_token(token) + user_id = UUID(payload["sub"]) + + purchase1 = Purchase( + user_id=user_id, + store_id=meijer.id, + receipt_id="meijer-2026-001", + purchase_date=today - timedelta(days=10), + total=Decimal("23.45"), + subtotal=Decimal("21.50"), + tax=Decimal("1.95"), + ) + purchase2 = Purchase( + user_id=user_id, + store_id=kroger.id, + receipt_id="kroger-2026-001", + purchase_date=today - timedelta(days=5), + total=Decimal("15.78"), + subtotal=Decimal("14.50"), + tax=Decimal("1.28"), + ) + session.add_all([purchase1, purchase2]) + await session.flush() + + # -- Purchase Items -- + item1 = PurchaseItem( + purchase_id=purchase1.id, + product_name_raw="Cheerios 18oz Box", + quantity=Decimal("1"), + unit_price=Decimal("4.79"), + extended_price=Decimal("4.79"), + normalized_product_id=cheerios.id, + ) + item2 = PurchaseItem( + purchase_id=purchase1.id, + product_name_raw="Meijer Whole Milk 1gal", + quantity=Decimal("2"), + unit_price=Decimal("3.29"), + extended_price=Decimal("6.58"), + normalized_product_id=milk.id, + ) + item3 = PurchaseItem( + purchase_id=purchase2.id, + product_name_raw="KRO CHEERIOS 18OZ", + quantity=Decimal("1"), + unit_price=Decimal("4.49"), + extended_price=Decimal("4.49"), + normalized_product_id=cheerios.id, + ) + session.add_all([item1, item2, item3]) + await session.flush() + + # -- Coupons -- + coupon1 = Coupon( + store_id=meijer.id, + normalized_product_id=cheerios.id, + title="$1 off Cheerios", + description="Save $1 on any Cheerios 18oz or larger", + discount_type="fixed", + discount_value=Decimal("1.00"), + valid_from=today - timedelta(days=7), + valid_to=today + timedelta(days=30), + ) + coupon2 = Coupon( + store_id=kroger.id, + normalized_product_id=None, + title="10% off dairy", + description="10% off all dairy products", + discount_type="percent", + discount_value=Decimal("10.00"), + valid_from=today - timedelta(days=3), + valid_to=today + timedelta(days=14), + ) + session.add_all([coupon1, coupon2]) + await session.flush() + + # -- Shrinkflation events -- + shrink = ShrinkflationEvent( + normalized_product_id=cheerios.id, + detected_date=today - timedelta(days=15), + old_size="20", + new_size="18", + old_unit="oz", + new_unit="oz", + price_at_old_size=Decimal("3.99"), + price_at_new_size=Decimal("4.29"), + confidence=Decimal("0.95"), + notes="Size reduced from 20oz to 18oz while price increased", + ) + session.add(shrink) + await session.commit() + + for obj in [ + meijer, + kroger, + target, + cheerios, + milk, + chicken, + purchase1, + purchase2, + item1, + item2, + item3, + coupon1, + coupon2, + shrink, + ]: + await session.refresh(obj) + + return { + "headers": auth_headers, + "user_id": user_id, + "stores": {"meijer": meijer, "kroger": kroger, "target": target}, + "products": {"cheerios": cheerios, "milk": milk, "chicken": chicken}, + "purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2}, + "items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3}, + "coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2}, + "shrinkflation": {"cheerios_shrink": shrink}, + } diff --git a/api/tests/test_e2e/test_auth_validation.py b/api/tests/test_e2e/test_auth_validation.py new file mode 100644 index 0000000..bbded83 --- /dev/null +++ b/api/tests/test_e2e/test_auth_validation.py @@ -0,0 +1,213 @@ +"""E2E: Auth and token validation flows.""" + +import asyncio + +import pytest + + +@pytest.mark.asyncio +class TestAuthRegistrationLogin: + """Full registration → login → token refresh → profile flow.""" + + async def test_full_auth_lifecycle(self, client, db_engine): + """Register → login → get profile → refresh → get profile again.""" + # Register + reg = await client.post( + "/auth/register", + json={ + "email": "lifecycle@example.com", + "password": "securepass123", + "display_name": "Lifecycle User", + }, + ) + assert reg.status_code == 201 + tokens = reg.json() + assert "access_token" in tokens + assert "refresh_token" in tokens + assert tokens["token_type"] == "bearer" + assert tokens["expires_in"] > 0 + + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Get profile with access token + me = await client.get("/auth/me", headers=headers) + assert me.status_code == 200 + assert me.json()["email"] == "lifecycle@example.com" + assert me.json()["display_name"] == "Lifecycle User" + + # Sleep 1s so the new token has a different exp than the registration token + await asyncio.sleep(1) + + # Login with same credentials + login = await client.post( + "/auth/login", + json={"email": "lifecycle@example.com", "password": "securepass123"}, + ) + assert login.status_code == 200 + login_tokens = login.json() + assert login_tokens["access_token"] != tokens["access_token"] + + # Refresh token + refresh = await client.post( + "/auth/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert refresh.status_code == 200 + new_tokens = refresh.json() + assert new_tokens["access_token"] != tokens["access_token"] + + # Use refreshed token to access profile + new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"} + me2 = await client.get("/auth/me", headers=new_headers) + assert me2.status_code == 200 + assert me2.json()["email"] == "lifecycle@example.com" + + +@pytest.mark.asyncio +class TestTokenValidation: + """Token edge cases and error responses.""" + + async def test_expired_token_rejected(self, client, db_engine): + """Manually craft an expired token and verify rejection.""" + import uuid + from datetime import UTC, datetime, timedelta + + from jose import jwt + + from cartsnitch_api.config import settings + + payload = { + "sub": str(uuid.uuid4()), + "exp": datetime.now(UTC) - timedelta(minutes=5), + "type": "access", + } + token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) + resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 401 + + async def test_invalid_token_rejected(self, client, db_engine): + resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"}) + assert resp.status_code == 401 + + async def test_missing_auth_header(self, client, db_engine): + resp = await client.get("/auth/me") + assert resp.status_code in (401, 403) + + async def test_refresh_token_cannot_access_endpoints(self, client, db_engine): + """A refresh token should not work as an access token.""" + reg = await client.post( + "/auth/register", + json={ + "email": "refresh-test@example.com", + "password": "securepass123", + "display_name": "Refresh Test", + }, + ) + refresh_token = reg.json()["refresh_token"] + resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}) + assert resp.status_code == 401 + + async def test_deleted_user_token_invalid(self, client, db_engine): + """After deleting an account, tokens should no longer work.""" + reg = await client.post( + "/auth/register", + json={ + "email": "delete-me@example.com", + "password": "securepass123", + "display_name": "Delete Me", + }, + ) + tokens = reg.json() + headers = {"Authorization": f"Bearer {tokens['access_token']}"} + + # Delete account + delete_resp = await client.delete("/auth/me", headers=headers) + assert delete_resp.status_code == 204 + + # Profile should fail + me = await client.get("/auth/me", headers=headers) + assert me.status_code in (401, 404) + + +@pytest.mark.asyncio +class TestAuthProtectedEndpoints: + """Verify auth is enforced on all user-specific endpoints.""" + + @pytest.mark.parametrize( + "method,path", + [ + ("GET", "/purchases"), + ("GET", "/products"), + ("GET", "/prices/trends"), + ("GET", "/prices/increases"), + ("GET", "/coupons"), + ("GET", "/alerts"), + ("GET", "/me/stores"), + ], + ) + async def test_endpoints_require_auth(self, client, db_engine, method, path): + resp = await client.request(method, path) + assert resp.status_code in (401, 403), f"{method} {path} should require auth" + + +@pytest.mark.asyncio +class TestCrossUserDataIsolation: + """Verify that users cannot access other users' data.""" + + async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data): + """Register a second user and verify they cannot see User A's purchases.""" + # User A's purchase (from seed_data) + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + + # Register User B + reg = await client.post( + "/auth/register", + json={ + "email": "userb@example.com", + "password": "securepass123", + "display_name": "User B", + }, + ) + assert reg.status_code == 201 + user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + # User B tries to access User A's specific purchase + resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers) + assert resp.status_code in (403, 404), ( + "User B should not be able to access User A's purchase" + ) + + async def test_user_b_purchase_list_is_empty(self, client, seed_data): + """A new user should see no purchases (not User A's purchases).""" + reg = await client.post( + "/auth/register", + json={ + "email": "userc@example.com", + "password": "securepass123", + "display_name": "User C", + }, + ) + assert reg.status_code == 201 + user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + resp = await client.get("/purchases", headers=user_c_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 0, "New user should have no purchases" + + async def test_user_b_stores_isolated(self, client, seed_data): + """User B's connected stores should be independent from User A.""" + reg = await client.post( + "/auth/register", + json={ + "email": "userd@example.com", + "password": "securepass123", + "display_name": "User D", + }, + ) + assert reg.status_code == 201 + user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"} + + # User D should have no connected stores + resp = await client.get("/me/stores", headers=user_d_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 0, "New user should have no connected stores" diff --git a/api/tests/test_e2e/test_cross_resource_flow.py b/api/tests/test_e2e/test_cross_resource_flow.py new file mode 100644 index 0000000..1f90671 --- /dev/null +++ b/api/tests/test_e2e/test_cross_resource_flow.py @@ -0,0 +1,114 @@ +"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts.""" + +import pytest + + +@pytest.mark.asyncio +class TestStoreConnectToPurchaseFlow: + """Connect a store, then verify purchases and related data are accessible.""" + + async def test_connect_store_then_list(self, client, seed_data): + headers = seed_data["headers"] + # Connect to Meijer + resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert resp.status_code in (200, 201) + + # Verify store appears in user's connected stores + stores = await client.get("/me/stores", headers=headers) + assert stores.status_code == 200 + slugs = [s["store"]["slug"] for s in stores.json()] + assert "meijer" in slugs + + async def test_disconnect_store(self, client, seed_data): + headers = seed_data["headers"] + await client.post("/me/stores/kroger/connect", json={}, headers=headers) + resp = await client.delete("/me/stores/kroger", headers=headers) + assert resp.status_code in (200, 204) + + # Verify store no longer in connected list + stores = await client.get("/me/stores", headers=headers) + slugs = [s["store"]["slug"] for s in stores.json()] + assert "kroger" not in slugs + + +@pytest.mark.asyncio +class TestPurchaseToPriceFlow: + """Verify purchase data links to price comparison data.""" + + async def test_purchase_items_link_to_products(self, client, seed_data): + """Items from purchases reference products that have price data.""" + headers = seed_data["headers"] + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + + # Get purchase detail + purchase = await client.get(f"/purchases/{purchase_id}", headers=headers) + assert purchase.status_code == 200 + items = purchase.json()["line_items"] + + # Get product detail for an item that has a product_id + product_ids = [li["product_id"] for li in items if li.get("product_id")] + assert len(product_ids) >= 1 + + for pid in product_ids: + product = await client.get(f"/products/{pid}", headers=headers) + assert product.status_code == 200 + assert len(product.json()["prices_by_store"]) >= 1 + + +@pytest.mark.asyncio +class TestCouponFlow: + """Verify coupon listing and relevance filtering.""" + + async def test_list_all_coupons(self, client, seed_data): + headers = seed_data["headers"] + resp = await client.get("/coupons", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + descriptions = [c["description"] for c in data] + assert any("Cheerios" in d for d in descriptions) + + async def test_filter_coupons_by_store(self, client, seed_data): + headers = seed_data["headers"] + meijer_id = str(seed_data["stores"]["meijer"].id) + resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert all(c["store_name"] == "Meijer" for c in data) + + async def test_relevant_coupons_for_user(self, client, seed_data): + """User bought Cheerios, so the Cheerios coupon should be relevant.""" + headers = seed_data["headers"] + resp = await client.get("/coupons/relevant", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases" + descriptions = [c["description"] for c in data] + assert any("Cheerios" in d for d in descriptions) + + +@pytest.mark.asyncio +class TestAlertFlow: + """Verify alert listing with seeded data.""" + + async def test_list_alerts(self, client, seed_data): + """User bought Cheerios which has a shrinkflation event — may appear as alert.""" + headers = seed_data["headers"] + resp = await client.get("/alerts", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list) + # If alerts are generated synchronously, verify shrinkflation alert content + if len(data) > 0: + alert_types = [a["alert_type"] for a in data] + product_names = [a["product_name"] for a in data] + assert any(t in ("shrinkflation", "price_increase") for t in alert_types) + assert any("Cheerios" in name for name in product_names) + + async def test_alert_settings_default(self, client, seed_data): + headers = seed_data["headers"] + resp = await client.get("/alerts/settings", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert "price_increase_threshold_pct" in data + assert "shrinkflation_enabled" in data diff --git a/api/tests/test_e2e/test_error_responses.py b/api/tests/test_e2e/test_error_responses.py new file mode 100644 index 0000000..c3ad16e --- /dev/null +++ b/api/tests/test_e2e/test_error_responses.py @@ -0,0 +1,127 @@ +"""E2E: Error responses for bad input across all endpoint categories.""" + +import pytest + +from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID + + +@pytest.mark.asyncio +class TestRegistrationErrors: + """Validation errors during user registration.""" + + async def test_short_password(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "short@example.com", "password": "short", "display_name": "Test"}, + ) + assert resp.status_code == 422 + + async def test_invalid_email(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"}, + ) + assert resp.status_code == 422 + + async def test_missing_fields(self, client, db_engine): + resp = await client.post("/auth/register", json={}) + assert resp.status_code == 422 + + async def test_empty_display_name(self, client, db_engine): + resp = await client.post( + "/auth/register", + json={"email": "empty@example.com", "password": "securepass123", "display_name": ""}, + ) + assert resp.status_code == 422 + + async def test_duplicate_email(self, client, db_engine): + payload = { + "email": "dupe@example.com", + "password": "securepass123", + "display_name": "First", + } + first = await client.post("/auth/register", json=payload) + assert first.status_code == 201 + second = await client.post("/auth/register", json=payload) + assert second.status_code == 409 + + +@pytest.mark.asyncio +class TestLoginErrors: + """Login failure modes.""" + + async def test_wrong_password(self, client, db_engine): + await client.post( + "/auth/register", + json={ + "email": "login-err@example.com", + "password": "correctpass1", + "display_name": "Login", + }, + ) + resp = await client.post( + "/auth/login", + json={"email": "login-err@example.com", "password": "wrongpass123"}, + ) + assert resp.status_code == 401 + + async def test_nonexistent_user(self, client, db_engine): + resp = await client.post( + "/auth/login", + json={"email": "nobody@example.com", "password": "doesntmatter"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestNotFoundErrors: + """404 responses for missing resources.""" + + async def test_product_not_found(self, client, seed_data): + resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_purchase_not_found(self, client, seed_data): + resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_public_trend_not_found(self, client, seed_data): + resp = await client.get(f"/public/trends/{ZERO_UUID}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestMalformedInput: + """Invalid UUID formats and bad query params.""" + + async def test_invalid_uuid_product(self, client, seed_data): + resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_invalid_uuid_purchase(self, client, seed_data): + resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_invalid_uuid_public_trend(self, client, seed_data): + resp = await client.get(f"/public/trends/{BAD_UUID}") + assert resp.status_code == 422 + + +@pytest.mark.asyncio +class TestStoreConnectionErrors: + """Store connection edge cases.""" + + async def test_connect_nonexistent_store(self, client, seed_data): + resp = await client.post( + "/me/stores/nonexistent-store/connect", + json={}, + headers=seed_data["headers"], + ) + assert resp.status_code == 404 + + async def test_connect_store_twice(self, client, seed_data): + headers = seed_data["headers"] + first = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert first.status_code in (200, 201) + second = await client.post("/me/stores/meijer/connect", json={}, headers=headers) + assert second.status_code == 409 diff --git a/api/tests/test_e2e/test_price_history.py b/api/tests/test_e2e/test_price_history.py new file mode 100644 index 0000000..3d53f06 --- /dev/null +++ b/api/tests/test_e2e/test_price_history.py @@ -0,0 +1,102 @@ +"""E2E: Price history queries returning correct data.""" + +import pytest + + +@pytest.mark.asyncio +class TestPriceTrends: + """Verify price trend aggregation against seeded history.""" + + async def test_trends_returns_all_products(self, client, seed_data): + resp = await client.get("/prices/trends", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + product_names = [t["product_name"] for t in data] + assert "Cheerios 18oz" in product_names + assert "Whole Milk 1gal" in product_names + + async def test_trends_filter_by_category(self, client, seed_data): + resp = await client.get( + "/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + # Only dairy products should appear + for trend in data: + assert trend["product_name"] == "Whole Milk 1gal" + + async def test_trends_contain_data_points(self, client, seed_data): + resp = await client.get("/prices/trends", headers=seed_data["headers"]) + data = resp.json() + cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz") + assert len(cheerios_trend["data_points"]) >= 3 + + +@pytest.mark.asyncio +class TestPriceIncreases: + """Detect price increases from seeded price history.""" + + async def test_increases_detected(self, client, seed_data): + resp = await client.get("/prices/increases", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + # Cheerios at Meijer went from 3.99 → 4.29 → 4.79 + cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"] + assert len(cheerios_increases) >= 1 + # Verify the increase data makes sense + for inc in cheerios_increases: + assert inc["new_price"] > inc["old_price"] + assert inc["increase_pct"] > 0 + assert inc["store_name"] == "Meijer" + + async def test_stable_prices_not_flagged(self, client, seed_data): + """Kroger Cheerios price is stable at $4.49 — should not appear as increase.""" + resp = await client.get("/prices/increases", headers=seed_data["headers"]) + data = resp.json() + kroger_increases = [ + inc + for inc in data + if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger" + ] + assert len(kroger_increases) == 0 + + +@pytest.mark.asyncio +class TestPriceComparison: + """Compare prices across stores for specific products.""" + + async def test_compare_cheerios_across_stores(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get( + "/prices/comparison", + params={"product_ids": cheerios_id}, + headers=seed_data["headers"], + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + cheerios_cmp = data[0] + assert cheerios_cmp["product_name"] == "Cheerios 18oz" + store_names = [p["store_name"] for p in cheerios_cmp["prices"]] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_compare_requires_product_ids(self, client, seed_data): + """product_ids is required — omitting it must return 422.""" + resp = await client.get("/prices/comparison", headers=seed_data["headers"]) + assert resp.status_code == 422 + + async def test_compare_multiple_products(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + milk_id = str(seed_data["products"]["milk"].id) + resp = await client.get( + "/prices/comparison", + params=[("product_ids", cheerios_id), ("product_ids", milk_id)], + headers=seed_data["headers"], + ) + assert resp.status_code == 200 + data = resp.json() + names = [c["product_name"] for c in data] + assert "Cheerios 18oz" in names + assert "Whole Milk 1gal" in names diff --git a/api/tests/test_e2e/test_product_search_lookup.py b/api/tests/test_e2e/test_product_search_lookup.py new file mode 100644 index 0000000..ea97c34 --- /dev/null +++ b/api/tests/test_e2e/test_product_search_lookup.py @@ -0,0 +1,82 @@ +"""E2E: Product search/lookup endpoints with real DB fixtures.""" + +import pytest + +from tests.test_e2e.conftest import ZERO_UUID + + +@pytest.mark.asyncio +class TestProductSearch: + """Search and filter products against seeded data.""" + + async def test_list_all_products(self, client, seed_data): + resp = await client.get("/products", headers=seed_data["headers"]) + assert resp.status_code == 200 + products = resp.json() + names = [p["name"] for p in products] + assert "Cheerios 18oz" in names + assert "Whole Milk 1gal" in names + assert "Chicken Breast 1lb" in names + + async def test_search_by_name(self, client, seed_data): + resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"]) + assert resp.status_code == 200 + products = resp.json() + assert len(products) >= 1 + assert all("cheerios" in p["name"].lower() for p in products) + + async def test_search_by_category(self, client, seed_data): + resp = await client.get( + "/products", params={"category": "dairy"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + products = resp.json() + assert len(products) >= 1 + assert all(p["category"] == "dairy" for p in products) + + async def test_search_no_results(self, client, seed_data): + resp = await client.get( + "/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +class TestProductLookup: + """Detailed product lookups with cross-store pricing.""" + + async def test_get_product_detail_with_prices(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Cheerios 18oz" + assert data["brand"] == "General Mills" + assert data["category"] == "pantry" + # Should have prices from both Meijer and Kroger + store_names = [p["store_name"] for p in data["prices_by_store"]] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_product_prices_reflect_latest(self, client, seed_data): + """The latest Meijer price for Cheerios should be 4.79 (the increase).""" + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"]) + data = resp.json() + meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer") + assert meijer_price["current_price"] == 4.79 + + async def test_product_not_found(self, client, seed_data): + resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"]) + assert resp.status_code == 404 + + async def test_product_price_history(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations + # Verify chronological ordering exists + prices = [dp["price"] for dp in data["data_points"]] + assert len(prices) >= 3 diff --git a/api/tests/test_e2e/test_public_endpoints.py b/api/tests/test_e2e/test_public_endpoints.py new file mode 100644 index 0000000..a0e24cf --- /dev/null +++ b/api/tests/test_e2e/test_public_endpoints.py @@ -0,0 +1,59 @@ +"""E2E: Public price transparency endpoints (no auth required).""" + +import uuid + +import pytest + + +@pytest.mark.asyncio +class TestPublicTrends: + """Public price trend endpoint — no auth, real data.""" + + async def test_public_trend_returns_data(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/public/trends/{cheerios_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Cheerios 18oz" + assert len(data["data_points"]) >= 3 + + async def test_public_trend_no_auth_needed(self, client, seed_data): + """Confirm no Authorization header is required.""" + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get(f"/public/trends/{cheerios_id}") + assert resp.status_code == 200 + + +@pytest.mark.asyncio +class TestPublicStoreComparison: + """Public store comparison endpoint.""" + + async def test_store_comparison(self, client, seed_data): + cheerios_id = str(seed_data["products"]["cheerios"].id) + resp = await client.get( + "/public/store-comparison", + params=[("product_ids", cheerios_id)], + ) + assert resp.status_code == 200 + data = resp.json() + assert "products" in data + assert len(data["products"]) >= 1 + + async def test_store_comparison_rejects_more_than_20_ids(self, client): + """max_length=20 guard: 21 product IDs must return 422.""" + too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)] + resp = await client.get("/public/store-comparison", params=too_many) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +class TestPublicInflation: + """Public inflation index endpoint.""" + + async def test_inflation_returns_index(self, client, seed_data): + resp = await client.get("/public/inflation") + assert resp.status_code == 200 + data = resp.json() + assert "cartsnitch_index" in data + assert "cpi_baseline" in data + assert "categories" in data diff --git a/api/tests/test_e2e/test_purchase_flow.py b/api/tests/test_e2e/test_purchase_flow.py new file mode 100644 index 0000000..44de438 --- /dev/null +++ b/api/tests/test_e2e/test_purchase_flow.py @@ -0,0 +1,87 @@ +"""E2E: Purchase listing, detail, and stats against real DB fixtures.""" + +import pytest + +from tests.test_e2e.conftest import ZERO_UUID + + +@pytest.mark.asyncio +class TestPurchaseList: + """List and filter a user's purchases.""" + + async def test_list_user_purchases(self, client, seed_data): + resp = await client.get("/purchases", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + store_names = [p["store_name"] for p in data] + assert "Meijer" in store_names + assert "Kroger" in store_names + + async def test_filter_purchases_by_store(self, client, seed_data): + meijer_id = str(seed_data["stores"]["meijer"].id) + resp = await client.get( + "/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"] + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert all(p["store_name"] == "Meijer" for p in data) + + async def test_purchases_require_auth(self, client, seed_data): + resp = await client.get("/purchases") + assert resp.status_code in (401, 403) + + +@pytest.mark.asyncio +class TestPurchaseDetail: + """Retrieve individual purchase with line items.""" + + async def test_get_purchase_detail(self, client, seed_data): + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["store_name"] == "Meijer" + assert data["total"] == 23.45 + assert len(data["line_items"]) == 2 + item_names = [li["name"] for li in data["line_items"]] + assert "Cheerios 18oz Box" in item_names + assert "Meijer Whole Milk 1gal" in item_names + + async def test_line_item_amounts_correct(self, client, seed_data): + purchase_id = str(seed_data["purchases"]["meijer_trip"].id) + resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"]) + data = resp.json() + cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"]) + assert cheerios_item["unit_price"] == 4.79 + assert cheerios_item["quantity"] == 1.0 + assert cheerios_item["total_price"] == 4.79 + + async def test_purchase_not_found(self, client, seed_data): + resp = await client.get( + f"/purchases/{ZERO_UUID}", + headers=seed_data["headers"], + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +class TestPurchaseStats: + """Verify spending aggregation across purchases.""" + + async def test_purchase_stats_totals(self, client, seed_data): + resp = await client.get("/purchases/stats", headers=seed_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["purchase_count"] == 2 + # 23.45 + 15.78 = 39.23 + assert abs(data["total_spent"] - 39.23) < 0.01 + + async def test_purchase_stats_by_store(self, client, seed_data): + resp = await client.get("/purchases/stats", headers=seed_data["headers"]) + data = resp.json() + assert "Meijer" in data["by_store"] + assert "Kroger" in data["by_store"] + assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01 + assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01 diff --git a/api/tests/test_encrypted_json.py b/api/tests/test_encrypted_json.py new file mode 100644 index 0000000..2ef3ccb --- /dev/null +++ b/api/tests/test_encrypted_json.py @@ -0,0 +1,130 @@ +"""Tests for EncryptedJSON TypeDecorator and session_data encryption.""" + +import json + +import pytest +from cryptography.fernet import Fernet +from pydantic import ValidationError +from sqlalchemy import column, create_engine, table, text +from sqlalchemy.orm import sessionmaker + +from cartsnitch_api.config import settings +from cartsnitch_api.models import Base +from cartsnitch_api.models.store import Store +from cartsnitch_api.models.user import User, UserStoreAccount + + +@pytest.fixture +def engine(): + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess + + +@pytest.fixture +def store(session): + s = Store(name="Test Store", slug="test-store") + session.add(s) + session.commit() + session.refresh(s) + return s + + +@pytest.fixture +def user(session): + u = User(email="alice@example.com", hashed_password="fakehash") + session.add(u) + session.commit() + session.refresh(u) + return u + + +class TestEncryptedJSONType: + """Unit tests for the EncryptedJSON TypeDecorator.""" + + def test_round_trip(self, session, user, store): + """Data written via the ORM comes back as the original dict.""" + original = {"token": "abc123", "cookies": {"session_id": "xyz"}} + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == original + + def test_stored_value_is_encrypted(self, session, user, store): + """The raw value in the DB should be a Fernet token, not plaintext JSON.""" + original = {"secret": "do-not-leak"} + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original) + session.add(account) + session.commit() + + # Use a raw table construct to bypass TypeDecorator on read + raw_table = table("user_store_accounts", column("id"), column("session_data")) + raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first() + # If UUID matching fails with str, try bytes format + if raw is None: + raw = session.execute( + text("SELECT session_data FROM user_store_accounts LIMIT 1") + ).scalar_one() + else: + raw = raw[1] + + assert raw != json.dumps(original) + assert raw.startswith("gAAAAA") + + # Verify we can decrypt the raw value manually + f = Fernet(settings.fernet_key.encode()) + decrypted = json.loads(f.decrypt(raw.encode())) + assert decrypted == original + + def test_null_round_trip(self, session, user, store): + """NULL session_data stays NULL.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data is None + + def test_empty_dict_round_trip(self, session, user, store): + """Empty dict round-trips correctly.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={}) + session.add(account) + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == {} + + def test_update_session_data(self, session, user, store): + """Updating session_data re-encrypts the new value.""" + account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1}) + session.add(account) + session.commit() + + account.session_data = {"v": 2, "new_field": True} + session.commit() + + loaded = session.get(UserStoreAccount, account.id) + assert loaded.session_data == {"v": 2, "new_field": True} + + +class TestEncryptionKeyValidation: + """Test that invalid/missing keys are caught at startup.""" + + def test_invalid_fernet_key_rejected(self, monkeypatch): + """Settings validation rejects a bad key.""" + monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key") + + with pytest.raises(ValidationError): + from cartsnitch_api.config import Settings + + Settings() diff --git a/api/tests/test_middleware/__init__.py b/api/tests/test_middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_middleware/conftest.py b/api/tests/test_middleware/conftest.py new file mode 100644 index 0000000..12f6b47 --- /dev/null +++ b/api/tests/test_middleware/conftest.py @@ -0,0 +1,19 @@ +"""Conftest for middleware tests — re-enables rate limiting after global disable.""" + +import pytest + +from cartsnitch_api.config import settings as cartsnitch_settings + + +@pytest.fixture(autouse=True) +def enable_rate_limiting(): + """Re-enable rate limiting after the global disable_rate_limiting fixture runs. + + The root conftest disables rate limiting for all tests to prevent 429 + interference. Middleware tests need it active to verify headers and + enforcement. This fixture runs after the root fixture (more local = later + in setup order) so True is the effective value during the test body. + """ + cartsnitch_settings.rate_limit_enabled = True + yield + cartsnitch_settings.rate_limit_enabled = False diff --git a/api/tests/test_middleware/test_error_handler.py b/api/tests/test_middleware/test_error_handler.py new file mode 100644 index 0000000..950351d --- /dev/null +++ b/api/tests/test_middleware/test_error_handler.py @@ -0,0 +1,54 @@ +"""Tests for structured error responses and error monitoring.""" + +import pytest + + +@pytest.mark.asyncio +async def test_404_returns_structured_error(client): + """Non-existent route should return structured error.""" + resp = await client.get("/nonexistent") + assert resp.status_code == 404 + body = resp.json() + assert "detail" in body + assert "code" in body + assert body["code"] == "NOT_FOUND" + + +@pytest.mark.asyncio +async def test_validation_error_returns_422_with_field_errors(client): + """Invalid request body should return structured validation errors.""" + resp = await client.post( + "/auth/register", + json={"email": "not-an-email", "password": "short", "display_name": ""}, + ) + assert resp.status_code == 422 + body = resp.json() + assert body["code"] == "VALIDATION_ERROR" + assert "errors" in body + assert isinstance(body["errors"], list) + assert len(body["errors"]) > 0 + # Each error should have field, message, type + for err in body["errors"]: + assert "field" in err + assert "message" in err + assert "type" in err + + +@pytest.mark.asyncio +async def test_error_stats_requires_service_key(client): + """Error stats endpoint should require X-Service-Key.""" + resp = await client.get("/internal/error-stats") + assert resp.status_code == 422 # Missing required header + + +@pytest.mark.asyncio +async def test_error_stats_with_valid_key(client): + """Error stats endpoint returns monitoring data with valid key.""" + resp = await client.get( + "/internal/error-stats", + headers={"X-Service-Key": "change-me-in-production"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "error_counts" in body + assert "recent_5xx_count" in body diff --git a/api/tests/test_middleware/test_rate_limit.py b/api/tests/test_middleware/test_rate_limit.py new file mode 100644 index 0000000..d5b7691 --- /dev/null +++ b/api/tests/test_middleware/test_rate_limit.py @@ -0,0 +1,55 @@ +"""Tests for rate limiting middleware.""" + +import pytest + +from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter + + +class TestSlidingWindowCounter: + def test_allows_within_limit(self): + counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) + for i in range(5): + allowed, remaining, retry = counter.is_allowed("test-key") + assert allowed is True + assert remaining == 4 - i + + def test_blocks_over_limit(self): + counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) + for _ in range(3): + counter.is_allowed("test-key") + + allowed, remaining, retry = counter.is_allowed("test-key") + assert allowed is False + assert remaining == 0 + assert retry > 0 + + def test_separate_keys(self): + counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) + # Fill key-a + counter.is_allowed("key-a") + counter.is_allowed("key-a") + allowed_a, _, _ = counter.is_allowed("key-a") + assert allowed_a is False + + # key-b should still be allowed + allowed_b, remaining, _ = counter.is_allowed("key-b") + assert allowed_b is True + assert remaining == 1 + + +@pytest.mark.asyncio +async def test_rate_limit_returns_429(client): + """Public endpoint should return 429 after limit exceeded.""" + # The default limit is 60/min — we won't hit it in normal tests, + # but we verify the middleware adds rate limit headers. + resp = await client.get("/public/inflation") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + + +@pytest.mark.asyncio +async def test_health_skips_rate_limit(client): + """Health endpoint should not have rate limit headers.""" + resp = await client.get("/health") + assert resp.status_code == 200 + assert "x-ratelimit-limit" not in resp.headers diff --git a/api/tests/test_models.py b/api/tests/test_models.py new file mode 100644 index 0000000..c0f8651 --- /dev/null +++ b/api/tests/test_models.py @@ -0,0 +1,376 @@ +"""Tests for SQLAlchemy ORM models.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from sqlalchemy import inspect + +from cartsnitch_api.constants import ( + AccountStatus, + DiscountType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_api.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + StoreLocation, + User, + UserStoreAccount, +) + + +class TestTableCreation: + """Verify all expected tables are created.""" + + def test_all_tables_exist(self, engine): + inspector = inspect(engine) + table_names = set(inspector.get_table_names()) + expected = { + "stores", + "store_locations", + "users", + "user_store_accounts", + "purchases", + "purchase_items", + "normalized_products", + "price_history", + "coupons", + "shrinkflation_events", + } + assert expected.issubset(table_names) + + def test_ten_tables_total(self, engine): + inspector = inspect(engine) + assert len(inspector.get_table_names()) == 10 + + +class TestUUIDPrimaryKeys: + """All models use UUID PKs.""" + + def test_store_uuid_pk(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert isinstance(store.id, uuid.UUID) + + def test_user_uuid_pk(self, session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(user) + session.commit() + assert isinstance(user.id, uuid.UUID) + + +class TestStoreModel: + def test_store_slug_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert store.slug == StoreSlug.KROGER + + def test_store_unique_slug(self, session): + s1 = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + s2 = Store( + id=uuid.uuid4(), + name="Target Duplicate", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(s1) + session.commit() + session.add(s2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestStoreLocationModel: + def test_store_location_fields(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + loc = StoreLocation( + id=uuid.uuid4(), + store_id=store.id, + address="123 Main St", + city="Ann Arbor", + state="MI", + zip="48104", + lat=42.2808, + lng=-83.7430, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(loc) + session.commit() + assert loc.city == "Ann Arbor" + assert loc.lat == pytest.approx(42.2808) + + +class TestUserStoreAccountModel: + def test_account_status_enum(self, session): + user = User( + id=uuid.uuid4(), + email="test@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + acct = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(acct) + session.commit() + assert acct.status == AccountStatus.ACTIVE + + def test_unique_user_store_constraint(self, session): + """One account per user per store.""" + user = User( + id=uuid.uuid4(), + email="unique@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + a1 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + a2 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.EXPIRED, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(a1) + session.commit() + session.add(a2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestPurchaseModel: + def test_purchase_with_items(self, session): + user = User( + id=uuid.uuid4(), + email="buyer@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + purchase = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + ingested_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(purchase) + session.flush() + item = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase.id, + product_name_raw="Meijer Whole Milk 1 Gallon", + upc="0041250000001", + quantity=Decimal("1"), + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ) + session.add(item) + session.commit() + assert item.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert item.unit_price == Decimal("3.49") + + +class TestNormalizedProductModel: + def test_product_with_upc_variants(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + brand="Store Brand", + size="128", + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + assert product.category == ProductCategory.DAIRY + assert product.size_unit == SizeUnit.FL_OZ + + +class TestPriceHistoryModel: + def test_price_source_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Eggs, Large, 12ct", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([store, product]) + session.flush() + ph = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + sale_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(ph) + session.commit() + assert ph.source == PriceSource.RECEIPT + assert ph.regular_price == Decimal("4.99") + + +class TestCouponModel: + def test_coupon_discount_types(self, session): + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + coupon = Coupon( + id=uuid.uuid4(), + store_id=store.id, + title="$2 off eggs", + discount_type=DiscountType.FIXED, + discount_value=Decimal("2.00"), + requires_clip=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(coupon) + session.commit() + assert coupon.discount_type == DiscountType.FIXED + assert coupon.discount_value == Decimal("2.00") + + +class TestShrinkflationEventModel: + def test_shrinkflation_event(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Cereal, Honey Oats", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + price_at_old_size=Decimal("4.99"), + price_at_new_size=Decimal("4.99"), + confidence=Decimal("0.95"), + notes="Size reduced by 14.4%, price unchanged", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(event) + session.commit() + assert event.confidence == Decimal("0.95") + assert event.old_unit == SizeUnit.OZ diff --git a/api/tests/test_openapi.py b/api/tests/test_openapi.py new file mode 100644 index 0000000..97eef19 --- /dev/null +++ b/api/tests/test_openapi.py @@ -0,0 +1,92 @@ +"""Verify all expected routes are present in the OpenAPI spec.""" + +import pytest +from httpx import ASGITransport, AsyncClient + +from cartsnitch_api.main import app + +EXPECTED_ROUTES = [ + # Auth (6) + ("post", "/auth/register"), + ("post", "/auth/login"), + ("post", "/auth/refresh"), + ("get", "/auth/me"), + ("patch", "/auth/me"), + ("delete", "/auth/me"), + # Stores (4) + ("get", "/stores"), + ("get", "/me/stores"), + ("post", "/me/stores/{store_slug}/connect"), + ("delete", "/me/stores/{store_slug}"), + # Purchases (3) + ("get", "/purchases"), + ("get", "/purchases/stats"), + ("get", "/purchases/{purchase_id}"), + # Products (3) + ("get", "/products"), + ("get", "/products/{product_id}"), + ("get", "/products/{product_id}/prices"), + # Prices (3) + ("get", "/prices/trends"), + ("get", "/prices/increases"), + ("get", "/prices/comparison"), + # Coupons (2) + ("get", "/coupons"), + ("get", "/coupons/relevant"), + # Shopping (2) + ("post", "/shopping/optimize"), + ("get", "/shopping/lists"), + # Alerts (3) + ("get", "/alerts"), + ("get", "/alerts/settings"), + ("put", "/alerts/settings"), + # Scraping (2) + ("post", "/scraping/{store_slug}/sync"), + ("get", "/scraping/status"), + # Public (3) + ("get", "/public/trends/{product_id}"), + ("get", "/public/store-comparison"), + ("get", "/public/inflation"), + # Health (1) + ("get", "/health"), +] + + +@pytest.mark.asyncio +async def test_all_routes_in_openapi(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/openapi.json") + assert resp.status_code == 200 + spec = resp.json() + paths = spec["paths"] + + registered = set() + for path, methods in paths.items(): + for method in methods: + if method in ("get", "post", "put", "delete", "patch"): + registered.add((method, path)) + + missing = [] + for method, path in EXPECTED_ROUTES: + if (method, path) not in registered: + missing.append(f"{method.upper()} {path}") + + assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing) + + +@pytest.mark.asyncio +async def test_route_count(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/openapi.json") + spec = resp.json() + paths = spec["paths"] + + count = 0 + for _path, methods in paths.items(): + for method in methods: + if method in ("get", "post", "put", "delete", "patch"): + count += 1 + + assert count == 33, f"Expected 33 routes, found {count}" diff --git a/api/tests/test_routes/__init__.py b/api/tests/test_routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/test_routes/test_alerts.py b/api/tests/test_routes/test_alerts.py new file mode 100644 index 0000000..5b576a5 --- /dev/null +++ b/api/tests/test_routes/test_alerts.py @@ -0,0 +1,35 @@ +"""Integration tests for alert endpoints.""" + +import pytest + + +@pytest.mark.asyncio +async def test_list_alerts_empty(client, auth_headers): + """No purchases means no alerts.""" + resp = await client.get("/alerts", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_get_alert_settings(client, auth_headers): + resp = await client.get("/alerts/settings", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["price_increase_threshold_pct"] == 5.0 + assert data["shrinkflation_enabled"] is True + assert data["email_notifications"] is False + + +@pytest.mark.asyncio +async def test_update_alert_settings_returns_501(client, auth_headers): + resp = await client.put( + "/alerts/settings", + headers=auth_headers, + json={ + "price_increase_threshold_pct": 10.0, + "shrinkflation_enabled": False, + "email_notifications": True, + }, + ) + assert resp.status_code == 501 diff --git a/api/tests/test_routes/test_coupons.py b/api/tests/test_routes/test_coupons.py new file mode 100644 index 0000000..8687acc --- /dev/null +++ b/api/tests/test_routes/test_coupons.py @@ -0,0 +1,58 @@ +"""Integration tests for coupon endpoints.""" + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import Coupon, Store + + +@pytest.fixture +async def coupon_data(db_engine, auth_headers): + """Seed stores and coupons.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Target", slug="target") + session.add(store) + await session.commit() + await session.refresh(store) + + coupon = Coupon( + store_id=store.id, + title="$2 off laundry", + description="$2 off any laundry detergent", + discount_value=Decimal("2.00"), + discount_type="fixed", + valid_from=date(2026, 1, 1), + valid_to=date(2026, 12, 31), + ) + session.add(coupon) + await session.commit() + + return {"store": store, "coupon": coupon, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_list_coupons(client, coupon_data): + resp = await client.get("/coupons", headers=coupon_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + + +@pytest.mark.asyncio +async def test_list_coupons_by_store(client, coupon_data): + store_id = str(coupon_data["store"].id) + resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) >= 1 + + +@pytest.mark.asyncio +async def test_relevant_coupons_empty(client, auth_headers): + """No purchases means no relevant coupons.""" + resp = await client.get("/coupons/relevant", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] diff --git a/api/tests/test_routes/test_prices.py b/api/tests/test_routes/test_prices.py new file mode 100644 index 0000000..7bdc60f --- /dev/null +++ b/api/tests/test_routes/test_prices.py @@ -0,0 +1,90 @@ +"""Integration tests for price endpoints.""" + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def price_data(db_engine, auth_headers): + """Seed products with price history showing an increase.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Walmart", slug="walmart") + product = NormalizedProduct( + canonical_name="Tide Pods 42ct", + category="household", + brand="Tide", + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + # Two price points — second is higher (increase) + ph1 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 2, 1), + regular_price=Decimal("12.99"), + source="receipt", + ) + ph2 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("14.49"), + source="receipt", + ) + session.add_all([ph1, ph2]) + await session.commit() + + return {"product": product, "store": store, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_price_trends(client, price_data): + resp = await client.get("/prices/trends", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["product_name"] == "Tide Pods 42ct" + assert len(data[0]["data_points"]) == 2 + + +@pytest.mark.asyncio +async def test_price_trends_by_category(client, price_data): + resp = await client.get("/prices/trends?category=household", headers=price_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 0 + + +@pytest.mark.asyncio +async def test_price_increases(client, price_data): + resp = await client.get("/prices/increases", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + increase = data[0] + assert increase["old_price"] == 12.99 + assert increase["new_price"] == 14.49 + assert increase["increase_pct"] > 0 + + +@pytest.mark.asyncio +async def test_price_comparison(client, price_data): + pid = str(price_data["product"].id) + resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["product_name"] == "Tide Pods 42ct" + assert len(data[0]["prices"]) >= 1 diff --git a/api/tests/test_routes/test_products.py b/api/tests/test_routes/test_products.py new file mode 100644 index 0000000..7e27c9c --- /dev/null +++ b/api/tests/test_routes/test_products.py @@ -0,0 +1,94 @@ +"""Integration tests for product endpoints.""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def product_data(db_engine, auth_headers): + """Seed products and price history.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Meijer", slug="meijer") + product = NormalizedProduct( + canonical_name="Cheerios 18oz", + category="pantry", + brand="General Mills", + upc_variants=["016000275263"], + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + ph1 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("4.99"), + source="receipt", + ) + ph2 = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 10), + regular_price=Decimal("5.49"), + source="receipt", + ) + session.add_all([ph1, ph2]) + await session.commit() + + return {"product": product, "store": store, "headers": auth_headers} + + +@pytest.mark.asyncio +async def test_list_products(client, product_data): + resp = await client.get("/products", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["name"] == "Cheerios 18oz" + + +@pytest.mark.asyncio +async def test_search_products(client, product_data): + resp = await client.get("/products?q=Cheerios", headers=product_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + resp = await client.get("/products?q=nonexistent", headers=product_data["headers"]) + assert resp.status_code == 200 + assert len(resp.json()) == 0 + + +@pytest.mark.asyncio +async def test_get_product_detail(client, product_data): + pid = str(product_data["product"].id) + resp = await client.get(f"/products/{pid}", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Cheerios 18oz" + assert data["brand"] == "General Mills" + assert len(data["prices_by_store"]) >= 1 + + +@pytest.mark.asyncio +async def test_get_product_not_found(client, auth_headers): + resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_product_prices(client, product_data): + pid = str(product_data["product"].id) + resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Cheerios 18oz" + assert len(data["data_points"]) == 2 diff --git a/api/tests/test_routes/test_public.py b/api/tests/test_routes/test_public.py new file mode 100644 index 0000000..08a5d29 --- /dev/null +++ b/api/tests/test_routes/test_public.py @@ -0,0 +1,73 @@ +"""Integration tests for public endpoints (no auth).""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store + + +@pytest.fixture +async def public_data(db_engine): + """Seed data for public endpoints.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Target", slug="target") + product = NormalizedProduct( + canonical_name="Skippy PB 16oz", + category="pantry", + brand="Skippy", + ) + session.add_all([store, product]) + await session.commit() + await session.refresh(store) + await session.refresh(product) + + ph = PriceHistory( + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 5), + regular_price=Decimal("3.99"), + source="receipt", + ) + session.add(ph) + await session.commit() + + return {"product": product, "store": store} + + +@pytest.mark.asyncio +async def test_public_trend(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}") + assert resp.status_code == 200 + data = resp.json() + assert data["product_name"] == "Skippy PB 16oz" + assert len(data["data_points"]) == 1 + + +@pytest.mark.asyncio +async def test_public_trend_not_found(client): + resp = await client.get(f"/public/trends/{uuid.uuid4()}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_public_store_comparison(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/store-comparison?product_ids={pid}") + assert resp.status_code == 200 + data = resp.json() + assert len(data["products"]) == 1 + + +@pytest.mark.asyncio +async def test_public_inflation(client, public_data): + resp = await client.get("/public/inflation") + assert resp.status_code == 200 + data = resp.json() + assert "categories" in data + assert "cartsnitch_index" in data diff --git a/api/tests/test_routes/test_purchases.py b/api/tests/test_routes/test_purchases.py new file mode 100644 index 0000000..14d5eb6 --- /dev/null +++ b/api/tests/test_routes/test_purchases.py @@ -0,0 +1,95 @@ +"""Integration tests for purchase endpoints.""" + +import uuid +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from cartsnitch_api.auth.jwt import create_access_token +from cartsnitch_api.models import Purchase, PurchaseItem, Store, User + + +@pytest.fixture +async def purchase_data(db_engine): + """Seed a user, store, purchase, and items.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + from cartsnitch_api.auth.passwords import hash_password + + user = User( + email="buyer@example.com", + hashed_password=hash_password("testpass123"), + display_name="Buyer", + ) + store = Store(name="Kroger", slug="kroger") + session.add_all([user, store]) + await session.commit() + await session.refresh(user) + await session.refresh(store) + + purchase = Purchase( + user_id=user.id, + store_id=store.id, + receipt_id="receipt-001", + purchase_date=date(2026, 3, 10), + total=Decimal("42.50"), + ) + session.add(purchase) + await session.commit() + await session.refresh(purchase) + + item = PurchaseItem( + purchase_id=purchase.id, + product_name_raw="Organic Milk 1gal", + quantity=Decimal("1"), + unit_price=Decimal("5.99"), + extended_price=Decimal("5.99"), + ) + session.add(item) + await session.commit() + + token = create_access_token(user.id) + return { + "user": user, + "store": store, + "purchase": purchase, + "headers": {"Authorization": f"Bearer {token}"}, + } + + +@pytest.mark.asyncio +async def test_list_purchases(client, purchase_data): + resp = await client.get("/purchases", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["store_name"] == "Kroger" + assert data[0]["total"] == 42.50 + + +@pytest.mark.asyncio +async def test_get_purchase_detail(client, purchase_data): + pid = str(purchase_data["purchase"].id) + resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert len(data["line_items"]) == 1 + assert data["line_items"][0]["name"] == "Organic Milk 1gal" + + +@pytest.mark.asyncio +async def test_get_purchase_not_found(client, auth_headers): + resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_purchase_stats(client, purchase_data): + resp = await client.get("/purchases/stats", headers=purchase_data["headers"]) + assert resp.status_code == 200 + data = resp.json() + assert data["total_spent"] == 42.50 + assert data["purchase_count"] == 1 + assert "Kroger" in data["by_store"] diff --git a/api/tests/test_routes/test_stores.py b/api/tests/test_routes/test_stores.py new file mode 100644 index 0000000..002ff05 --- /dev/null +++ b/api/tests/test_routes/test_stores.py @@ -0,0 +1,77 @@ +"""Integration tests for store endpoints.""" + +import pytest + +from cartsnitch_api.models import Store + + +@pytest.fixture +async def seeded_store(db_engine): + """Insert a test store directly into the DB.""" + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None) + session.add(store) + await session.commit() + await session.refresh(store) + return store + + +@pytest.mark.asyncio +async def test_list_stores(client, seeded_store): + resp = await client.get("/stores") + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 1 + assert data[0]["slug"] == "meijer" + + +@pytest.mark.asyncio +async def test_list_user_stores_empty(client, auth_headers): + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_connect_and_disconnect_store(client, auth_headers, seeded_store): + # Connect + resp = await client.post( + "/me/stores/meijer/connect", + headers=auth_headers, + json={"credentials": None}, + ) + assert resp.status_code == 201 + assert resp.json()["connected"] is True + + # List should show connected + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + # Disconnect + resp = await client.delete("/me/stores/meijer", headers=auth_headers) + assert resp.status_code == 204 + + # List should be empty again + resp = await client.get("/me/stores", headers=auth_headers) + assert resp.json() == [] + + +@pytest.mark.asyncio +async def test_connect_nonexistent_store(client, auth_headers): + resp = await client.post( + "/me/stores/nonexistent/connect", + headers=auth_headers, + json={}, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_connect_duplicate_store(client, auth_headers, seeded_store): + await client.post("/me/stores/meijer/connect", headers=auth_headers, json={}) + resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={}) + assert resp.status_code == 409 diff --git a/api/tests/test_services/__init__.py b/api/tests/test_services/__init__.py new file mode 100644 index 0000000..e69de29 From 4cf6f91e954b770198578bcb8db5d98ac964bfed Mon Sep 17 00:00:00 2001 From: Coupon Carl Date: Sat, 28 Mar 2026 02:24:14 +0000 Subject: [PATCH 2/4] Squashed 'common/' content from commit 28b2939 git-subtree-dir: common git-subtree-split: 28b2939037b5932ca5d5a6c734b292c012ac675f --- .github/workflows/ci.yml | 88 ++ .gitignore | 8 + CLAUDE.md | 185 ++++ Dockerfile | 19 + alembic.ini | 36 + alembic/env.py | 51 + alembic/script.py.mako | 25 + pyproject.toml | 47 + renovate.json | 4 + scripts/stats/README.md | 60 ++ scripts/stats/savings_potential.sql | 121 +++ scripts/stats/shrinkflation_count.sql | 39 + scripts/stats/validate_launch_stats.py | 267 +++++ src/cartsnitch_common/__init__.py | 3 + src/cartsnitch_common/config.py | 18 + src/cartsnitch_common/constants.py | 85 ++ src/cartsnitch_common/database.py | 45 + src/cartsnitch_common/events.py | 28 + src/cartsnitch_common/models/__init__.py | 25 + src/cartsnitch_common/models/base.py | 30 + src/cartsnitch_common/models/coupon.py | 42 + src/cartsnitch_common/models/price.py | 50 + src/cartsnitch_common/models/product.py | 39 + src/cartsnitch_common/models/purchase.py | 91 ++ src/cartsnitch_common/models/shrinkflation.py | 41 + src/cartsnitch_common/models/store.py | 52 + src/cartsnitch_common/models/user.py | 51 + src/cartsnitch_common/normalization.py | 156 +++ src/cartsnitch_common/pipeline/__init__.py | 26 + src/cartsnitch_common/pipeline/matching.py | 136 +++ .../pipeline/price_tracking.py | 130 +++ src/cartsnitch_common/pipeline/receipt.py | 144 +++ .../pipeline/shrinkflation.py | 165 +++ src/cartsnitch_common/py.typed | 0 src/cartsnitch_common/schemas/__init__.py | 49 + src/cartsnitch_common/schemas/coupon.py | 45 + src/cartsnitch_common/schemas/events.py | 17 + src/cartsnitch_common/schemas/price.py | 38 + src/cartsnitch_common/schemas/product.py | 33 + src/cartsnitch_common/schemas/purchase.py | 73 ++ .../schemas/shrinkflation.py | 40 + src/cartsnitch_common/schemas/store.py | 52 + src/cartsnitch_common/schemas/user.py | 44 + src/cartsnitch_common/seed/__init__.py | 1 + src/cartsnitch_common/seed/__main__.py | 50 + src/cartsnitch_common/seed/config.py | 38 + .../seed/generators/__init__.py | 1 + .../seed/generators/coupons.py | 107 ++ .../seed/generators/prices.py | 162 +++ .../seed/generators/products.py | 253 +++++ .../seed/generators/purchases.py | 156 +++ .../seed/generators/shrinkflation.py | 114 +++ .../seed/generators/stores.py | 203 ++++ .../seed/generators/users.py | 105 ++ src/cartsnitch_common/seed/runner.py | 189 ++++ tests/__init__.py | 0 tests/conftest.py | 24 + tests/test_models.py | 376 +++++++ tests/test_normalization.py | 157 +++ tests/test_pipeline_e2e.py | 949 ++++++++++++++++++ tests/test_pipeline_matching.py | 160 +++ tests/test_pipeline_price.py | 282 ++++++ tests/test_pipeline_receipt.py | 204 ++++ tests/test_pipeline_shrinkflation.py | 233 +++++ tests/test_schemas.py | 225 +++++ tests/test_seed.py | 357 +++++++ 66 files changed, 7044 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 pyproject.toml create mode 100644 renovate.json create mode 100644 scripts/stats/README.md create mode 100644 scripts/stats/savings_potential.sql create mode 100644 scripts/stats/shrinkflation_count.sql create mode 100644 scripts/stats/validate_launch_stats.py create mode 100644 src/cartsnitch_common/__init__.py create mode 100644 src/cartsnitch_common/config.py create mode 100644 src/cartsnitch_common/constants.py create mode 100644 src/cartsnitch_common/database.py create mode 100644 src/cartsnitch_common/events.py create mode 100644 src/cartsnitch_common/models/__init__.py create mode 100644 src/cartsnitch_common/models/base.py create mode 100644 src/cartsnitch_common/models/coupon.py create mode 100644 src/cartsnitch_common/models/price.py create mode 100644 src/cartsnitch_common/models/product.py create mode 100644 src/cartsnitch_common/models/purchase.py create mode 100644 src/cartsnitch_common/models/shrinkflation.py create mode 100644 src/cartsnitch_common/models/store.py create mode 100644 src/cartsnitch_common/models/user.py create mode 100644 src/cartsnitch_common/normalization.py create mode 100644 src/cartsnitch_common/pipeline/__init__.py create mode 100644 src/cartsnitch_common/pipeline/matching.py create mode 100644 src/cartsnitch_common/pipeline/price_tracking.py create mode 100644 src/cartsnitch_common/pipeline/receipt.py create mode 100644 src/cartsnitch_common/pipeline/shrinkflation.py create mode 100644 src/cartsnitch_common/py.typed create mode 100644 src/cartsnitch_common/schemas/__init__.py create mode 100644 src/cartsnitch_common/schemas/coupon.py create mode 100644 src/cartsnitch_common/schemas/events.py create mode 100644 src/cartsnitch_common/schemas/price.py create mode 100644 src/cartsnitch_common/schemas/product.py create mode 100644 src/cartsnitch_common/schemas/purchase.py create mode 100644 src/cartsnitch_common/schemas/shrinkflation.py create mode 100644 src/cartsnitch_common/schemas/store.py create mode 100644 src/cartsnitch_common/schemas/user.py create mode 100644 src/cartsnitch_common/seed/__init__.py create mode 100644 src/cartsnitch_common/seed/__main__.py create mode 100644 src/cartsnitch_common/seed/config.py create mode 100644 src/cartsnitch_common/seed/generators/__init__.py create mode 100644 src/cartsnitch_common/seed/generators/coupons.py create mode 100644 src/cartsnitch_common/seed/generators/prices.py create mode 100644 src/cartsnitch_common/seed/generators/products.py create mode 100644 src/cartsnitch_common/seed/generators/purchases.py create mode 100644 src/cartsnitch_common/seed/generators/shrinkflation.py create mode 100644 src/cartsnitch_common/seed/generators/stores.py create mode 100644 src/cartsnitch_common/seed/generators/users.py create mode 100644 src/cartsnitch_common/seed/runner.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_models.py create mode 100644 tests/test_normalization.py create mode 100644 tests/test_pipeline_e2e.py create mode 100644 tests/test_pipeline_matching.py create mode 100644 tests/test_pipeline_price.py create mode 100644 tests/test_pipeline_receipt.py create mode 100644 tests/test_pipeline_shrinkflation.py create mode 100644 tests/test_schemas.py create mode 100644 tests/test_seed.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7483099 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,88 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + lint: + runs-on: runners-cartsnitch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install ruff + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + + typecheck: + runs-on: runners-cartsnitch + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install -e ".[dev]" + - name: Type check + run: mypy src/cartsnitch_common + + test: + runs-on: runners-cartsnitch + services: + postgres: + image: postgres:15-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + env: + POSTGRES_USER: cartsnitch + POSTGRES_PASSWORD: cartsnitch_test + POSTGRES_DB: cartsnitch_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + DATABASE_URL: postgresql://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + CARTSNITCH_DATABASE_URL_SYNC: postgresql://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install -e ".[dev]" + - name: Run migrations + run: alembic upgrade head + - name: Run tests + run: pytest --tb=short -q + + build: + runs-on: runners-cartsnitch + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install build + - name: Build package + run: python -m build diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e044f1f --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.egg-info/ +dist/ +build/ +.pytest_cache/ +*.egg +.venv/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ef248cd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,185 @@ +# CartSnitch Common Library + +## Project Context + +CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/common`) is the shared Python library that all CartSnitch services depend on. + +**GitHub org:** github.com/cartsnitch +**Domain:** cartsnitch.com + +### CartSnitch Services + +| Repo | Service | Purpose | +|------|---------|---------| +| `cartsnitch/common` | — | Shared models, schemas, utilities (this repo) | +| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers | +| `cartsnitch/api` | API Gateway | Frontend-facing REST API | +| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) | +| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | +| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | +| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | +| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations | + +### Architecture Decisions + +- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline. +- **Shared DB:** One PostgreSQL cluster (CNPG on K8s, docker-compose locally). Each service owns its tables but shares the database. Services access other services' data via REST API, not direct cross-service DB queries. +- **Inter-service comms:** REST (synchronous) + Redis pub/sub (async events). +- **Target scale:** 500–1,000 users initially. +- **Target retailers (MVP):** Meijer (mPerks), Kroger, Target (Circle) in Southeast Michigan. + +## What This Repo Contains + +This is a Python package (`cartsnitch-common`) that provides: + +1. **SQLAlchemy ORM models** — the canonical database schema shared across services +2. **Pydantic schemas** — request/response models for inter-service API contracts +3. **Database utilities** — engine/session factory, connection management +4. **Configuration** — shared settings via pydantic-settings (DB URL, Redis URL, etc.) +5. **Event definitions** — Redis pub/sub event types and payloads +6. **Constants** — store slugs, category enums, etc. + +## Tech Stack + +- Python 3.12+ +- SQLAlchemy 2.0 (async support) +- Alembic (migrations live in this repo since it owns the schema) +- Pydantic v2 +- pydantic-settings (env-based configuration) +- Redis (py-redis for pub/sub event definitions) + +## Database Schema + +All migrations are managed from this repo via Alembic. Services depend on `cartsnitch-common` to get the models. + +### Core Tables + +``` +stores + id (PK), name, slug (meijer|kroger|target), logo_url, website_url, created_at + +store_locations + id (PK), store_id (FK), address, city, state, zip, lat, lng + +users + id (PK), email, hashed_password, display_name, created_at, updated_at + +user_store_accounts + id (PK), user_id (FK), store_id (FK), session_data (encrypted JSONB), session_expires_at, last_sync_at, status (active|expired|error) + +purchases + id (PK), user_id (FK), store_id (FK), store_location_id (FK), receipt_id (unique per store), purchase_date, total, subtotal, tax, savings_total, source_url, raw_data (JSONB), ingested_at + +purchase_items + id (PK), purchase_id (FK), product_name_raw, upc, quantity, unit_price, extended_price, regular_price, sale_price, coupon_discount, loyalty_discount, category_raw, normalized_product_id (FK, nullable) + +normalized_products + id (PK), canonical_name, category, subcategory, brand, size, size_unit, upc_variants (JSONB), created_at, updated_at + +price_history + id (PK), normalized_product_id (FK), store_id (FK), observed_date, regular_price, sale_price, loyalty_price, coupon_price, source (receipt|catalog|weekly_ad), purchase_item_id (FK, nullable) + +coupons + id (PK), store_id (FK), normalized_product_id (FK, nullable), title, description, discount_type (percent|fixed|bogo|buy_x_get_y), discount_value, min_purchase, valid_from, valid_to, requires_clip, coupon_code, source_url, scraped_at + +shrinkflation_events + id (PK), normalized_product_id (FK), detected_date, old_size, new_size, old_unit, new_unit, price_at_old_size, price_at_new_size, confidence, notes +``` + +## Repo Structure + +``` +cartsnitch-common/ +├── CLAUDE.md +├── README.md +├── pyproject.toml # Package definition, installable via pip +├── alembic.ini +├── alembic/ +│ ├── env.py +│ └── versions/ +├── src/ +│ └── cartsnitch_common/ +│ ├── __init__.py +│ ├── config.py # Shared settings (DB_URL, REDIS_URL, etc.) +│ ├── database.py # Engine, session factory, async support +│ ├── models/ +│ │ ├── __init__.py # Re-exports all models +│ │ ├── base.py # DeclarativeBase, common mixins (timestamps, etc.) +│ │ ├── store.py # Store, StoreLocation +│ │ ├── user.py # User, UserStoreAccount +│ │ ├── purchase.py # Purchase, PurchaseItem +│ │ ├── product.py # NormalizedProduct +│ │ ├── price.py # PriceHistory +│ │ ├── coupon.py # Coupon +│ │ └── shrinkflation.py # ShrinkflationEvent +│ ├── schemas/ +│ │ ├── __init__.py +│ │ ├── purchase.py # Pydantic request/response schemas +│ │ ├── product.py +│ │ ├── price.py +│ │ ├── coupon.py +│ │ └── events.py # Redis pub/sub event payloads +│ ├── events.py # Event bus helpers (publish/subscribe) +│ └── constants.py # Store slugs, enums +└── tests/ + ├── conftest.py + ├── test_models.py + └── test_schemas.py +``` + +## Packaging + +This package is published as `cartsnitch-common` and installed by other services via: + +``` +# In each service's pyproject.toml +dependencies = [ + "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@main", +] +``` + +Or if using a private PyPI registry, publish there. For local dev, install in editable mode: + +```bash +pip install -e /path/to/common +``` + +## Development Workflow + +- **Never push directly to main.** Always create feature branches and open PRs. +- Branch naming: `feature/` or `fix/` +- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` +- Alembic migrations must be reviewed carefully — they affect all services. +- Bump the version in `pyproject.toml` when changing schemas or models so downstream services can pin versions. +- Run `alembic upgrade head` in local dev after pulling changes. + +## Event Bus (Redis Pub/Sub) + +Events are the primary async communication mechanism between services. Event types are defined in this repo so all services share the same contract. + +### Event Channels + +- `cartsnitch.receipts.ingested` — ReceiptWitness publishes when new receipt data is saved +- `cartsnitch.prices.updated` — Published when new price data points are recorded +- `cartsnitch.products.normalized` — Published when product normalization resolves a match +- `cartsnitch.coupons.updated` — ClipArtist publishes when coupon data refreshes +- `cartsnitch.alerts.price_increase` — StickerShock publishes when a significant price increase is detected +- `cartsnitch.alerts.shrinkflation` — ShrinkRay publishes when shrinkflation is detected + +### Event Payload Structure + +```json +{ + "event_type": "cartsnitch.receipts.ingested", + "timestamp": "2026-03-15T12:00:00Z", + "service": "receiptwitness", + "payload": { ... } +} +``` + +## Important Notes + +- This is the schema owner. All Alembic migrations live here. No other service runs its own migrations. +- When adding new models or changing existing ones, always create a migration and bump the package version. +- Pydantic schemas in `schemas/` define the API contracts between services. These are the source of truth for inter-service communication. +- The `database.py` module should support both sync and async sessions since different services may use different patterns. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2ce4733 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +# syntax=docker/dockerfile:1 +FROM python:3.12-slim AS base + +WORKDIR /app + +COPY pyproject.toml ./ +RUN pip install --no-cache-dir . + +COPY src/ src/ +COPY alembic/ alembic/ +COPY alembic.ini ./ + +FROM base AS test +RUN pip install --no-cache-dir ".[dev]" +COPY tests/ tests/ +CMD ["pytest", "--tb=short", "-q"] + +FROM base AS prod +CMD ["python", "-c", "import cartsnitch_common; print(f'cartsnitch-common ready')"] diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..00a0b14 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql://localhost:5432/cartsnitch + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..cd893e0 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,51 @@ +"""Alembic environment configuration for CartSnitch.""" + +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context +from cartsnitch_common.models.base import Base + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC") +if db_url: + config.set_main_option("sqlalchemy.url", db_url) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fe3b097 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ee348c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,47 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cartsnitch-common" +version = "2026.321.0" +description = "Shared models, schemas, and utilities for CartSnitch services" +requires-python = ">=3.12" +dependencies = [ + "sqlalchemy[asyncio]>=2.0,<3.0", + "alembic>=1.13,<2.0", + "pydantic[email]>=2.0,<3.0", + "pydantic-settings>=2.0,<3.0", + "asyncpg>=0.29,<1.0", + "redis>=5.0,<6.0", + "psycopg2-binary>=2.9,<3.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "ruff>=0.4", + "mypy>=1.10", + "faker>=33.0,<34.0", +] +seed = [ + "faker>=33.0,<34.0", +] + +[project.scripts] +cartsnitch-seed = "cartsnitch_common.seed.__main__:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/cartsnitch_common"] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..833ba3b --- /dev/null +++ b/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["local>cartsnitch/.github:renovate-config"] +} diff --git a/scripts/stats/README.md b/scripts/stats/README.md new file mode 100644 index 0000000..86d424f --- /dev/null +++ b/scripts/stats/README.md @@ -0,0 +1,60 @@ +# Launch Announcement Validation Queries + +Scripts to validate the two statistics cited in the CartSnitch launch announcement: + +1. **847 products that shrank in the past 12 months** +2. **$336/year potential savings from buying the same items at the cheapest store** + +## Status + +These queries are written against the production data model but **cannot be run yet** — production infrastructure (CAR-99, CAR-104) is still being deployed. Once production data is available, run these scripts to confirm the cited numbers. + +## Queries + +### Stat 1: Shrinkflation count (`shrinkflation_count.sql`) + +Counts distinct `normalized_product_id` values with at least one `ShrinkflationEvent` where `detected_date` falls within the past 12 months. + +**Key assumptions:** +- "Past 12 months" is relative to query execution date (`CURRENT_DATE - INTERVAL '12 months'`). +- A product counts once even if it has multiple shrinkflation events in the window. +- The 847 figure was generated from a specific date — re-running will drift as the window slides. + +### Stat 2: Annual savings potential (`savings_potential.sql`) + +**Methodology:** + +For each `normalized_product_id` with price observations from **two or more distinct stores** in the past 90 days: + +1. Take the **most recent `regular_price`** per `(normalized_product_id, store_id)` pair. +2. Compute `cheapest_price` = MIN across stores, `avg_price` = AVG across stores. +3. `savings_per_purchase` = `avg_price - cheapest_price`. + +To arrive at **annual** savings per family: + +- Assume a **typical family purchases each product ~N times per year** (see `PURCHASE_FREQUENCY_PER_YEAR` constant in `validate_launch_stats.py`). +- Default assumption: products purchased on average 26×/year (~every 2 weeks for regularly bought items). +- Sum across all eligible products: `Σ(savings_per_purchase × frequency)`. + +**Sensitivity knobs:** +- `PURCHASE_FREQUENCY_PER_YEAR` — adjust purchase cadence assumption +- `LOOKBACK_DAYS` — how recent a price observation must be to be "current" (default: 90 days) +- `MIN_STORES_FOR_COMPARISON` — minimum number of stores a product must appear at (default: 2) + +The $336 figure assumes the defaults above. If actual purchase frequencies differ significantly, rerun `validate_launch_stats.py --freq `. + +## Running + +```bash +# Requires DATABASE_URL env var pointing at production Postgres +python scripts/stats/validate_launch_stats.py + +# Adjust purchase frequency assumption (default: 26 times/year) +python scripts/stats/validate_launch_stats.py --freq 20 + +# Run just stat 1 or stat 2 +python scripts/stats/validate_launch_stats.py --stat 1 +python scripts/stats/validate_launch_stats.py --stat 2 +``` + +Raw SQL files (`shrinkflation_count.sql`, `savings_potential.sql`) can also be run directly with `psql`. diff --git a/scripts/stats/savings_potential.sql b/scripts/stats/savings_potential.sql new file mode 100644 index 0000000..f575bcc --- /dev/null +++ b/scripts/stats/savings_potential.sql @@ -0,0 +1,121 @@ +-- ============================================================================= +-- Stat 2: Annual savings potential from cross-store price comparison +-- Validates: "$336/year potential savings from buying the same items +-- at the cheapest store" (launch announcement) +-- +-- Methodology: +-- 1. For each (normalized_product_id, store_id), take the MOST RECENT +-- regular_price within the past 90 days ("current" price). +-- 2. Keep only products observed at 2+ distinct stores. +-- 3. For each product: savings_per_purchase = avg_price - min_price across stores. +-- 4. Annualise: multiply by an assumed purchase frequency of 26x/year +-- (~every 2 weeks for regularly purchased grocery items). +-- 5. Sum across all eligible products to get total annual savings potential. +-- +-- Sensitivity: +-- Change the frequency constant (26) and lookback interval (90 days) to +-- explore how sensitive the $336 figure is to these assumptions. +-- +-- Run against production Postgres once infrastructure is available. +-- ============================================================================= + +-- Step 1: most-recent price per (product, store) within the past 90 days +WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + s.slug AS store_slug, + ph.regular_price AS current_price, + ph.observed_date + FROM price_history ph + JOIN stores s ON s.id = ph.store_id + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '90 days' + AND ph.regular_price > 0 + ORDER BY + ph.normalized_product_id, + ph.store_id, + ph.observed_date DESC +), + +-- Step 2: aggregate per product — only keep products seen at 2+ stores +product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price, + MAX(lp.current_price) AS most_expensive_price, + MAX(lp.current_price) - MIN(lp.current_price) AS price_range + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= 2 +), + +-- Step 3: compute savings_per_purchase and annualise +-- Purchase frequency assumption: 26 purchases/year per product (~every 2 weeks) +-- This is a conservative estimate for regularly purchased grocery items. +savings_per_product AS ( + SELECT + pps.normalized_product_id, + np.canonical_name, + np.category, + pps.store_count, + pps.cheapest_price, + pps.avg_price, + pps.price_range, + ROUND(pps.avg_price - pps.cheapest_price, 2) AS savings_per_purchase, + ROUND((pps.avg_price - pps.cheapest_price) * 26, 2) AS annual_savings_at_26x + FROM product_price_spread pps + JOIN normalized_products np ON np.id = pps.normalized_product_id +) + +-- Final summary: total annual savings potential +SELECT + COUNT(*) AS eligible_product_count, + ROUND(AVG(savings_per_purchase), 4) AS avg_savings_per_purchase, + ROUND(SUM(annual_savings_at_26x), 2) AS total_annual_savings_26x_freq, + -- Sensitivity: alternative frequencies + ROUND(SUM(savings_per_purchase) * 20, 2) AS total_annual_savings_20x_freq, + ROUND(SUM(savings_per_purchase) * 52, 2) AS total_annual_savings_52x_freq +FROM savings_per_product; + + +-- Per-product detail (top 50 by annual savings opportunity) +WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + s.slug AS store_slug, + ph.regular_price AS current_price, + ph.observed_date + FROM price_history ph + JOIN stores s ON s.id = ph.store_id + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '90 days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC +), +product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= 2 +) +SELECT + np.canonical_name, + np.category, + np.brand, + np.size, + np.size_unit, + pps.store_count, + pps.cheapest_price, + ROUND(pps.avg_price, 2) AS avg_price, + ROUND(pps.avg_price - pps.cheapest_price, 2) AS savings_per_purchase, + ROUND((pps.avg_price - pps.cheapest_price) * 26, 2) AS annual_savings_at_26x +FROM product_price_spread pps +JOIN normalized_products np ON np.id = pps.normalized_product_id +ORDER BY annual_savings_at_26x DESC +LIMIT 50; diff --git a/scripts/stats/shrinkflation_count.sql b/scripts/stats/shrinkflation_count.sql new file mode 100644 index 0000000..7312190 --- /dev/null +++ b/scripts/stats/shrinkflation_count.sql @@ -0,0 +1,39 @@ +-- ============================================================================= +-- Stat 1: Products that shrank in the past 12 months +-- Validates: "847 products that shrank in the past 12 months" (launch announcement) +-- +-- Run against production Postgres once infrastructure is available. +-- Results will drift as the 12-month window slides forward from execution date. +-- ============================================================================= + +-- Primary count: distinct products with ≥1 shrinkflation event in the past year +SELECT + COUNT(DISTINCT se.normalized_product_id) AS shrinkflation_product_count +FROM shrinkflation_events se +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months'; + + +-- Breakdown by product category (for deeper reporting) +SELECT + COALESCE(np.category, 'unknown') AS category, + COUNT(DISTINCT se.normalized_product_id) AS products_with_shrinkflation +FROM shrinkflation_events se +JOIN normalized_products np ON np.id = se.normalized_product_id +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +GROUP BY np.category +ORDER BY products_with_shrinkflation DESC; + + +-- Breakdown by confidence band (high/medium/low events) +-- Confidence >= 0.80 = "clear" shrinkflation signal +SELECT + CASE + WHEN se.confidence >= 0.80 THEN 'high (>=0.80)' + WHEN se.confidence >= 0.50 THEN 'medium (0.50-0.79)' + ELSE 'low (<0.50)' + END AS confidence_band, + COUNT(DISTINCT se.normalized_product_id) AS products +FROM shrinkflation_events se +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +GROUP BY confidence_band +ORDER BY MIN(se.confidence) DESC; diff --git a/scripts/stats/validate_launch_stats.py b/scripts/stats/validate_launch_stats.py new file mode 100644 index 0000000..b0e152f --- /dev/null +++ b/scripts/stats/validate_launch_stats.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +validate_launch_stats.py — Validate CartSnitch launch announcement statistics. + +Validates two statistics from content/marketing/launch-announcement.md: + 1. "847 products that shrank in the past 12 months" + 2. "$336/year potential savings from buying the same items at the cheapest store" + +Usage: + DATABASE_URL=postgresql+asyncpg://... python scripts/stats/validate_launch_stats.py + python scripts/stats/validate_launch_stats.py --freq 20 # change purchase frequency + python scripts/stats/validate_launch_stats.py --stat 1 # run stat 1 only + python scripts/stats/validate_launch_stats.py --stat 2 # run stat 2 only + +NOTE: Production infrastructure is not yet deployed (CAR-99, CAR-104). This script +cannot be run against real data until those are complete. The data model has been +verified to support both queries. + +Ref: CAR-162 +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import sys +from decimal import Decimal + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +# ────────────────────────────────────────────────────────────────────────────── +# Configuration / assumptions +# ────────────────────────────────────────────────────────────────────────────── + +DEFAULT_PURCHASE_FREQUENCY_PER_YEAR: int = 26 +"""Default purchase frequency assumption. + +26 = roughly every 2 weeks for a typical grocery staple. +Adjust with --freq to explore sensitivity. +""" + +PRICE_LOOKBACK_DAYS: int = 90 +"""How many days back to look for a "current" price observation.""" + +MIN_STORES_FOR_COMPARISON: int = 2 +"""Minimum number of distinct stores a product must appear at to be eligible.""" + + +# ────────────────────────────────────────────────────────────────────────────── +# Stat 1: shrinkflation count +# ────────────────────────────────────────────────────────────────────────────── + +SHRINKFLATION_COUNT_SQL = sa.text(""" + SELECT COUNT(DISTINCT se.normalized_product_id) AS shrinkflation_product_count + FROM shrinkflation_events se + WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +""") + +SHRINKFLATION_BY_CATEGORY_SQL = sa.text(""" + SELECT + COALESCE(np.category, 'unknown') AS category, + COUNT(DISTINCT se.normalized_product_id) AS product_count + FROM shrinkflation_events se + JOIN normalized_products np ON np.id = se.normalized_product_id + WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' + GROUP BY np.category + ORDER BY product_count DESC +""") + + +async def run_stat_1(session: AsyncSession) -> None: + """Validate: 847 products shrank in the past 12 months.""" + print("\n" + "=" * 70) + print("STAT 1: Products with shrinkflation events in the past 12 months") + print("Expected: ~847") + print("=" * 70) + + result = await session.execute(SHRINKFLATION_COUNT_SQL) + row = result.fetchone() + count = row[0] if row else 0 + print(f"\n Distinct products: {count:,}") + + announced = 847 + delta = count - announced + pct = (abs(delta) / announced * 100) if announced else 0 + status = "✓ MATCHES" if abs(delta) <= 10 else f"⚠ DIFFERS by {delta:+d} ({pct:.1f}%)" + print(f" Announced value: {announced:,}") + print(f" Status: {status}") + + print("\n Breakdown by category:") + cat_result = await session.execute(SHRINKFLATION_BY_CATEGORY_SQL) + for cat_row in cat_result.fetchall(): + print(f" {cat_row[0]:<20s} {cat_row[1]:>5,}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Stat 2: annual savings potential +# ────────────────────────────────────────────────────────────────────────────── + + +def savings_summary_sql(freq: int, lookback_days: int, min_stores: int) -> sa.TextClause: + """Build the savings summary query with runtime parameters.""" + return sa.text(f""" + WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + ph.regular_price AS current_price + FROM price_history ph + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '{lookback_days} days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC + ), + product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= {min_stores} + ) + SELECT + COUNT(*) AS eligible_products, + ROUND(AVG(avg_price - cheapest_price)::numeric, 4) AS avg_savings_per_purchase, + ROUND(SUM((avg_price - cheapest_price) * {freq})::numeric, 2) + AS total_annual_savings + FROM product_price_spread + """) + + +def savings_top_products_sql(freq: int, lookback_days: int, min_stores: int) -> sa.TextClause: + """Top 20 products by annual savings opportunity.""" + return sa.text(f""" + WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + ph.regular_price AS current_price + FROM price_history ph + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '{lookback_days} days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC + ), + product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= {min_stores} + ) + SELECT + np.canonical_name, + np.brand, + np.category, + ROUND((pps.avg_price - pps.cheapest_price)::numeric, 2) AS savings_per_purchase, + ROUND(((pps.avg_price - pps.cheapest_price) * {freq})::numeric, 2) AS annual_savings + FROM product_price_spread pps + JOIN normalized_products np ON np.id = pps.normalized_product_id + ORDER BY annual_savings DESC + LIMIT 20 + """) + + +async def run_stat_2(session: AsyncSession, freq: int) -> None: + """Validate: $336/year potential savings from cross-store price comparison.""" + print("\n" + "=" * 70) + print("STAT 2: Annual savings potential from buying at cheapest store") + print( + f"Assumptions: purchase freq={freq}x/year, price lookback={PRICE_LOOKBACK_DAYS}d, " + f"min_stores={MIN_STORES_FOR_COMPARISON}" + ) + print("Expected: ~$336/year") + print("=" * 70) + + result = await session.execute( + savings_summary_sql(freq, PRICE_LOOKBACK_DAYS, MIN_STORES_FOR_COMPARISON) + ) + row = result.fetchone() + if not row or row[0] == 0: + print("\n No eligible products found. Is production data loaded?") + return + + eligible, avg_save, total_annual = row + print(f"\n Eligible products (in 2+ stores): {eligible:,}") + print(f" Avg savings per purchase: ${avg_save:.4f}") + print(f" Estimated annual savings: ${total_annual:,.2f}") + + announced = Decimal("336.00") + delta = total_annual - announced + pct = abs(delta) / announced * 100 + # Allow ±10% tolerance for frequency assumption variance + status = "✓ WITHIN 10%" if pct <= 10 else f"⚠ DIFFERS by ${delta:+.2f} ({pct:.1f}%)" + print(f" Announced value: ${announced:,.2f}") + print(f" Status: {status}") + + print("\n Sensitivity (same data, different frequency assumptions):") + for alt_freq in (13, 20, 26, 40, 52): + alt = float(avg_save) * int(eligible) * alt_freq + marker = " ← default" if alt_freq == freq else "" + print(f" {alt_freq:>2}x/year: ${alt:>8,.2f}{marker}") + + print("\n Top 20 products by annual savings opportunity:") + top_result = await session.execute( + savings_top_products_sql(freq, PRICE_LOOKBACK_DAYS, MIN_STORES_FOR_COMPARISON) + ) + print(f" {'Product':<40s} {'Brand':<20s} {'Save/Buy':>8} {'Annual':>8}") + print(f" {'-' * 40} {'-' * 20} {'-' * 8} {'-' * 8}") + for r in top_result.fetchall(): + name = (r[0] or "")[:39] + brand = (r[1] or "")[:19] + print(f" {name:<40s} {brand:<20s} ${r[3]:>7.2f} ${r[4]:>7.2f}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Entry point +# ────────────────────────────────────────────────────────────────────────────── + + +async def main(stat: int | None, freq: int) -> None: + db_url = os.getenv("DATABASE_URL") + if not db_url: + print("ERROR: DATABASE_URL environment variable is not set.", file=sys.stderr) + print("Set it to your production Postgres URL, e.g.:", file=sys.stderr) + print(" export DATABASE_URL=postgresql+asyncpg://user:pass@host/db", file=sys.stderr) + sys.exit(1) + + engine = create_async_engine(db_url, echo=False) + async with AsyncSession(engine) as session: + if stat is None or stat == 1: + await run_stat_1(session) + if stat is None or stat == 2: + await run_stat_2(session, freq) + + await engine.dispose() + print("\nDone.\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--stat", + type=int, + choices=[1, 2], + default=None, + help="Run only stat 1 or stat 2 (default: both)", + ) + parser.add_argument( + "--freq", + type=int, + default=DEFAULT_PURCHASE_FREQUENCY_PER_YEAR, + help=( + "Purchase frequency per product per year " + f"(default: {DEFAULT_PURCHASE_FREQUENCY_PER_YEAR})" + ), + ) + args = parser.parse_args() + asyncio.run(main(stat=args.stat, freq=args.freq)) diff --git a/src/cartsnitch_common/__init__.py b/src/cartsnitch_common/__init__.py new file mode 100644 index 0000000..5ee899a --- /dev/null +++ b/src/cartsnitch_common/__init__.py @@ -0,0 +1,3 @@ +"""CartSnitch Common Library — shared models, schemas, and utilities.""" + +__version__ = "0.3.0" diff --git a/src/cartsnitch_common/config.py b/src/cartsnitch_common/config.py new file mode 100644 index 0000000..70b4153 --- /dev/null +++ b/src/cartsnitch_common/config.py @@ -0,0 +1,18 @@ +"""Shared configuration for CartSnitch services via pydantic-settings.""" + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Environment-driven settings shared by all CartSnitch services.""" + + model_config = SettingsConfigDict(env_prefix="CARTSNITCH_", env_file=".env") + + database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + database_url_sync: str = "postgresql+psycopg2://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + redis_url: str = "redis://localhost:6379/0" + debug: bool = False + log_level: str = "INFO" + + +settings = Settings() diff --git a/src/cartsnitch_common/constants.py b/src/cartsnitch_common/constants.py new file mode 100644 index 0000000..b7a716c --- /dev/null +++ b/src/cartsnitch_common/constants.py @@ -0,0 +1,85 @@ +"""Constants and enums shared across CartSnitch services.""" + +from enum import StrEnum + + +class StoreSlug(StrEnum): + """Supported retailer slugs.""" + + MEIJER = "meijer" + KROGER = "kroger" + TARGET = "target" + + +class AccountStatus(StrEnum): + """User store account link status.""" + + ACTIVE = "active" + EXPIRED = "expired" + ERROR = "error" + + +class DiscountType(StrEnum): + """Coupon discount type.""" + + PERCENT = "percent" + FIXED = "fixed" + BOGO = "bogo" + BUY_X_GET_Y = "buy_x_get_y" + + +class PriceSource(StrEnum): + """Source of a price observation.""" + + RECEIPT = "receipt" + CATALOG = "catalog" + WEEKLY_AD = "weekly_ad" + + +class EventType(StrEnum): + """Redis pub/sub event types.""" + + RECEIPTS_INGESTED = "cartsnitch.receipts.ingested" + PRICES_UPDATED = "cartsnitch.prices.updated" + PRODUCTS_NORMALIZED = "cartsnitch.products.normalized" + COUPONS_UPDATED = "cartsnitch.coupons.updated" + ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase" + ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation" + + +class ProductCategory(StrEnum): + """Top-level product categories.""" + + PRODUCE = "produce" + DAIRY = "dairy" + MEAT = "meat" + BAKERY = "bakery" + FROZEN = "frozen" + PANTRY = "pantry" + BEVERAGES = "beverages" + SNACKS = "snacks" + HOUSEHOLD = "household" + PERSONAL_CARE = "personal_care" + OTHER = "other" + + +class MatchConfidence(StrEnum): + """Confidence level for product matching.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class SizeUnit(StrEnum): + """Standardized product size units.""" + + OZ = "oz" + FL_OZ = "fl_oz" + LB = "lb" + G = "g" + KG = "kg" + ML = "ml" + L = "l" + CT = "ct" + PK = "pk" diff --git a/src/cartsnitch_common/database.py b/src/cartsnitch_common/database.py new file mode 100644 index 0000000..76a4f35 --- /dev/null +++ b/src/cartsnitch_common/database.py @@ -0,0 +1,45 @@ +"""Database engine and session factories for sync and async usage.""" + +from collections.abc import AsyncGenerator, Generator + +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker + +from cartsnitch_common.config import settings + + +def get_async_engine(url: str | None = None): + """Create an async SQLAlchemy engine.""" + return create_async_engine(url or settings.database_url, echo=settings.debug) + + +def get_sync_engine(url: str | None = None): + """Create a sync SQLAlchemy engine.""" + return create_engine(url or settings.database_url_sync, echo=settings.debug) + + +def get_async_session_factory(url: str | None = None) -> async_sessionmaker[AsyncSession]: + """Create an async session factory.""" + engine = get_async_engine(url) + return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +def get_sync_session_factory(url: str | None = None) -> sessionmaker[Session]: + """Create a sync session factory.""" + engine = get_sync_engine(url) + return sessionmaker(engine, expire_on_commit=False) + + +async def get_async_session(url: str | None = None) -> AsyncGenerator[AsyncSession, None]: + """Dependency for async session injection.""" + factory = get_async_session_factory(url) + async with factory() as session: + yield session + + +def get_sync_session(url: str | None = None) -> Generator[Session, None, None]: + """Dependency for sync session injection.""" + factory = get_sync_session_factory(url) + with factory() as session: + yield session diff --git a/src/cartsnitch_common/events.py b/src/cartsnitch_common/events.py new file mode 100644 index 0000000..986362f --- /dev/null +++ b/src/cartsnitch_common/events.py @@ -0,0 +1,28 @@ +"""Event bus helpers for Redis pub/sub.""" + +from datetime import UTC, datetime +from typing import Any, cast + +from redis import Redis + +from cartsnitch_common.constants import EventType +from cartsnitch_common.schemas.events import EventEnvelope + + +def publish_event( + redis_client: Redis, + event_type: EventType, + service: str, + payload: dict[str, Any], +) -> int: + """Publish an event to the Redis pub/sub channel. + + Returns the number of subscribers that received the message. + """ + envelope = EventEnvelope( + event_type=event_type, + timestamp=datetime.now(UTC), + service=service, + payload=payload, + ) + return cast(int, redis_client.publish(event_type.value, envelope.model_dump_json())) diff --git a/src/cartsnitch_common/models/__init__.py b/src/cartsnitch_common/models/__init__.py new file mode 100644 index 0000000..fb12ee3 --- /dev/null +++ b/src/cartsnitch_common/models/__init__.py @@ -0,0 +1,25 @@ +"""SQLAlchemy ORM models — re-exports all models for convenience.""" + +from cartsnitch_common.models.base import Base, TimestampMixin +from cartsnitch_common.models.coupon import Coupon +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.purchase import Purchase, PurchaseItem +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent +from cartsnitch_common.models.store import Store, StoreLocation +from cartsnitch_common.models.user import User, UserStoreAccount + +__all__ = [ + "Base", + "TimestampMixin", + "Store", + "StoreLocation", + "User", + "UserStoreAccount", + "Purchase", + "PurchaseItem", + "NormalizedProduct", + "PriceHistory", + "Coupon", + "ShrinkflationEvent", +] diff --git a/src/cartsnitch_common/models/base.py b/src/cartsnitch_common/models/base.py new file mode 100644 index 0000000..f93cf79 --- /dev/null +++ b/src/cartsnitch_common/models/base.py @@ -0,0 +1,30 @@ +"""Base model and mixins for all CartSnitch ORM models.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Base class for all CartSnitch models.""" + + +class TimestampMixin: + """Mixin providing created_at / updated_at columns.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + ) + + +class UUIDPrimaryKeyMixin: + """Mixin providing a UUID primary key.""" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + ) diff --git a/src/cartsnitch_common/models/coupon.py b/src/cartsnitch_common/models/coupon.py new file mode 100644 index 0000000..6e6305d --- /dev/null +++ b/src/cartsnitch_common/models/coupon.py @@ -0,0 +1,42 @@ +"""Coupon model.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import DiscountType +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.store import Store + + +class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A coupon or deal for a product at a store.""" + + __tablename__ = "coupons" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + title: Mapped[str] = mapped_column(String(300), nullable=False) + description: Mapped[str | None] = mapped_column(String(1000)) + discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False) + discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + valid_from: Mapped[date | None] = mapped_column(Date) + valid_to: Mapped[date | None] = mapped_column(Date) + requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + coupon_code: Mapped[str | None] = mapped_column(String(100)) + source_url: Mapped[str | None] = mapped_column(String(500)) + scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="coupons") + normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons") diff --git a/src/cartsnitch_common/models/price.py b/src/cartsnitch_common/models/price.py new file mode 100644 index 0000000..5814aac --- /dev/null +++ b/src/cartsnitch_common/models/price.py @@ -0,0 +1,50 @@ +"""PriceHistory model — tracks product prices over time.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Index, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.purchase import PurchaseItem + from cartsnitch_common.models.store import Store + + +class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single price observation for a product at a store on a date.""" + + __tablename__ = "price_history" + __table_args__ = ( + Index( + "ix_price_history_product_store_date", + "normalized_product_id", + "store_id", + "observed_date", + ), + ) + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + observed_date: Mapped[date] = mapped_column(Date, nullable=False) + regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source: Mapped[PriceSource] = mapped_column(String(20), nullable=False) + purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id")) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories") + store: Mapped["Store"] = relationship(back_populates="price_histories") + purchase_item: Mapped["PurchaseItem | None"] = relationship( + back_populates="price_history_entries" + ) diff --git a/src/cartsnitch_common/models/product.py b/src/cartsnitch_common/models/product.py new file mode 100644 index 0000000..215e57e --- /dev/null +++ b/src/cartsnitch_common/models/product.py @@ -0,0 +1,39 @@ +"""NormalizedProduct model — the canonical product identity.""" + +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import ProductCategory, SizeUnit +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.coupon import Coupon + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.purchase import PurchaseItem + from cartsnitch_common.models.shrinkflation import ShrinkflationEvent + + +class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Canonical product identity — matches products across retailers.""" + + __tablename__ = "normalized_products" + + canonical_name: Mapped[str] = mapped_column(String(300), nullable=False) + category: Mapped[ProductCategory | None] = mapped_column(String(50)) + subcategory: Mapped[str | None] = mapped_column(String(100)) + brand: Mapped[str | None] = mapped_column(String(200)) + size: Mapped[str | None] = mapped_column(String(50)) + size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) + upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) + + # Relationships + purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") + price_histories: Mapped[list["PriceHistory"]] = relationship( + back_populates="normalized_product" + ) + coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product") + shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship( + back_populates="normalized_product" + ) diff --git a/src/cartsnitch_common/models/purchase.py b/src/cartsnitch_common/models/purchase.py new file mode 100644 index 0000000..3797ef2 --- /dev/null +++ b/src/cartsnitch_common/models/purchase.py @@ -0,0 +1,91 @@ +"""Purchase and PurchaseItem models.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import ( + JSON, + Date, + DateTime, + ForeignKey, + Index, + Numeric, + String, + UniqueConstraint, + func, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.store import Store, StoreLocation + from cartsnitch_common.models.user import User + + +class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single shopping trip / receipt.""" + + __tablename__ = "purchases" + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id")) + receipt_id: Mapped[str] = mapped_column(String(200), nullable=False) + purchase_date: Mapped[date] = mapped_column(Date, nullable=False) + total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source_url: Mapped[str | None] = mapped_column(String(500)) + raw_data: Mapped[dict | None] = mapped_column(JSON) + ingested_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="purchases") + store: Mapped["Store"] = relationship(back_populates="purchases") + store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases") + items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase") + + __table_args__ = ( + Index("ix_purchases_user_store", "user_id", "store_id"), + UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"), + ) + + +class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Individual line item on a receipt.""" + + __tablename__ = "purchase_items" + + purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False) + product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False) + upc: Mapped[str | None] = mapped_column(String(20)) + quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1) + unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + category_raw: Mapped[str | None] = mapped_column(String(100)) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + + # Relationships + purchase: Mapped["Purchase"] = relationship(back_populates="items") + normalized_product: Mapped["NormalizedProduct | None"] = relationship( + back_populates="purchase_items" + ) + price_history_entries: Mapped[list["PriceHistory"]] = relationship( + back_populates="purchase_item" + ) diff --git a/src/cartsnitch_common/models/shrinkflation.py b/src/cartsnitch_common/models/shrinkflation.py new file mode 100644 index 0000000..d198713 --- /dev/null +++ b/src/cartsnitch_common/models/shrinkflation.py @@ -0,0 +1,41 @@ +"""ShrinkflationEvent model.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + + +class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Detected shrinkflation event — product size changed while price held or rose.""" + + __tablename__ = "shrinkflation_events" + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + detected_date: Mapped[date] = mapped_column(Date, nullable=False) + old_size: Mapped[str] = mapped_column(String(50), nullable=False) + new_size: Mapped[str] = mapped_column(String(50), nullable=False) + old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + confidence: Mapped[Decimal] = mapped_column( + Numeric(3, 2), nullable=False, default=Decimal("1.00") + ) + notes: Mapped[str | None] = mapped_column(String(1000)) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship( + back_populates="shrinkflation_events" + ) diff --git a/src/cartsnitch_common/models/store.py b/src/cartsnitch_common/models/store.py new file mode 100644 index 0000000..cde7760 --- /dev/null +++ b/src/cartsnitch_common/models/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation models.""" + +import uuid +from typing import TYPE_CHECKING + +from sqlalchemy import Float, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import StoreSlug +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.coupon import Coupon + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.purchase import Purchase + from cartsnitch_common.models.user import UserStoreAccount + + +class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Supported retailer.""" + + __tablename__ = "stores" + + name: Mapped[str] = mapped_column(String(100), nullable=False) + slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True) + logo_url: Mapped[str | None] = mapped_column(String(500)) + website_url: Mapped[str | None] = mapped_column(String(500)) + + # Relationships + locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store") + user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store") + price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store") + coupons: Mapped[list["Coupon"]] = relationship(back_populates="store") + + +class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Physical store location.""" + + __tablename__ = "store_locations" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + address: Mapped[str] = mapped_column(String(300), nullable=False) + city: Mapped[str] = mapped_column(String(100), nullable=False) + state: Mapped[str] = mapped_column(String(2), nullable=False) + zip: Mapped[str] = mapped_column(String(10), nullable=False) + lat: Mapped[float | None] = mapped_column(Float) + lng: Mapped[float | None] = mapped_column(Float) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="locations") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location") diff --git a/src/cartsnitch_common/models/user.py b/src/cartsnitch_common/models/user.py new file mode 100644 index 0000000..e2b1bfb --- /dev/null +++ b/src/cartsnitch_common/models/user.py @@ -0,0 +1,51 @@ +"""User and UserStoreAccount models.""" + +import uuid +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import AccountStatus +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.purchase import Purchase + from cartsnitch_common.models.store import Store + + +class User(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Application user.""" + + __tablename__ = "users" + + email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100)) + + # Relationships + store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="user") + + +class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Link between a user and their retailer account credentials.""" + + __tablename__ = "user_store_accounts" + __table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),) + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + # WARNING: Contains retailer session cookies/tokens. Encryption-at-rest + # required before production deployment (e.g., pgcrypto or app-level encryption). + session_data: Mapped[dict | None] = mapped_column(JSON) + session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + status: Mapped[AccountStatus] = mapped_column( + String(20), nullable=False, default=AccountStatus.ACTIVE + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="store_accounts") + store: Mapped["Store"] = relationship(back_populates="user_accounts") diff --git a/src/cartsnitch_common/normalization.py b/src/cartsnitch_common/normalization.py new file mode 100644 index 0000000..d448fa9 --- /dev/null +++ b/src/cartsnitch_common/normalization.py @@ -0,0 +1,156 @@ +"""Product normalization — Phase 1: UPC matching + fuzzy name matching. + +Matches products across retailers by: +1. Exact UPC match (highest confidence) +2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) +""" + +import re +from dataclasses import dataclass +from enum import StrEnum + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from cartsnitch_common.models.product import NormalizedProduct + + +class MatchMethod(StrEnum): + """How a product match was determined.""" + + UPC = "upc" + NAME = "name" + + +@dataclass(frozen=True) +class MatchResult: + """Result of a product normalization attempt.""" + + product: NormalizedProduct + confidence: float + method: MatchMethod + + +# Noise words stripped during name cleaning +_NOISE_WORDS = frozenset( + { + "the", + "a", + "an", + "and", + "or", + "of", + "with", + "in", + "for", + "to", + "brand", + "original", + "classic", + "new", + "improved", + } +) + +# Regex for extracting size info (e.g., "16 oz", "1.5 lb", "12 ct") +_SIZE_PATTERN = re.compile( + r"(\d+(?:\.\d+)?)\s*(oz|fl\s*oz|lb|lbs|g|kg|ml|l|ct|pk|count|pack)\b", + re.IGNORECASE, +) + + +def clean_name(name: str) -> str: + """Normalize a product name for comparison. + + - Lowercase + - Remove size info (e.g., "16 oz") + - Strip noise words + - Collapse whitespace + """ + cleaned = name.lower() + cleaned = _SIZE_PATTERN.sub("", cleaned) + cleaned = re.sub(r"[^\w\s]", " ", cleaned) + tokens = cleaned.split() + tokens = [t for t in tokens if t not in _NOISE_WORDS] + return " ".join(tokens) + + +def extract_size_info(name: str) -> tuple[str, str] | None: + """Extract (size, unit) from a product name, if present.""" + match = _SIZE_PATTERN.search(name) + if match: + return match.group(1), match.group(2).lower().replace(" ", "_") + return None + + +def jaccard_similarity(a: str, b: str) -> float: + """Token-based Jaccard similarity between two cleaned names.""" + tokens_a = set(a.split()) + tokens_b = set(b.split()) + if not tokens_a or not tokens_b: + return 0.0 + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) + + +def match_by_upc(session: Session, upc: str) -> MatchResult | None: + """Find a normalized product by exact UPC match. + + Loads products with upc_variants and checks membership in Python + for cross-database compatibility (works on both PostgreSQL and SQLite). + """ + # TODO: Use PostgreSQL JSON containment query (@>) for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None)) + products = session.execute(stmt).scalars().all() + for product in products: + if product.upc_variants and upc in product.upc_variants: + return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) + return None + + +def match_by_name( + session: Session, + name: str, + threshold: float = 0.5, +) -> MatchResult | None: + """Find the best normalized product by fuzzy name matching. + + Loads all normalized products and computes Jaccard similarity. + Returns the best match above the threshold, or None. + """ + # TODO: Use pg_trgm similarity index for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + cleaned = clean_name(name) + stmt = select(NormalizedProduct) + products = session.execute(stmt).scalars().all() + + best_match: NormalizedProduct | None = None + best_score = 0.0 + + for product in products: + score = jaccard_similarity(cleaned, clean_name(product.canonical_name)) + if score > best_score and score >= threshold: + best_score = score + best_match = product + + if best_match: + return MatchResult(product=best_match, confidence=best_score, method=MatchMethod.NAME) + return None + + +def normalize_product( + session: Session, + name: str, + upc: str | None = None, + name_threshold: float = 0.5, +) -> MatchResult | None: + """Full normalization pipeline: UPC first, then fuzzy name fallback.""" + if upc: + result = match_by_upc(session, upc) + if result: + return result + return match_by_name(session, name, threshold=name_threshold) diff --git a/src/cartsnitch_common/pipeline/__init__.py b/src/cartsnitch_common/pipeline/__init__.py new file mode 100644 index 0000000..82b586b --- /dev/null +++ b/src/cartsnitch_common/pipeline/__init__.py @@ -0,0 +1,26 @@ +"""Data pipeline — receipt normalization, product matching, price tracking, shrinkflation.""" + +from cartsnitch_common.pipeline.matching import ( + ConfidenceLevel, + ProductMatcher, + match_purchase_item, +) +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_price_trend, + record_price_from_item, +) +from cartsnitch_common.pipeline.receipt import normalize_receipt, parse_meijer_item +from cartsnitch_common.pipeline.shrinkflation import detect_shrinkflation + +__all__ = [ + "ConfidenceLevel", + "PriceDelta", + "ProductMatcher", + "detect_shrinkflation", + "get_price_trend", + "match_purchase_item", + "normalize_receipt", + "parse_meijer_item", + "record_price_from_item", +] diff --git a/src/cartsnitch_common/pipeline/matching.py b/src/cartsnitch_common/pipeline/matching.py new file mode 100644 index 0000000..ef4512f --- /dev/null +++ b/src/cartsnitch_common/pipeline/matching.py @@ -0,0 +1,136 @@ +"""Product matching & dedup — UPC primary, fuzzy name fallback, confidence scoring. + +Wraps the Phase 1 normalization module with confidence-level classification +and batch matching for purchase ingestion. +""" + +import uuid +from dataclasses import dataclass + +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import ( + MatchMethod, + MatchResult, + extract_size_info, + normalize_product, +) +from cartsnitch_common.schemas.purchase import PurchaseItemCreate + +# Re-export for convenience +ConfidenceLevel = MatchConfidence + + +@dataclass(frozen=True) +class MatchOutcome: + """Result of matching a single purchase item to a normalized product.""" + + item_index: int + match: MatchResult | None + confidence_level: MatchConfidence + created_new: bool = False + + +def classify_confidence(score: float, method: MatchMethod) -> MatchConfidence: + """Classify a match score into high/medium/low confidence.""" + if method == MatchMethod.UPC: + return MatchConfidence.HIGH + # Name-based matching thresholds + if score >= 0.8: + return MatchConfidence.HIGH + if score >= 0.5: + return MatchConfidence.MEDIUM + return MatchConfidence.LOW + + +def _create_product_from_item( + session: Session, + item: PurchaseItemCreate, +) -> NormalizedProduct: + """Create a new NormalizedProduct from a purchase item that had no match.""" + size_info = extract_size_info(item.product_name_raw) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=item.product_name_raw, + size=size_info[0] if size_info else None, + size_unit=size_info[1] if size_info else None, + upc_variants=[item.upc] if item.upc else [], + ) + session.add(product) + session.flush() + return product + + +class ProductMatcher: + """Batch product matcher for purchase ingestion. + + Usage: + matcher = ProductMatcher(session) + outcomes = matcher.match_items(items) + """ + + def __init__( + self, + session: Session, + name_threshold: float = 0.4, + auto_create: bool = True, + ): + self.session = session + self.name_threshold = name_threshold + self.auto_create = auto_create + + def match_single( + self, + item: PurchaseItemCreate, + ) -> tuple[NormalizedProduct | None, MatchResult | None, MatchConfidence]: + """Match a single purchase item to a normalized product. + + Returns (product, match_result, confidence_level). + If auto_create is True and no match found, creates a new product. + """ + result = normalize_product( + self.session, + item.product_name_raw, + upc=item.upc, + name_threshold=self.name_threshold, + ) + + if result: + confidence = classify_confidence(result.confidence, result.method) + return result.product, result, confidence + + if self.auto_create: + product = _create_product_from_item(self.session, item) + return product, None, MatchConfidence.LOW + + return None, None, MatchConfidence.LOW + + def match_items(self, items: list[PurchaseItemCreate]) -> list[MatchOutcome]: + """Match a batch of purchase items. Returns outcomes in order.""" + outcomes: list[MatchOutcome] = [] + for idx, item in enumerate(items): + product, result, confidence = self.match_single(item) + created = result is None and product is not None + outcomes.append( + MatchOutcome( + item_index=idx, + match=result, + confidence_level=confidence, + created_new=created, + ) + ) + return outcomes + + +def match_purchase_item( + session: Session, + item: PurchaseItemCreate, + name_threshold: float = 0.4, + auto_create: bool = True, +) -> tuple[NormalizedProduct | None, MatchConfidence]: + """Convenience function: match a single item, return (product, confidence).""" + matcher = ProductMatcher(session, name_threshold=name_threshold, auto_create=auto_create) + product, _, confidence = matcher.match_single(item) + return product, confidence diff --git a/src/cartsnitch_common/pipeline/price_tracking.py b/src/cartsnitch_common/pipeline/price_tracking.py new file mode 100644 index 0000000..24c3d5f --- /dev/null +++ b/src/cartsnitch_common/pipeline/price_tracking.py @@ -0,0 +1,130 @@ +"""Price history tracking — record prices and detect deltas. + +On each purchase ingestion, writes price_history records and detects +price changes from previous entries for the same product+store. +""" + +import uuid +from dataclasses import dataclass +from datetime import date +from decimal import Decimal + +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.models.price import PriceHistory + + +@dataclass(frozen=True) +class PriceDelta: + """A detected price change for a product at a store.""" + + product_id: uuid.UUID + store_id: uuid.UUID + old_price: Decimal + new_price: Decimal + change_amount: Decimal + change_percent: Decimal + old_date: date + new_date: date + + @property + def is_increase(self) -> bool: + return self.change_amount > 0 + + @property + def is_decrease(self) -> bool: + return self.change_amount < 0 + + +def get_latest_price( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, +) -> PriceHistory | None: + """Get the most recent price entry for a product at a store.""" + stmt = ( + select(PriceHistory) + .where( + and_( + PriceHistory.normalized_product_id == product_id, + PriceHistory.store_id == store_id, + ) + ) + .order_by(PriceHistory.observed_date.desc()) + .limit(1) + ) + return session.execute(stmt).scalar_one_or_none() + + +def record_price_from_item( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, + observed_date: date, + regular_price: Decimal, + sale_price: Decimal | None = None, + loyalty_price: Decimal | None = None, + coupon_price: Decimal | None = None, + purchase_item_id: uuid.UUID | None = None, + source: PriceSource = PriceSource.RECEIPT, +) -> tuple[PriceHistory, PriceDelta | None]: + """Record a price observation and return any detected delta. + + Returns (price_history_entry, price_delta_or_none). + """ + previous = get_latest_price(session, product_id, store_id) + + entry = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product_id, + store_id=store_id, + observed_date=observed_date, + regular_price=regular_price, + sale_price=sale_price, + loyalty_price=loyalty_price, + coupon_price=coupon_price, + source=source, + purchase_item_id=purchase_item_id, + ) + session.add(entry) + session.flush() + + delta = None + if previous and previous.regular_price != regular_price: + change = regular_price - previous.regular_price + pct = (change / previous.regular_price * 100) if previous.regular_price else Decimal("0") + delta = PriceDelta( + product_id=product_id, + store_id=store_id, + old_price=previous.regular_price, + new_price=regular_price, + change_amount=change, + change_percent=pct.quantize(Decimal("0.01")), + old_date=previous.observed_date, + new_date=observed_date, + ) + + return entry, delta + + +def get_price_trend( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, + limit: int = 30, +) -> list[PriceHistory]: + """Get recent price history for a product at a store, newest first.""" + stmt = ( + select(PriceHistory) + .where( + and_( + PriceHistory.normalized_product_id == product_id, + PriceHistory.store_id == store_id, + ) + ) + .order_by(PriceHistory.observed_date.desc()) + .limit(limit) + ) + return list(session.execute(stmt).scalars().all()) diff --git a/src/cartsnitch_common/pipeline/receipt.py b/src/cartsnitch_common/pipeline/receipt.py new file mode 100644 index 0000000..7d3e863 --- /dev/null +++ b/src/cartsnitch_common/pipeline/receipt.py @@ -0,0 +1,144 @@ +"""Receipt normalization — parse raw Meijer scraper output into purchase records. + +Maps raw receipt fields, cleans product names, extracts quantities/units. +""" + +import re +from datetime import date +from decimal import Decimal, InvalidOperation + +from cartsnitch_common.schemas.purchase import PurchaseCreate, PurchaseItemCreate + + +def _clean_product_name(raw: str) -> str: + """Clean raw product name from scraper output.""" + cleaned = raw.strip() + # Remove leading/trailing non-alphanumeric chars + cleaned = re.sub(r"^\W+|\W+$", "", cleaned) + # Collapse internal whitespace + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned + + +def _safe_decimal( + value: str | float | int | Decimal | None, + default: Decimal = Decimal("0"), +) -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return default + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError): + return default + + +def parse_meijer_item(raw_item: dict) -> PurchaseItemCreate: + """Parse a single Meijer scraper line item into a PurchaseItemCreate. + + Expected raw_item keys (from Meijer scraper): + - description / name: product name + - upc / upcCode: UPC barcode + - quantity / qty: number of units + - unitPrice / price: per-unit price + - extendedPrice / totalPrice: line total + - regularPrice: shelf price before discounts + - salePrice: sale price if applicable + - couponAmount / couponDiscount: coupon savings + - loyaltyAmount / loyaltyDiscount: loyalty savings + - category / department: raw category + """ + name = raw_item.get("description") or raw_item.get("name") or "" + cleaned_name = _clean_product_name(name) + + upc = raw_item.get("upc") or raw_item.get("upcCode") + if upc: + upc = str(upc).strip().lstrip("0") or str(upc).strip() + + qty = _safe_decimal( + raw_item.get("quantity") or raw_item.get("qty"), + default=Decimal("1"), + ) + + unit_price = _safe_decimal(raw_item.get("unitPrice") or raw_item.get("price")) + extended = _safe_decimal(raw_item.get("extendedPrice") or raw_item.get("totalPrice")) + if extended == Decimal("0") and unit_price > 0: + extended = unit_price * qty + + regular = raw_item.get("regularPrice") + sale = raw_item.get("salePrice") + coupon = raw_item.get("couponAmount") or raw_item.get("couponDiscount") + loyalty = raw_item.get("loyaltyAmount") or raw_item.get("loyaltyDiscount") + category = raw_item.get("category") or raw_item.get("department") + + return PurchaseItemCreate( + product_name_raw=cleaned_name, + upc=upc, + quantity=qty, + unit_price=unit_price, + extended_price=extended, + regular_price=_safe_decimal(regular) if regular is not None else None, + sale_price=_safe_decimal(sale) if sale is not None else None, + coupon_discount=_safe_decimal(coupon) if coupon is not None else None, + loyalty_discount=_safe_decimal(loyalty) if loyalty is not None else None, + category_raw=str(category).strip() if category else None, + ) + + +def normalize_receipt( + raw_receipt: dict, + user_id: str, + store_id: str, +) -> PurchaseCreate: + """Parse a complete Meijer raw receipt into a PurchaseCreate. + + Expected raw_receipt keys: + - receiptId / receipt_id / id: unique receipt identifier + - date / purchaseDate / purchase_date: purchase date (YYYY-MM-DD or similar) + - total / totalAmount: receipt total + - subtotal: pre-tax subtotal + - tax / taxAmount: tax amount + - savings / totalSavings: total discount savings + - items: list of raw line item dicts + """ + import uuid + + receipt_id = str( + raw_receipt.get("receiptId") + or raw_receipt.get("receipt_id") + or raw_receipt.get("id") + or uuid.uuid4() + ) + + raw_date = ( + raw_receipt.get("date") + or raw_receipt.get("purchaseDate") + or raw_receipt.get("purchase_date") + ) + if isinstance(raw_date, str): + purchase_date = date.fromisoformat(raw_date[:10]) + elif isinstance(raw_date, date): + purchase_date = raw_date + else: + purchase_date = date.today() + + total = _safe_decimal(raw_receipt.get("total") or raw_receipt.get("totalAmount")) + subtotal = raw_receipt.get("subtotal") + tax = raw_receipt.get("tax") or raw_receipt.get("taxAmount") + savings = raw_receipt.get("savings") or raw_receipt.get("totalSavings") + + raw_items = raw_receipt.get("items") or [] + items = [parse_meijer_item(item) for item in raw_items] + + return PurchaseCreate( + user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id, + store_id=uuid.UUID(store_id) if isinstance(store_id, str) else store_id, + receipt_id=receipt_id, + purchase_date=purchase_date, + total=total, + subtotal=_safe_decimal(subtotal) if subtotal is not None else None, + tax=_safe_decimal(tax) if tax is not None else None, + savings_total=_safe_decimal(savings) if savings is not None else None, + raw_data=raw_receipt, + items=items, + ) diff --git a/src/cartsnitch_common/pipeline/shrinkflation.py b/src/cartsnitch_common/pipeline/shrinkflation.py new file mode 100644 index 0000000..0e6d2b3 --- /dev/null +++ b/src/cartsnitch_common/pipeline/shrinkflation.py @@ -0,0 +1,165 @@ +"""Shrinkflation detection — compare unit sizes across price history. + +Flags cases where a product's size decreased while price stayed flat or increased. +""" + +import uuid +from dataclasses import dataclass +from datetime import date +from decimal import Decimal + +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent + +# Conversion factors to a common base unit (grams for weight, ml for volume, count for discrete) +_WEIGHT_TO_GRAMS: dict[SizeUnit, Decimal] = { + SizeUnit.G: Decimal("1"), + SizeUnit.KG: Decimal("1000"), + SizeUnit.OZ: Decimal("28.3495"), + SizeUnit.LB: Decimal("453.592"), +} + +_VOLUME_TO_ML: dict[SizeUnit, Decimal] = { + SizeUnit.ML: Decimal("1"), + SizeUnit.L: Decimal("1000"), + SizeUnit.FL_OZ: Decimal("29.5735"), +} + +_COUNT_UNITS: set[SizeUnit] = {SizeUnit.CT, SizeUnit.PK} + + +def _to_comparable(size: str, unit: SizeUnit) -> Decimal | None: + """Convert a size+unit to a comparable numeric value. + + Returns None if units are not comparable (different measurement systems). + """ + try: + size_val = Decimal(size) + except Exception: + return None + + if unit in _WEIGHT_TO_GRAMS: + return size_val * _WEIGHT_TO_GRAMS[unit] + if unit in _VOLUME_TO_ML: + return size_val * _VOLUME_TO_ML[unit] + if unit in _COUNT_UNITS: + return size_val + return None + + +def _units_comparable(unit_a: SizeUnit, unit_b: SizeUnit) -> bool: + """Check if two units are in the same measurement system.""" + if unit_a in _WEIGHT_TO_GRAMS and unit_b in _WEIGHT_TO_GRAMS: + return True + if unit_a in _VOLUME_TO_ML and unit_b in _VOLUME_TO_ML: + return True + return unit_a in _COUNT_UNITS and unit_b in _COUNT_UNITS + + +@dataclass(frozen=True) +class ShrinkflationCandidate: + """A potential shrinkflation detection before writing to DB.""" + + product: NormalizedProduct + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + old_price: Decimal | None + new_price: Decimal | None + confidence: Decimal + size_change_pct: Decimal + + +def detect_shrinkflation( + session: Session, + product: NormalizedProduct, + new_size: str, + new_unit: SizeUnit, + new_price: Decimal | None = None, + detected_date: date | None = None, + min_size_decrease_pct: Decimal = Decimal("1"), +) -> ShrinkflationEvent | None: + """Check if a product's size has decreased (shrinkflation). + + Compares the new size against the product's recorded size. + If size decreased while price stayed flat or increased, records a shrinkflation event. + + Returns the ShrinkflationEvent if detected, None otherwise. + """ + if not product.size or not product.size_unit: + return None + + old_unit = SizeUnit(product.size_unit) + if not _units_comparable(old_unit, new_unit): + return None + + old_comparable = _to_comparable(product.size, old_unit) + new_comparable = _to_comparable(new_size, new_unit) + + if old_comparable is None or new_comparable is None: + return None + + if new_comparable >= old_comparable: + return None # Size didn't decrease + + size_change_pct = ((old_comparable - new_comparable) / old_comparable * 100).quantize( + Decimal("0.01") + ) + if size_change_pct < min_size_decrease_pct: + return None + + # Check existing events to avoid duplicates + existing = session.execute( + select(ShrinkflationEvent).where( + and_( + ShrinkflationEvent.normalized_product_id == product.id, + ShrinkflationEvent.old_size == product.size, + ShrinkflationEvent.new_size == new_size, + ) + ) + ).scalar_one_or_none() + + if existing: + return existing + + # Confidence: higher if size change is significant and price didn't drop + confidence = Decimal("0.70") + if size_change_pct >= Decimal("5"): + confidence = Decimal("0.85") + if size_change_pct >= Decimal("10"): + confidence = Decimal("0.95") + + # Get the last known price for comparison + old_price: Decimal | None = None + if product.price_histories: + latest = max(product.price_histories, key=lambda ph: ph.observed_date) + old_price = latest.regular_price + + if old_price is not None and new_price is not None and new_price < old_price: + # Price actually dropped — less likely to be shrinkflation + confidence = max(Decimal("0.30"), confidence - Decimal("0.30")) + + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=detected_date or date.today(), + old_size=product.size, + new_size=new_size, + old_unit=old_unit, + new_unit=new_unit, + price_at_old_size=old_price, + price_at_new_size=new_price, + confidence=confidence, + notes=( + f"Size decreased {size_change_pct}% ({product.size} {old_unit} → {new_size} {new_unit})" + ), + ) + session.add(event) + session.flush() + + return event diff --git a/src/cartsnitch_common/py.typed b/src/cartsnitch_common/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/cartsnitch_common/schemas/__init__.py b/src/cartsnitch_common/schemas/__init__.py new file mode 100644 index 0000000..aad52b4 --- /dev/null +++ b/src/cartsnitch_common/schemas/__init__.py @@ -0,0 +1,49 @@ +"""Pydantic v2 schemas for inter-service API contracts.""" + +from cartsnitch_common.schemas.coupon import CouponCreate, CouponRead +from cartsnitch_common.schemas.events import EventEnvelope +from cartsnitch_common.schemas.price import PriceHistoryCreate, PriceHistoryRead +from cartsnitch_common.schemas.product import NormalizedProductCreate, NormalizedProductRead +from cartsnitch_common.schemas.purchase import ( + PurchaseCreate, + PurchaseItemCreate, + PurchaseItemRead, + PurchaseRead, +) +from cartsnitch_common.schemas.shrinkflation import ShrinkflationEventCreate, ShrinkflationEventRead +from cartsnitch_common.schemas.store import ( + StoreCreate, + StoreLocationCreate, + StoreLocationRead, + StoreRead, +) +from cartsnitch_common.schemas.user import ( + UserCreate, + UserRead, + UserStoreAccountCreate, + UserStoreAccountRead, +) + +__all__ = [ + "StoreCreate", + "StoreRead", + "StoreLocationCreate", + "StoreLocationRead", + "UserCreate", + "UserRead", + "UserStoreAccountCreate", + "UserStoreAccountRead", + "PurchaseCreate", + "PurchaseRead", + "PurchaseItemCreate", + "PurchaseItemRead", + "NormalizedProductCreate", + "NormalizedProductRead", + "PriceHistoryCreate", + "PriceHistoryRead", + "CouponCreate", + "CouponRead", + "ShrinkflationEventCreate", + "ShrinkflationEventRead", + "EventEnvelope", +] diff --git a/src/cartsnitch_common/schemas/coupon.py b/src/cartsnitch_common/schemas/coupon.py new file mode 100644 index 0000000..cae2a20 --- /dev/null +++ b/src/cartsnitch_common/schemas/coupon.py @@ -0,0 +1,45 @@ +"""Coupon Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import DiscountType + + +class CouponCreate(BaseModel): + store_id: uuid.UUID + normalized_product_id: uuid.UUID | None = None + title: str + description: str | None = None + discount_type: DiscountType + discount_value: Decimal | None = None + min_purchase: Decimal | None = None + valid_from: date | None = None + valid_to: date | None = None + requires_clip: bool = False + coupon_code: str | None = None + source_url: str | None = None + + +class CouponRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + store_id: uuid.UUID + normalized_product_id: uuid.UUID | None + title: str + description: str | None + discount_type: DiscountType + discount_value: Decimal | None + min_purchase: Decimal | None + valid_from: date | None + valid_to: date | None + requires_clip: bool + coupon_code: str | None + source_url: str | None + scraped_at: datetime | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/events.py b/src/cartsnitch_common/schemas/events.py new file mode 100644 index 0000000..cff9adc --- /dev/null +++ b/src/cartsnitch_common/schemas/events.py @@ -0,0 +1,17 @@ +"""Redis pub/sub event envelope and payload schemas.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from cartsnitch_common.constants import EventType + + +class EventEnvelope(BaseModel): + """Standard event wrapper for all Redis pub/sub messages.""" + + event_type: EventType + timestamp: datetime + service: str + payload: dict[str, Any] diff --git a/src/cartsnitch_common/schemas/price.py b/src/cartsnitch_common/schemas/price.py new file mode 100644 index 0000000..4c46e1c --- /dev/null +++ b/src/cartsnitch_common/schemas/price.py @@ -0,0 +1,38 @@ +"""PriceHistory Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import PriceSource + + +class PriceHistoryCreate(BaseModel): + normalized_product_id: uuid.UUID + store_id: uuid.UUID + observed_date: date + regular_price: Decimal + sale_price: Decimal | None = None + loyalty_price: Decimal | None = None + coupon_price: Decimal | None = None + source: PriceSource + purchase_item_id: uuid.UUID | None = None + + +class PriceHistoryRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + normalized_product_id: uuid.UUID + store_id: uuid.UUID + observed_date: date + regular_price: Decimal + sale_price: Decimal | None + loyalty_price: Decimal | None + coupon_price: Decimal | None + source: PriceSource + purchase_item_id: uuid.UUID | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/product.py b/src/cartsnitch_common/schemas/product.py new file mode 100644 index 0000000..366e661 --- /dev/null +++ b/src/cartsnitch_common/schemas/product.py @@ -0,0 +1,33 @@ +"""NormalizedProduct Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel + +from cartsnitch_common.constants import ProductCategory, SizeUnit + + +class NormalizedProductCreate(BaseModel): + canonical_name: str + category: ProductCategory | None = None + subcategory: str | None = None + brand: str | None = None + size: str | None = None + size_unit: SizeUnit | None = None + upc_variants: list[str] = [] + + +class NormalizedProductRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + canonical_name: str + category: ProductCategory | None + subcategory: str | None + brand: str | None + size: str | None + size_unit: SizeUnit | None + upc_variants: list | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/purchase.py b/src/cartsnitch_common/schemas/purchase.py new file mode 100644 index 0000000..05959be --- /dev/null +++ b/src/cartsnitch_common/schemas/purchase.py @@ -0,0 +1,73 @@ +"""Purchase and PurchaseItem Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + + +class PurchaseItemCreate(BaseModel): + product_name_raw: str + upc: str | None = None + quantity: Decimal = Decimal("1") + unit_price: Decimal + extended_price: Decimal + regular_price: Decimal | None = None + sale_price: Decimal | None = None + coupon_discount: Decimal | None = None + loyalty_discount: Decimal | None = None + category_raw: str | None = None + normalized_product_id: uuid.UUID | None = None + + +class PurchaseItemRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + purchase_id: uuid.UUID + product_name_raw: str + upc: str | None + quantity: Decimal + unit_price: Decimal + extended_price: Decimal + regular_price: Decimal | None + sale_price: Decimal | None + coupon_discount: Decimal | None + loyalty_discount: Decimal | None + category_raw: str | None + normalized_product_id: uuid.UUID | None + + +class PurchaseCreate(BaseModel): + user_id: uuid.UUID + store_id: uuid.UUID + store_location_id: uuid.UUID | None = None + receipt_id: str + purchase_date: date + total: Decimal + subtotal: Decimal | None = None + tax: Decimal | None = None + savings_total: Decimal | None = None + source_url: str | None = None + raw_data: dict | None = None + items: list[PurchaseItemCreate] = [] + + +class PurchaseRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + user_id: uuid.UUID + store_id: uuid.UUID + store_location_id: uuid.UUID | None + receipt_id: str + purchase_date: date + total: Decimal + subtotal: Decimal | None + tax: Decimal | None + savings_total: Decimal | None + source_url: str | None + ingested_at: datetime + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/shrinkflation.py b/src/cartsnitch_common/schemas/shrinkflation.py new file mode 100644 index 0000000..4edb507 --- /dev/null +++ b/src/cartsnitch_common/schemas/shrinkflation.py @@ -0,0 +1,40 @@ +"""ShrinkflationEvent Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import SizeUnit + + +class ShrinkflationEventCreate(BaseModel): + normalized_product_id: uuid.UUID + detected_date: date + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + price_at_old_size: Decimal | None = None + price_at_new_size: Decimal | None = None + confidence: Decimal = Decimal("1.00") + notes: str | None = None + + +class ShrinkflationEventRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + normalized_product_id: uuid.UUID + detected_date: date + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + price_at_old_size: Decimal | None + price_at_new_size: Decimal | None + confidence: Decimal + notes: str | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/store.py b/src/cartsnitch_common/schemas/store.py new file mode 100644 index 0000000..99fa83f --- /dev/null +++ b/src/cartsnitch_common/schemas/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel + +from cartsnitch_common.constants import StoreSlug + + +class StoreCreate(BaseModel): + name: str + slug: StoreSlug + logo_url: str | None = None + website_url: str | None = None + + +class StoreRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + name: str + slug: StoreSlug + logo_url: str | None + website_url: str | None + created_at: datetime + updated_at: datetime + + +class StoreLocationCreate(BaseModel): + store_id: uuid.UUID + address: str + city: str + state: str + zip: str + lat: float | None = None + lng: float | None = None + + +class StoreLocationRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + store_id: uuid.UUID + address: str + city: str + state: str + zip: str + lat: float | None + lng: float | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/user.py b/src/cartsnitch_common/schemas/user.py new file mode 100644 index 0000000..2c174ba --- /dev/null +++ b/src/cartsnitch_common/schemas/user.py @@ -0,0 +1,44 @@ +"""User and UserStoreAccount Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel, EmailStr + +from cartsnitch_common.constants import AccountStatus + + +class UserCreate(BaseModel): + email: EmailStr + password: str + display_name: str | None = None + + +class UserRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + email: str + display_name: str | None + created_at: datetime + updated_at: datetime + + +class UserStoreAccountCreate(BaseModel): + user_id: uuid.UUID + store_id: uuid.UUID + session_data: dict | None = None + status: AccountStatus = AccountStatus.ACTIVE + + +class UserStoreAccountRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + user_id: uuid.UUID + store_id: uuid.UUID + status: AccountStatus + session_expires_at: datetime | None + last_sync_at: datetime | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/seed/__init__.py b/src/cartsnitch_common/seed/__init__.py new file mode 100644 index 0000000..199d0be --- /dev/null +++ b/src/cartsnitch_common/seed/__init__.py @@ -0,0 +1 @@ +"""Deterministic seed data generator for CartSnitch dev environment.""" diff --git a/src/cartsnitch_common/seed/__main__.py b/src/cartsnitch_common/seed/__main__.py new file mode 100644 index 0000000..2a623fd --- /dev/null +++ b/src/cartsnitch_common/seed/__main__.py @@ -0,0 +1,50 @@ +"""Entry point for `python -m cartsnitch_common.seed` and `cartsnitch-seed` CLI.""" + +import argparse +import sys + +from cartsnitch_common.seed.config import SEED_VALUE + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="cartsnitch-seed", + description="Generate deterministic seed data for the CartSnitch dev environment.", + ) + parser.add_argument( + "--database-url", + default=None, + help=( + "PostgreSQL connection URL (sync driver). " + "Defaults to CARTSNITCH_DATABASE_URL_SYNC env var or built-in default." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print planned record counts without writing to the database.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED_VALUE, + help=f"Random seed for deterministic output (default: {SEED_VALUE}).", + ) + + args = parser.parse_args() + + try: + from cartsnitch_common.seed.runner import run_seed + + run_seed( + database_url=args.database_url, + seed_value=args.seed, + dry_run=args.dry_run, + ) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/cartsnitch_common/seed/config.py b/src/cartsnitch_common/seed/config.py new file mode 100644 index 0000000..ae4c27b --- /dev/null +++ b/src/cartsnitch_common/seed/config.py @@ -0,0 +1,38 @@ +"""Seed configuration constants.""" + +from datetime import date + +# Random seed for deterministic output +SEED_VALUE: int = 42 + +# Date window: 6 months of history ending today (relative to seed baseline) +SEED_BASELINE_DATE: date = date(2026, 3, 21) +SEED_START_DATE: date = date(2025, 9, 21) +SEED_END_DATE: date = date(2026, 3, 21) + +# Scale targets +NUM_STORES: int = 3 +NUM_LOCATIONS_PER_STORE: int = 5 # 15 total +NUM_USERS: int = 500 +NUM_ACTIVE_USERS: int = 50 +NUM_USER_STORE_ACCOUNTS: int = 100 +NUM_PRODUCTS: int = 500 +NUM_PURCHASES: int = 5_000 +NUM_PURCHASE_ITEMS: int = 25_000 +NUM_PRICE_HISTORY: int = 50_000 +NUM_COUPONS: int = 200 +NUM_SHRINKFLATION_EVENTS: int = 20 + +# Price-increase products (for StickerShock detection) +# 10% of products should show a significant price increase (>10%) over the window +NUM_PRICE_INCREASE_PRODUCTS: int = 50 # ~10% of 500 + +# Coupon mix +COUPON_EXPIRED_PCT: float = 0.60 +COUPON_ACTIVE_PCT: float = 0.40 + +# Items per purchase (target avg to hit 25K total from 5K purchases) +AVG_ITEMS_PER_PURCHASE: int = 5 + +# Price history: ~100 observations per product (500 products * 100 = 50K) +PRICE_OBS_PER_PRODUCT: int = 100 diff --git a/src/cartsnitch_common/seed/generators/__init__.py b/src/cartsnitch_common/seed/generators/__init__.py new file mode 100644 index 0000000..08cbc5f --- /dev/null +++ b/src/cartsnitch_common/seed/generators/__init__.py @@ -0,0 +1 @@ +"""Seed data generators.""" diff --git a/src/cartsnitch_common/seed/generators/coupons.py b/src/cartsnitch_common/seed/generators/coupons.py new file mode 100644 index 0000000..890e90a --- /dev/null +++ b/src/cartsnitch_common/seed/generators/coupons.py @@ -0,0 +1,107 @@ +"""Generate Coupon seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta +from decimal import Decimal + +from faker import Faker + +from cartsnitch_common.constants import DiscountType +from cartsnitch_common.seed.config import ( + COUPON_EXPIRED_PCT, + NUM_COUPONS, + SEED_END_DATE, + SEED_START_DATE, +) + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +_COUPON_TITLES: list[str] = [ + "Save {val} on {product}", + "{val} off your next {product} purchase", + "Get {val} off {product}", + "Buy {product}, save {val}", + "Weekend special: {val} off {product}", + "Member exclusive: {val} off {product}", + "Digital coupon: {val} off {product}", +] + + +def generate_coupons( + fake: Faker, + products: list[dict], + stores: list[dict], +) -> list[dict]: + """Return NUM_COUPONS coupon records with realistic mix of active/expired.""" + now = datetime.now(tz=UTC) + today = SEED_END_DATE + coupons = [] + + num_expired = int(NUM_COUPONS * COUPON_EXPIRED_PCT) + num_active = NUM_COUPONS - num_expired + + def make_coupon(is_active: bool) -> dict: + store = random.choice(stores) + product = random.choice(products) if random.random() > 0.1 else None + product_name = product["canonical_name"].split(" ", 2)[-1] if product else "any item" + + discount_type = random.choice(list(DiscountType)) + + if discount_type == DiscountType.PERCENT: + discount_value = _decimal(random.choice([5, 10, 15, 20, 25, 30])) + title = f"Save {int(discount_value)}% on {product_name}" + elif discount_type == DiscountType.FIXED: + discount_value = _decimal(random.choice([0.50, 1.00, 1.50, 2.00, 2.50, 3.00, 5.00])) + title = f"Save ${discount_value} on {product_name}" + elif discount_type == DiscountType.BOGO: + discount_value = None + title = f"BOGO: Buy one {product_name}, get one free" + else: # BUY_X_GET_Y + discount_value = None + title = f"Buy 2 {product_name}, get 1 free" + + if is_active: + valid_from = today - timedelta(days=random.randint(1, 30)) + valid_to = today + timedelta(days=random.randint(1, 60)) + else: + valid_to = today - timedelta(days=random.randint(1, 180)) + valid_from = valid_to - timedelta(days=random.randint(7, 30)) + + requires_clip = random.random() > 0.5 + coupon_code = fake.bothify(text="??##-??##").upper() if not requires_clip else None + min_purchase = _decimal(random.choice([0, 0, 0, 5.00, 10.00, 15.00])) or None + + scraped_at = datetime( + SEED_START_DATE.year, SEED_START_DATE.month, SEED_START_DATE.day, tzinfo=UTC + ) + timedelta(days=random.randint(0, 180)) + + return { + "id": uuid.uuid4(), + "store_id": store["id"], + "normalized_product_id": product["id"] if product else None, + "title": title, + "description": fake.sentence(nb_words=10), + "discount_type": discount_type, + "discount_value": discount_value, + "min_purchase": min_purchase, + "valid_from": valid_from, + "valid_to": valid_to, + "requires_clip": requires_clip, + "coupon_code": coupon_code, + "source_url": None, + "scraped_at": scraped_at, + "created_at": now, + "updated_at": now, + } + + for _ in range(num_expired): + coupons.append(make_coupon(is_active=False)) + for _ in range(num_active): + coupons.append(make_coupon(is_active=True)) + + random.shuffle(coupons) + return coupons diff --git a/src/cartsnitch_common/seed/generators/prices.py b/src/cartsnitch_common/seed/generators/prices.py new file mode 100644 index 0000000..3e87344 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/prices.py @@ -0,0 +1,162 @@ +"""Generate PriceHistory seed data with realistic patterns for StickerShock detection.""" + +import random +import uuid +from datetime import UTC, date, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.seed.config import ( + NUM_PRICE_HISTORY, + NUM_PRICE_INCREASE_PRODUCTS, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + +# Holidays within the seed window for seasonal sales (approx) +_SALE_PERIODS: list[tuple[date, date]] = [ + (date(2025, 11, 27), date(2025, 11, 30)), # Thanksgiving / Black Friday + (date(2025, 12, 20), date(2025, 12, 26)), # Christmas + (date(2026, 1, 1), date(2026, 1, 2)), # New Year + (date(2026, 2, 14), date(2026, 2, 15)), # Valentine's Day +] + + +def _is_sale_period(d: date) -> bool: + return any(start <= d <= end for start, end in _SALE_PERIODS) + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +def _base_price_for_product(product: dict) -> float: + """Assign a realistic base price based on category.""" + from cartsnitch_common.constants import ProductCategory + + category_ranges: dict[ProductCategory, tuple[float, float]] = { + ProductCategory.PRODUCE: (1.49, 6.99), + ProductCategory.DAIRY: (2.99, 8.99), + ProductCategory.MEAT: (4.99, 19.99), + ProductCategory.BAKERY: (2.49, 7.99), + ProductCategory.FROZEN: (3.99, 12.99), + ProductCategory.PANTRY: (1.99, 9.99), + ProductCategory.BEVERAGES: (0.99, 6.99), + ProductCategory.SNACKS: (2.49, 6.99), + ProductCategory.HOUSEHOLD: (3.99, 19.99), + ProductCategory.PERSONAL_CARE: (3.99, 14.99), + } + cat: ProductCategory | None = product.get("category") + lo, hi = category_ranges.get(cat, (1.99, 9.99)) if cat is not None else (1.99, 9.99) + return random.uniform(lo, hi) + + +def generate_price_history( + products: list[dict], + stores: list[dict], + purchase_items: list[dict], +) -> list[dict]: + """Return ~NUM_PRICE_HISTORY price history records with realistic patterns. + + Pattern types (assigned per product): + - sudden_jump: flat then >10% price increase at a random point + - gradual_creep: slow steady increase over the window + - stable: nearly flat price with small noise + - sale_driven: drops during holiday periods, returns after + - volatile: random walk + + 10% of products (NUM_PRICE_INCREASE_PRODUCTS) will show a detectable + price increase (>10%) that StickerShock can flag. + """ + now = datetime.now(tz=UTC) + records: list[dict] = [] + + # Build purchase-item lookup: (product_id, store_id) -> [purchase_item_id] + item_lookup: dict[tuple, list[uuid.UUID]] = {} + for item in purchase_items: + key = (item["normalized_product_id"], item.get("_store_id")) + item_lookup.setdefault(key, []).append(item["id"]) + + total = NUM_PRICE_HISTORY + per_product_per_store = total // (len(products) * len(stores)) + per_product_per_store = max(per_product_per_store, 1) + + # Assign patterns + product_patterns: list[str] = [] + price_increase_indices = set(random.sample(range(len(products)), NUM_PRICE_INCREASE_PRODUCTS)) + pattern_pool = ["sale_driven", "stable", "gradual_creep", "volatile"] + for i in range(len(products)): + if i in price_increase_indices: + product_patterns.append(random.choice(["sudden_jump", "gradual_creep"])) + else: + product_patterns.append(random.choice(pattern_pool)) + + for i, product in enumerate(products): + pattern = product_patterns[i] + base_price = _base_price_for_product(product) + + # Jump point for sudden_jump (50-80% through window) + jump_day = int(_DATE_RANGE_DAYS * random.uniform(0.5, 0.8)) + jump_factor = random.uniform(1.10, 1.25) # 10-25% increase + + for store in stores: + # Generate obs dates spread across the window + obs_days = sorted( + random.sample( + range(_DATE_RANGE_DAYS + 1), + min(per_product_per_store, _DATE_RANGE_DAYS + 1), + ) + ) + + for day_offset in obs_days: + obs_date = SEED_START_DATE + timedelta(days=day_offset) + progress = day_offset / max(_DATE_RANGE_DAYS, 1) + + # Compute regular price by pattern + if pattern == "sudden_jump": + if day_offset < jump_day: + price = base_price + random.uniform(-0.05, 0.05) + else: + price = base_price * jump_factor + random.uniform(-0.05, 0.05) + elif pattern == "gradual_creep": + price = base_price * (1 + 0.12 * progress) + random.uniform(-0.10, 0.10) + elif pattern == "stable": + price = base_price + random.uniform(-0.10, 0.10) + elif pattern == "volatile": + price = base_price * random.uniform(0.85, 1.15) + else: + price = base_price + random.uniform(-0.05, 0.05) + + price = max(0.99, price) + regular_price = _decimal(price) + + # Sale price during holiday periods + sale_price: Decimal | None = None + if _is_sale_period(obs_date): + sale_price = _decimal(price * random.uniform(0.75, 0.90)) + + records.append( + { + "id": uuid.uuid4(), + "normalized_product_id": product["id"], + "store_id": store["id"], + "observed_date": obs_date, + "regular_price": regular_price, + "sale_price": sale_price, + "loyalty_price": None, + "coupon_price": None, + "source": ( + PriceSource.RECEIPT if random.random() > 0.3 else PriceSource.CATALOG + ), + "purchase_item_id": None, + "created_at": now, + "updated_at": now, + } + ) + + if len(records) >= NUM_PRICE_HISTORY: + return records + + return records diff --git a/src/cartsnitch_common/seed/generators/products.py b/src/cartsnitch_common/seed/generators/products.py new file mode 100644 index 0000000..6fad535 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/products.py @@ -0,0 +1,253 @@ +"""Generate NormalizedProduct seed data.""" + +import random +import uuid +from datetime import UTC, datetime + +from faker import Faker + +from cartsnitch_common.constants import ProductCategory, SizeUnit +from cartsnitch_common.seed.config import NUM_PRODUCTS + +# Product templates per category: (category, brands, names, sizes, default_unit) +_PRODUCT_TEMPLATES: list[tuple[ProductCategory, list[str], list[str], list[str], SizeUnit]] = [ + ( + ProductCategory.PRODUCE, + ["Organic Valley", "Earthbound Farm", "Local Farm", "Fresh Farms"], + [ + "Bananas", + "Apples", + "Baby Carrots", + "Spinach", + "Broccoli", + "Strawberries", + "Blueberries", + "Grapes", + "Tomatoes", + "Lettuce", + ], + ["1 lb", "2 lb", "16 oz", "12 oz", "5 oz", "6 oz", "32 oz"], + SizeUnit.LB, + ), + ( + ProductCategory.DAIRY, + ["Kraft", "Tillamook", "Great Value", "Land O'Lakes", "Daisy", "Organic Valley"], + [ + "Whole Milk", + "2% Milk", + "Cheddar Cheese", + "Mozzarella", + "Greek Yogurt", + "Butter", + "Cream Cheese", + "Sour Cream", + "Heavy Cream", + "Cottage Cheese", + ], + ["16 oz", "32 oz", "64 oz", "1 gallon", "8 oz", "12 oz", "5 oz"], + SizeUnit.FL_OZ, + ), + ( + ProductCategory.MEAT, + ["Tyson", "Perdue", "Smithfield", "Oscar Mayer", "Applegate", "Kirkland"], + [ + "Chicken Breast", + "Ground Beef", + "Pork Chops", + "Bacon", + "Turkey", + "Salmon", + "Tilapia", + "Sausage", + "Hot Dogs", + "Deli Ham", + ], + ["1 lb", "2 lb", "3 lb", "12 oz", "16 oz", "24 oz"], + SizeUnit.LB, + ), + ( + ProductCategory.BAKERY, + ["Nature's Own", "Dave's Killer Bread", "Pepperidge Farm", "Sara Lee", "Arnold"], + [ + "White Bread", + "Whole Wheat Bread", + "Sourdough", + "Bagels", + "English Muffins", + "Croissants", + "Dinner Rolls", + "Hamburger Buns", + "Hot Dog Buns", + "Muffins", + ], + ["20 oz", "24 oz", "6 ct", "8 ct", "12 ct", "16 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.FROZEN, + ["Stouffer's", "Amy's", "Birds Eye", "Green Giant", "Totino's", "DiGiorno"], + [ + "Frozen Pizza", + "Mac and Cheese", + "Frozen Burritos", + "Chicken Nuggets", + "Fish Sticks", + "Frozen Vegetables", + "Ice Cream", + "Frozen Waffles", + "Tater Tots", + "Frozen Lasagna", + ], + ["12 oz", "16 oz", "24 oz", "32 oz", "4 ct", "8 ct"], + SizeUnit.OZ, + ), + ( + ProductCategory.PANTRY, + ["Campbell's", "Hunt's", "Kraft", "Heinz", "Del Monte", "General Mills", "Kellogg's"], + [ + "Pasta Sauce", + "Canned Tomatoes", + "Chicken Noodle Soup", + "Peanut Butter", + "Jelly", + "Olive Oil", + "Rice", + "Pasta", + "Oatmeal", + "Cereal", + ], + ["15 oz", "24 oz", "32 oz", "18 oz", "16 oz", "24 oz", "48 oz", "64 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.BEVERAGES, + ["Coca-Cola", "Pepsi", "Tropicana", "Minute Maid", "Gatorade", "LaCroix", "Nestle"], + [ + "Cola", + "Diet Cola", + "Orange Juice", + "Apple Juice", + "Sports Drink", + "Sparkling Water", + "Iced Coffee", + "Energy Drink", + "Lemonade", + "Green Tea", + ], + ["12 fl oz", "20 fl oz", "32 fl oz", "64 fl oz", "2 liter", "6 pk", "12 pk"], + SizeUnit.FL_OZ, + ), + ( + ProductCategory.SNACKS, + ["Frito-Lay", "Nabisco", "Kellogg's", "Pepperidge Farm", "Clif Bar", "KIND", "Planters"], + [ + "Potato Chips", + "Tortilla Chips", + "Pretzels", + "Crackers", + "Granola Bars", + "Trail Mix", + "Popcorn", + "Cookies", + "Nuts", + "Fruit Snacks", + ], + ["7 oz", "10 oz", "16 oz", "6 ct", "12 ct", "18 ct", "3.5 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.HOUSEHOLD, + ["Tide", "Dawn", "Bounty", "Charmin", "Clorox", "Method", "Seventh Generation"], + [ + "Laundry Detergent", + "Dish Soap", + "Paper Towels", + "Toilet Paper", + "Bleach", + "All-Purpose Cleaner", + "Fabric Softener", + "Dryer Sheets", + "Trash Bags", + "Sponges", + ], + ["32 oz", "64 oz", "100 oz", "6 pk", "12 pk", "24 ct", "2 pk"], + SizeUnit.OZ, + ), + ( + ProductCategory.PERSONAL_CARE, + ["Dove", "Pantene", "Colgate", "Crest", "Gillette", "L'Oreal", "Neutrogena"], + [ + "Shampoo", + "Conditioner", + "Body Wash", + "Toothpaste", + "Deodorant", + "Face Wash", + "Lotion", + "Razor", + "Shaving Cream", + "Hand Soap", + ], + ["12 oz", "24 oz", "32 oz", "3.4 oz", "6 oz", "8 oz", "2 pk"], + SizeUnit.OZ, + ), +] + + +def _generate_upc() -> str: + """Generate a fake 12-digit UPC.""" + digits = [random.randint(0, 9) for _ in range(11)] + odd_sum = sum(digits[i] for i in range(0, 11, 2)) + even_sum = sum(digits[i] for i in range(1, 11, 2)) + check = (10 - ((odd_sum * 3 + even_sum) % 10)) % 10 + digits.append(check) + return "".join(str(d) for d in digits) + + +def generate_products(fake: Faker) -> list[dict]: + """Return NUM_PRODUCTS normalized product records.""" + now = datetime.now(tz=UTC) + products = [] + used_upcs: set[str] = set() + + per_category = NUM_PRODUCTS // len(_PRODUCT_TEMPLATES) + remainder = NUM_PRODUCTS % len(_PRODUCT_TEMPLATES) + + for i, (category, brands, names, sizes, default_unit) in enumerate(_PRODUCT_TEMPLATES): + count = per_category + (1 if i < remainder else 0) + for _ in range(count): + brand = random.choice(brands) + product_name = random.choice(names) + size_str = random.choice(sizes) + canonical_name = f"{brand} {product_name} {size_str}" + + size_parts = size_str.split(" ", 1) + size_val = size_parts[0] + + num_upcs = random.randint(1, 3) + upcs: list[str] = [] + for _ in range(num_upcs): + upc = _generate_upc() + attempts = 0 + while upc in used_upcs and attempts < 10: + upc = _generate_upc() + attempts += 1 + used_upcs.add(upc) + upcs.append(upc) + + products.append( + { + "id": uuid.uuid4(), + "canonical_name": canonical_name, + "category": category, + "subcategory": product_name, + "brand": brand, + "size": size_val, + "size_unit": default_unit, + "upc_variants": upcs, + "created_at": now, + "updated_at": now, + } + ) + + return products diff --git a/src/cartsnitch_common/seed/generators/purchases.py b/src/cartsnitch_common/seed/generators/purchases.py new file mode 100644 index 0000000..d023c73 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/purchases.py @@ -0,0 +1,156 @@ +"""Generate Purchase and PurchaseItem seed data.""" + +import random +import uuid +from datetime import UTC, date, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.seed.config import ( + NUM_PURCHASE_ITEMS, + NUM_PURCHASES, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + + +def _random_date() -> date: + return SEED_START_DATE + timedelta(days=random.randint(0, _DATE_RANGE_DAYS)) + + +def _decimal(val: float, places: int = 2) -> Decimal: + return Decimal(str(round(val, places))) + + +def generate_purchases( + users: list[dict], + stores: list[dict], + store_locations: list[dict], +) -> list[dict]: + """Return NUM_PURCHASES purchase records.""" + now = datetime.now(tz=UTC) + active_users = [u for u in users if u["_active"]] + inactive_users = [u for u in users if not u["_active"]] + + # Build location index by store_id + locs_by_store: dict = {} + for loc in store_locations: + locs_by_store.setdefault(loc["store_id"], []).append(loc) + + purchases = [] + seen_receipts: set[tuple] = set() + + # Active users get 80% of purchases + active_count = int(NUM_PURCHASES * 0.8) + inactive_count = NUM_PURCHASES - active_count + + def make_purchase(user: dict, store: dict) -> dict | None: + receipt_id = f"RCT-{random.randint(100000, 999999)}" + key = (user["id"], store["id"], receipt_id) + if key in seen_receipts: + return None + seen_receipts.add(key) + subtotal = _decimal(random.uniform(5.0, 150.0)) + tax = _decimal(float(subtotal) * 0.06) + savings = _decimal(random.uniform(0.0, float(subtotal) * 0.3)) + total = _decimal(float(subtotal) + float(tax) - float(savings)) + purchase_date = _random_date() + store_locs = locs_by_store.get(store["id"], []) + store_location_id = random.choice(store_locs)["id"] if store_locs else None + ingested_at = datetime( + purchase_date.year, purchase_date.month, purchase_date.day, tzinfo=UTC + ) + timedelta(hours=random.randint(1, 48)) + return { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "store_location_id": store_location_id, + "receipt_id": receipt_id, + "purchase_date": purchase_date, + "total": total, + "subtotal": subtotal, + "tax": tax, + "savings_total": savings if float(savings) > 0 else None, + "source_url": None, + "raw_data": None, + "ingested_at": ingested_at, + "created_at": now, + "updated_at": now, + } + + for _ in range(active_count): + user = random.choice(active_users) + store = random.choice(stores) + p = make_purchase(user, store) + if p: + purchases.append(p) + + for _ in range(inactive_count): + user = random.choice(inactive_users) + store = random.choice(stores) + p = make_purchase(user, store) + if p: + purchases.append(p) + + return purchases[:NUM_PURCHASES] + + +def generate_purchase_items( + purchases: list[dict], + products: list[dict], +) -> list[dict]: + """Return ~NUM_PURCHASE_ITEMS purchase item records distributed across purchases.""" + now = datetime.now(tz=UTC) + items: list[dict] = [] + total_target = NUM_PURCHASE_ITEMS + num_purchases = len(purchases) + + # Distribute items: avg 5 per purchase with variance + for i, purchase in enumerate(purchases): + # Remaining purchases get proportional share + remaining_purchases = num_purchases - i + remaining_items = total_target - len(items) + if remaining_purchases <= 0 or remaining_items <= 0: + break + avg = remaining_items / remaining_purchases + count = max(1, min(15, int(random.gauss(avg, 2)))) + count = min(count, remaining_items) + + for _ in range(count): + product = random.choice(products) + unit_price = _decimal(random.uniform(0.99, 25.99)) + quantity = Decimal("1.000") + extended_price = _decimal(float(unit_price) * float(quantity)) + has_sale = random.random() > 0.7 + sale_price = ( + _decimal(float(unit_price) * random.uniform(0.7, 0.95)) if has_sale else None + ) + has_coupon = random.random() > 0.85 + coupon_discount = _decimal(random.uniform(0.25, 2.00)) if has_coupon else None + + upc = None + if product["upc_variants"]: + upc = random.choice(product["upc_variants"]) + + items.append( + { + "id": uuid.uuid4(), + "purchase_id": purchase["id"], + "product_name_raw": product["canonical_name"], + "upc": upc, + "quantity": quantity, + "unit_price": unit_price, + "extended_price": extended_price, + "regular_price": unit_price, + "sale_price": sale_price, + "coupon_discount": coupon_discount, + "loyalty_discount": None, + "category_raw": product["category"].value if product["category"] else None, + "normalized_product_id": product["id"], + "created_at": now, + "updated_at": now, + } + ) + + return items diff --git a/src/cartsnitch_common/seed/generators/shrinkflation.py b/src/cartsnitch_common/seed/generators/shrinkflation.py new file mode 100644 index 0000000..9d833bb --- /dev/null +++ b/src/cartsnitch_common/seed/generators/shrinkflation.py @@ -0,0 +1,114 @@ +"""Generate ShrinkflationEvent seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.seed.config import ( + NUM_SHRINKFLATION_EVENTS, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + +# Shrinkflation patterns: (old_size, new_size, unit, size_reduction_pct) +_SHRINK_PATTERNS: list[tuple[str, str, SizeUnit, float]] = [ + ("16", "14", SizeUnit.OZ, 0.125), + ("32", "28", SizeUnit.OZ, 0.125), + ("64", "56", SizeUnit.FL_OZ, 0.125), + ("18", "16", SizeUnit.OZ, 0.111), + ("20", "18", SizeUnit.OZ, 0.10), + ("2", "1.75", SizeUnit.LB, 0.125), + ("24", "21", SizeUnit.OZ, 0.125), + ("12", "10.5", SizeUnit.OZ, 0.125), + ("48", "42", SizeUnit.OZ, 0.125), + ("8", "7", SizeUnit.OZ, 0.125), + ("1", "0.875", SizeUnit.LB, 0.125), + ("36", "32", SizeUnit.OZ, 0.111), + ("6", "5", SizeUnit.CT, 0.167), + ("12", "10", SizeUnit.CT, 0.167), + ("100", "90", SizeUnit.CT, 0.10), + ("16.9", "15", SizeUnit.FL_OZ, 0.112), + ("3", "2.5", SizeUnit.LB, 0.167), + ("40", "35", SizeUnit.OZ, 0.125), + ("28", "24", SizeUnit.OZ, 0.143), + ("14.5", "12.5", SizeUnit.OZ, 0.138), +] + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +def generate_shrinkflation_events(products: list[dict]) -> list[dict]: + """Return NUM_SHRINKFLATION_EVENTS shrinkflation event records. + + Selects products and assigns size changes where price is maintained or + increased despite the smaller package — valid inputs for ShrinkRay. + """ + now = datetime.now(tz=UTC) + events = [] + + # Pick NUM_SHRINKFLATION_EVENTS unique products (prefer pantry/snacks/household) + from cartsnitch_common.constants import ProductCategory + + preferred_cats = { + ProductCategory.PANTRY, + ProductCategory.SNACKS, + ProductCategory.HOUSEHOLD, + ProductCategory.PERSONAL_CARE, + ProductCategory.FROZEN, + ProductCategory.DAIRY, + ProductCategory.BEVERAGES, + } + preferred = [p for p in products if p.get("category") in preferred_cats] + fallback = [p for p in products if p not in preferred] + pool = preferred + fallback + + selected = random.sample(pool, min(NUM_SHRINKFLATION_EVENTS, len(pool))) + + for i, product in enumerate(selected): + pattern = _SHRINK_PATTERNS[i % len(_SHRINK_PATTERNS)] + old_size, new_size, unit, reduction_pct = pattern + + # Detection date: at least 60 days into window so there's history before + min_day = 60 + detected_day = random.randint(min_day, _DATE_RANGE_DAYS) + detected_date = SEED_START_DATE + timedelta(days=detected_day) + + # Price maintained or slightly increased despite size reduction + base_price = random.uniform(2.99, 12.99) + price_at_old_size = _decimal(base_price) + # flat or small increase despite size reduction + price_at_new_size = _decimal(base_price * random.uniform(0.98, 1.08)) + + confidence = _decimal(random.uniform(0.70, 0.99)) + + notes = ( + f"Package reduced from {old_size}{unit} to {new_size}{unit} " + f"({reduction_pct * 100:.1f}% reduction). " + f"Price {'increased' if price_at_new_size > price_at_old_size else 'held steady'}." + ) + + events.append( + { + "id": uuid.uuid4(), + "normalized_product_id": product["id"], + "detected_date": detected_date, + "old_size": old_size, + "new_size": new_size, + "old_unit": unit, + "new_unit": unit, + "price_at_old_size": price_at_old_size, + "price_at_new_size": price_at_new_size, + "confidence": confidence, + "notes": notes, + "created_at": now, + "updated_at": now, + } + ) + + return events diff --git a/src/cartsnitch_common/seed/generators/stores.py b/src/cartsnitch_common/seed/generators/stores.py new file mode 100644 index 0000000..e2ebb82 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/stores.py @@ -0,0 +1,203 @@ +"""Generate Store and StoreLocation seed data.""" + +import uuid +from datetime import UTC, datetime + +from cartsnitch_common.constants import StoreSlug +from cartsnitch_common.seed.config import NUM_LOCATIONS_PER_STORE + +# Fixed store definitions +_STORE_DEFS: list[dict] = [ + { + "name": "Meijer", + "slug": StoreSlug.MEIJER, + "logo_url": "https://www.meijer.com/favicon.ico", + "website_url": "https://www.meijer.com", + }, + { + "name": "Kroger", + "slug": StoreSlug.KROGER, + "logo_url": "https://www.kroger.com/favicon.ico", + "website_url": "https://www.kroger.com", + }, + { + "name": "Target", + "slug": StoreSlug.TARGET, + "logo_url": "https://www.target.com/favicon.ico", + "website_url": "https://www.target.com", + }, +] + +# SE Michigan locations per store (5 each = 15 total) +_LOCATION_DEFS: dict[StoreSlug, list[dict]] = { + StoreSlug.MEIJER: [ + { + "address": "3145 Ann Arbor-Saline Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2434, + "lng": -83.8102, + }, + { + "address": "700 W Ellsworth Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48108", + "lat": 42.2318, + "lng": -83.7581, + }, + { + "address": "5100 Oakman Blvd", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3223, + "lng": -83.1952, + }, + { + "address": "15555 Northline Rd", + "city": "Southgate", + "state": "MI", + "zip": "48195", + "lat": 42.2089, + "lng": -83.1953, + }, + { + "address": "2855 Washtenaw Ave", + "city": "Ypsilanti", + "state": "MI", + "zip": "48197", + "lat": 42.2461, + "lng": -83.6388, + }, + ], + StoreSlug.KROGER: [ + { + "address": "2010 W Stadium Blvd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2706, + "lng": -83.7807, + }, + { + "address": "1100 S Main St", + "city": "Ann Arbor", + "state": "MI", + "zip": "48104", + "lat": 42.2555, + "lng": -83.7469, + }, + { + "address": "23650 Michigan Ave", + "city": "Dearborn", + "state": "MI", + "zip": "48124", + "lat": 42.3221, + "lng": -83.2135, + }, + { + "address": "14000 Michigan Ave", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3281, + "lng": -83.1789, + }, + { + "address": "3965 Packard St", + "city": "Ann Arbor", + "state": "MI", + "zip": "48108", + "lat": 42.2298, + "lng": -83.7196, + }, + ], + StoreSlug.TARGET: [ + { + "address": "3165 Ann Arbor-Saline Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2431, + "lng": -83.8097, + }, + { + "address": "4001 Carpenter Rd", + "city": "Ypsilanti", + "state": "MI", + "zip": "48197", + "lat": 42.2373, + "lng": -83.6617, + }, + { + "address": "16000 Ford Rd", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3312, + "lng": -83.2098, + }, + { + "address": "17300 Eureka Rd", + "city": "Southgate", + "state": "MI", + "zip": "48195", + "lat": 42.2001, + "lng": -83.2014, + }, + { + "address": "2400 E Stadium Blvd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48104", + "lat": 42.2624, + "lng": -83.7102, + }, + ], +} + + +def generate_stores() -> list[dict]: + """Return 3 fixed store records.""" + now = datetime.now(tz=UTC) + stores = [] + for defn in _STORE_DEFS: + stores.append( + { + "id": uuid.uuid4(), + "name": defn["name"], + "slug": defn["slug"], + "logo_url": defn["logo_url"], + "website_url": defn["website_url"], + "created_at": now, + "updated_at": now, + } + ) + return stores + + +def generate_store_locations(stores: list[dict]) -> list[dict]: + """Return 5 locations per store (15 total).""" + now = datetime.now(tz=UTC) + slug_to_id = {s["slug"]: s["id"] for s in stores} + locations = [] + for slug, loc_defs in _LOCATION_DEFS.items(): + store_id = slug_to_id[slug] + for loc in loc_defs[:NUM_LOCATIONS_PER_STORE]: + locations.append( + { + "id": uuid.uuid4(), + "store_id": store_id, + "address": loc["address"], + "city": loc["city"], + "state": loc["state"], + "zip": loc["zip"], + "lat": loc["lat"], + "lng": loc["lng"], + "created_at": now, + "updated_at": now, + } + ) + return locations diff --git a/src/cartsnitch_common/seed/generators/users.py b/src/cartsnitch_common/seed/generators/users.py new file mode 100644 index 0000000..6757b23 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/users.py @@ -0,0 +1,105 @@ +"""Generate User and UserStoreAccount seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta + +from faker import Faker + +from cartsnitch_common.constants import AccountStatus +from cartsnitch_common.seed.config import ( + NUM_ACTIVE_USERS, + NUM_USER_STORE_ACCOUNTS, + NUM_USERS, + SEED_END_DATE, +) + + +def generate_users(fake: Faker) -> list[dict]: + """Return NUM_USERS user records. First NUM_ACTIVE_USERS are active.""" + now = datetime.now(tz=UTC) + users = [] + for i in range(NUM_USERS): + created_at = now - timedelta(days=random.randint(30, 365)) + users.append( + { + "id": uuid.uuid4(), + "email": fake.unique.email(), + "hashed_password": fake.sha256(), + "display_name": fake.name() if random.random() > 0.2 else None, + "created_at": created_at, + "updated_at": created_at, + "_active": i < NUM_ACTIVE_USERS, + } + ) + return users + + +def generate_user_store_accounts( + users: list[dict], + stores: list[dict], +) -> list[dict]: + """Return ~NUM_USER_STORE_ACCOUNTS user-store account links. + + Active users get accounts at multiple stores; inactive users may have none. + """ + now = datetime.now(tz=UTC) + accounts = [] + seen: set[tuple] = set() + + active_users = [u for u in users if u["_active"]] + inactive_users = [u for u in users if not u["_active"]] + + # Active users: each gets 1-3 store accounts + for user in active_users: + num_accounts = random.randint(1, 3) + selected_stores = random.sample(stores, min(num_accounts, len(stores))) + for store in selected_stores: + key = (user["id"], store["id"]) + if key in seen: + continue + seen.add(key) + last_sync = datetime( + SEED_END_DATE.year, + SEED_END_DATE.month, + SEED_END_DATE.day, + tzinfo=UTC, + ) - timedelta(days=random.randint(0, 14)) + accounts.append( + { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "session_data": {"token": "SEED_FAKE_TOKEN", "expires": "2026-12-31"}, + "session_expires_at": now + timedelta(days=random.randint(1, 90)), + "last_sync_at": last_sync, + "status": AccountStatus.ACTIVE, + "created_at": user["created_at"], + "updated_at": user["updated_at"], + } + ) + + # Fill remaining slots from inactive users + remaining = NUM_USER_STORE_ACCOUNTS - len(accounts) + for user in random.sample(inactive_users, min(remaining, len(inactive_users))): + store = random.choice(stores) + key = (user["id"], store["id"]) + if key in seen: + continue + seen.add(key) + status = random.choice([AccountStatus.EXPIRED, AccountStatus.ERROR, AccountStatus.ACTIVE]) + accounts.append( + { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "session_data": None, + "session_expires_at": None, + "last_sync_at": None, + "status": status, + "created_at": user["created_at"], + "updated_at": user["updated_at"], + } + ) + + return accounts[: NUM_USER_STORE_ACCOUNTS + len(active_users) * 3] diff --git a/src/cartsnitch_common/seed/runner.py b/src/cartsnitch_common/seed/runner.py new file mode 100644 index 0000000..c2b7784 --- /dev/null +++ b/src/cartsnitch_common/seed/runner.py @@ -0,0 +1,189 @@ +"""Seed runner: orchestrates generation and DB insertion in FK-safe order.""" + +import random +import time +from typing import Any + +from faker import Faker +from sqlalchemy import text +from sqlalchemy.orm import Session + +from cartsnitch_common.database import get_sync_session_factory +from cartsnitch_common.models.coupon import Coupon +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.purchase import Purchase, PurchaseItem +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent +from cartsnitch_common.models.store import Store, StoreLocation +from cartsnitch_common.models.user import User, UserStoreAccount +from cartsnitch_common.seed.config import SEED_VALUE +from cartsnitch_common.seed.generators.coupons import generate_coupons +from cartsnitch_common.seed.generators.prices import generate_price_history +from cartsnitch_common.seed.generators.products import generate_products +from cartsnitch_common.seed.generators.purchases import generate_purchase_items, generate_purchases +from cartsnitch_common.seed.generators.shrinkflation import generate_shrinkflation_events +from cartsnitch_common.seed.generators.stores import generate_store_locations, generate_stores +from cartsnitch_common.seed.generators.users import generate_user_store_accounts, generate_users + +# FK-safe truncation order (reverse of insertion order) +_TRUNCATE_TABLES: list[str] = [ + "shrinkflation_events", + "coupons", + "price_history", + "purchase_items", + "purchases", + "user_store_accounts", + "normalized_products", + "users", + "store_locations", + "stores", +] + + +def _log(msg: str) -> None: + print(msg, flush=True) + + +def _bulk_insert(session: Session, model: type, rows: list[dict[str, Any]]) -> None: + """Insert rows using core INSERT for performance, stripping private keys.""" + if not rows: + return + # Strip internal keys (prefixed with _) + clean = [{k: v for k, v in row.items() if not k.startswith("_")} for row in rows] + session.execute(model.__table__.insert(), clean) # type: ignore[attr-defined] + + +def run_seed( + database_url: str | None = None, + seed_value: int = SEED_VALUE, + dry_run: bool = False, +) -> None: + """Generate and insert all seed data. + + Args: + database_url: Optional override for the DB connection URL. + seed_value: Random seed for deterministic output. + dry_run: If True, print planned counts without touching the DB. + """ + random.seed(seed_value) + fake = Faker() + Faker.seed(seed_value) + + _log("=== CartSnitch Seed Data Generator ===") + _log(f"Seed: {seed_value}") + + # --- Generation phase --- + t0 = time.monotonic() + + _log("Generating stores...") + stores = generate_stores() + _log(f" {len(stores)} stores ({time.monotonic() - t0:.2f}s)") + + _log("Generating store locations...") + store_locations = generate_store_locations(stores) + _log(f" {len(store_locations)} store locations ({time.monotonic() - t0:.2f}s)") + + _log("Generating users...") + users = generate_users(fake) + _log(f" {len(users)} users ({time.monotonic() - t0:.2f}s)") + + _log("Generating user store accounts...") + user_store_accounts = generate_user_store_accounts(users, stores) + _log(f" {len(user_store_accounts)} user store accounts ({time.monotonic() - t0:.2f}s)") + + _log("Generating products...") + products = generate_products(fake) + _log(f" {len(products)} products ({time.monotonic() - t0:.2f}s)") + + _log("Generating purchases...") + purchases = generate_purchases(users, stores, store_locations) + _log(f" {len(purchases)} purchases ({time.monotonic() - t0:.2f}s)") + + _log("Generating purchase items...") + purchase_items = generate_purchase_items(purchases, products) + _log(f" {len(purchase_items)} purchase items ({time.monotonic() - t0:.2f}s)") + + _log("Generating price history...") + price_history = generate_price_history(products, stores, purchase_items) + _log(f" {len(price_history)} price history records ({time.monotonic() - t0:.2f}s)") + + _log("Generating coupons...") + coupons = generate_coupons(fake, products, stores) + _log(f" {len(coupons)} coupons ({time.monotonic() - t0:.2f}s)") + + _log("Generating shrinkflation events...") + shrinkflation_events = generate_shrinkflation_events(products) + _log(f" {len(shrinkflation_events)} shrinkflation events ({time.monotonic() - t0:.2f}s)") + + _log("") + _log("=== Summary ===") + _log(f" stores: {len(stores)}") + _log(f" store_locations: {len(store_locations)}") + _log(f" users: {len(users)}") + _log(f" user_store_accounts: {len(user_store_accounts)}") + _log(f" normalized_products: {len(products)}") + _log(f" purchases: {len(purchases)}") + _log(f" purchase_items: {len(purchase_items)}") + _log(f" price_history: {len(price_history)}") + _log(f" coupons: {len(coupons)}") + _log(f" shrinkflation_events: {len(shrinkflation_events)}") + + if dry_run: + _log("") + _log("Dry run — no data written.") + return + + # --- DB insertion phase --- + factory = get_sync_session_factory(database_url) + with factory() as session: + _log("") + _log("Truncating tables (reverse FK order)...") + for table in _TRUNCATE_TABLES: + session.execute(text(f"TRUNCATE TABLE {table} CASCADE")) + _log(" done") + + _log("Inserting stores...") + _bulk_insert(session, Store, stores) + _log(f" {len(stores)} inserted") + + _log("Inserting store locations...") + _bulk_insert(session, StoreLocation, store_locations) + _log(f" {len(store_locations)} inserted") + + _log("Inserting users...") + _bulk_insert(session, User, users) + _log(f" {len(users)} inserted") + + _log("Inserting user store accounts...") + _bulk_insert(session, UserStoreAccount, user_store_accounts) + _log(f" {len(user_store_accounts)} inserted") + + _log("Inserting products...") + _bulk_insert(session, NormalizedProduct, products) + _log(f" {len(products)} inserted") + + _log("Inserting purchases...") + _bulk_insert(session, Purchase, purchases) + _log(f" {len(purchases)} inserted") + + _log("Inserting purchase items...") + _bulk_insert(session, PurchaseItem, purchase_items) + _log(f" {len(purchase_items)} inserted") + + _log("Inserting price history...") + _bulk_insert(session, PriceHistory, price_history) + _log(f" {len(price_history)} inserted") + + _log("Inserting coupons...") + _bulk_insert(session, Coupon, coupons) + _log(f" {len(coupons)} inserted") + + _log("Inserting shrinkflation events...") + _bulk_insert(session, ShrinkflationEvent, shrinkflation_events) + _log(f" {len(shrinkflation_events)} inserted") + + session.commit() + + elapsed = time.monotonic() - t0 + _log("") + _log(f"Seed complete in {elapsed:.1f}s") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6bfc994 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +"""Shared test fixtures for cartsnitch-common tests.""" + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from cartsnitch_common.models.base import Base + + +@pytest.fixture +def engine(): + """In-memory SQLite engine for unit tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """SQLAlchemy session bound to in-memory SQLite.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..8b5eb68 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,376 @@ +"""Tests for SQLAlchemy ORM models.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from sqlalchemy import inspect + +from cartsnitch_common.constants import ( + AccountStatus, + DiscountType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + StoreLocation, + User, + UserStoreAccount, +) + + +class TestTableCreation: + """Verify all expected tables are created.""" + + def test_all_tables_exist(self, engine): + inspector = inspect(engine) + table_names = set(inspector.get_table_names()) + expected = { + "stores", + "store_locations", + "users", + "user_store_accounts", + "purchases", + "purchase_items", + "normalized_products", + "price_history", + "coupons", + "shrinkflation_events", + } + assert expected.issubset(table_names) + + def test_ten_tables_total(self, engine): + inspector = inspect(engine) + assert len(inspector.get_table_names()) == 10 + + +class TestUUIDPrimaryKeys: + """All models use UUID PKs.""" + + def test_store_uuid_pk(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert isinstance(store.id, uuid.UUID) + + def test_user_uuid_pk(self, session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(user) + session.commit() + assert isinstance(user.id, uuid.UUID) + + +class TestStoreModel: + def test_store_slug_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert store.slug == StoreSlug.KROGER + + def test_store_unique_slug(self, session): + s1 = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + s2 = Store( + id=uuid.uuid4(), + name="Target Duplicate", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(s1) + session.commit() + session.add(s2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestStoreLocationModel: + def test_store_location_fields(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + loc = StoreLocation( + id=uuid.uuid4(), + store_id=store.id, + address="123 Main St", + city="Ann Arbor", + state="MI", + zip="48104", + lat=42.2808, + lng=-83.7430, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(loc) + session.commit() + assert loc.city == "Ann Arbor" + assert loc.lat == pytest.approx(42.2808) + + +class TestUserStoreAccountModel: + def test_account_status_enum(self, session): + user = User( + id=uuid.uuid4(), + email="test@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + acct = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(acct) + session.commit() + assert acct.status == AccountStatus.ACTIVE + + def test_unique_user_store_constraint(self, session): + """One account per user per store.""" + user = User( + id=uuid.uuid4(), + email="unique@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + a1 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + a2 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.EXPIRED, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(a1) + session.commit() + session.add(a2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestPurchaseModel: + def test_purchase_with_items(self, session): + user = User( + id=uuid.uuid4(), + email="buyer@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + purchase = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + ingested_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(purchase) + session.flush() + item = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase.id, + product_name_raw="Meijer Whole Milk 1 Gallon", + upc="0041250000001", + quantity=Decimal("1"), + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ) + session.add(item) + session.commit() + assert item.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert item.unit_price == Decimal("3.49") + + +class TestNormalizedProductModel: + def test_product_with_upc_variants(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + brand="Store Brand", + size="128", + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + assert product.category == ProductCategory.DAIRY + assert product.size_unit == SizeUnit.FL_OZ + + +class TestPriceHistoryModel: + def test_price_source_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Eggs, Large, 12ct", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([store, product]) + session.flush() + ph = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + sale_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(ph) + session.commit() + assert ph.source == PriceSource.RECEIPT + assert ph.regular_price == Decimal("4.99") + + +class TestCouponModel: + def test_coupon_discount_types(self, session): + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + coupon = Coupon( + id=uuid.uuid4(), + store_id=store.id, + title="$2 off eggs", + discount_type=DiscountType.FIXED, + discount_value=Decimal("2.00"), + requires_clip=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(coupon) + session.commit() + assert coupon.discount_type == DiscountType.FIXED + assert coupon.discount_value == Decimal("2.00") + + +class TestShrinkflationEventModel: + def test_shrinkflation_event(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Cereal, Honey Oats", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + price_at_old_size=Decimal("4.99"), + price_at_new_size=Decimal("4.99"), + confidence=Decimal("0.95"), + notes="Size reduced by 14.4%, price unchanged", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(event) + session.commit() + assert event.confidence == Decimal("0.95") + assert event.old_unit == SizeUnit.OZ diff --git a/tests/test_normalization.py b/tests/test_normalization.py new file mode 100644 index 0000000..d60c4d5 --- /dev/null +++ b/tests/test_normalization.py @@ -0,0 +1,157 @@ +"""Tests for product normalization module.""" + +import uuid +from datetime import UTC, datetime + +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import ( + MatchMethod, + clean_name, + extract_size_info, + jaccard_similarity, + match_by_name, + match_by_upc, + normalize_product, +) + + +class TestCleanName: + def test_lowercase(self): + assert clean_name("Kroger WHOLE MILK") == "kroger whole milk" + + def test_removes_size_info(self): + assert "oz" not in clean_name("Milk 16 oz Whole") + + def test_removes_noise_words(self): + cleaned = clean_name("The Original Brand Milk") + assert "the" not in cleaned.split() + assert "original" not in cleaned.split() + assert "brand" not in cleaned.split() + + def test_collapses_whitespace(self): + assert " " not in clean_name("Milk Whole Gallon") + + def test_removes_punctuation(self): + cleaned = clean_name("Meijer's Best (Organic) Milk!") + assert "'" not in cleaned + assert "(" not in cleaned + + +class TestExtractSizeInfo: + def test_extracts_oz(self): + result = extract_size_info("Cereal 18 oz box") + assert result == ("18", "oz") + + def test_extracts_fl_oz(self): + result = extract_size_info("Juice 64 fl oz") + assert result == ("64", "fl_oz") + + def test_extracts_lb(self): + result = extract_size_info("Ground Beef 1.5 lb") + assert result == ("1.5", "lb") + + def test_extracts_ct(self): + result = extract_size_info("Eggs Large 12 ct") + assert result == ("12", "ct") + + def test_no_size_returns_none(self): + assert extract_size_info("Bananas") is None + + +class TestJaccardSimilarity: + def test_identical_strings(self): + assert jaccard_similarity("whole milk gallon", "whole milk gallon") == 1.0 + + def test_completely_different(self): + assert jaccard_similarity("apple juice", "ground beef") == 0.0 + + def test_partial_overlap(self): + score = jaccard_similarity("kroger whole milk", "meijer whole milk") + assert 0.4 < score < 0.8 # "whole" and "milk" overlap + + def test_empty_strings(self): + assert jaccard_similarity("", "") == 0.0 + assert jaccard_similarity("milk", "") == 0.0 + + +class TestMatchByUPC: + def test_match_found(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + # SQLite doesn't support JSONB containment — this will raise + # In production (PostgreSQL), this would work + result = match_by_upc(session, "0041250000001") + assert result is not None + assert result.method == MatchMethod.UPC + assert result.confidence == 1.0 + + def test_no_match(self, session): + result = match_by_upc(session, "9999999999999") + assert result is None + + +class TestMatchByName: + def test_exact_name_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Whole Milk Gallon") + assert result is not None + assert result.method == MatchMethod.NAME + assert result.confidence > 0.5 + + def test_fuzzy_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Kroger Whole Milk, 1 Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Meijer Whole Milk 1 Gallon", threshold=0.3) + assert result is not None + assert result.confidence > 0.3 + + def test_no_match_below_threshold(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Apple Juice 64 oz", threshold=0.5) + assert result is None + + +class TestNormalizeProduct: + def test_name_fallback(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs, 12 count", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = normalize_product(session, "Large Eggs 12 ct", upc=None) + assert result is not None + assert result.method == MatchMethod.NAME + + def test_no_match(self, session): + result = normalize_product(session, "Nonexistent Product XYZ", upc=None) + assert result is None diff --git a/tests/test_pipeline_e2e.py b/tests/test_pipeline_e2e.py new file mode 100644 index 0000000..03737f2 --- /dev/null +++ b/tests/test_pipeline_e2e.py @@ -0,0 +1,949 @@ +"""End-to-end integration tests for the data pipeline. + +Tests the full flow: scraper output → normalization → product matching → DB storage +→ price tracking → shrinkflation detection → event publishing. + +Uses real test fixtures with an in-memory SQLite database, not mocks. +""" + +import uuid +from datetime import date +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, sessionmaker + +from cartsnitch_common.constants import ( + EventType, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.events import publish_event +from cartsnitch_common.models import ( + Base, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + User, +) +from cartsnitch_common.pipeline.matching import ProductMatcher +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_price_trend, + record_price_from_item, +) +from cartsnitch_common.pipeline.receipt import normalize_receipt, parse_meijer_item +from cartsnitch_common.pipeline.shrinkflation import detect_shrinkflation +from cartsnitch_common.schemas.events import EventEnvelope +from cartsnitch_common.schemas.purchase import PurchaseCreate + +# --------------------------------------------------------------------------- +# Fixtures: realistic scraper output from Meijer +# --------------------------------------------------------------------------- + +MEIJER_RECEIPT_FIXTURE = { + "receiptId": "MJ-2026-03-15-00042", + "date": "2026-03-15", + "total": "47.82", + "subtotal": "44.50", + "taxAmount": "3.32", + "totalSavings": "6.20", + "items": [ + { + "description": " Meijer Whole Milk 1 Gallon ", + "upcCode": "00041250010001", + "quantity": 1, + "unitPrice": "3.29", + "extendedPrice": "3.29", + "regularPrice": "3.49", + "salePrice": "3.29", + "category": "Dairy", + }, + { + "name": "BARILLA SPAGHETTI 16 OZ", + "upc": "076808280753", + "qty": 2, + "price": "1.69", + "totalPrice": "3.38", + "regularPrice": "1.89", + "couponDiscount": "0.40", + "department": "Pantry", + }, + { + "description": "Meijer Lean Ground Beef 1 lb", + "upcCode": "00041250022004", + "quantity": 1, + "unitPrice": "5.99", + "extendedPrice": "5.99", + "regularPrice": "6.49", + "loyaltyDiscount": "0.50", + "category": "Meat", + }, + { + "description": "Cheerios Original 12 oz", + "upcCode": "016000275645", + "quantity": 1, + "unitPrice": "4.49", + "extendedPrice": "4.49", + "regularPrice": "4.49", + "category": "Snacks", + }, + { + "description": "Fresh Bananas", + "quantity": 1, + "unitPrice": "0.69", + "extendedPrice": "0.69", + "category": "Produce", + }, + ], +} + +MEIJER_RECEIPT_SECOND_VISIT = { + "receiptId": "MJ-2026-03-18-00099", + "date": "2026-03-18", + "total": "12.47", + "items": [ + { + "description": "Meijer Whole Milk 1 Gallon", + "upcCode": "00041250010001", + "quantity": 1, + "unitPrice": "3.49", + "extendedPrice": "3.49", + "regularPrice": "3.49", + "category": "Dairy", + }, + { + "description": "BARILLA SPAGHETTI 16 OZ", + "upc": "076808280753", + "qty": 1, + "price": "1.99", + "totalPrice": "1.99", + "regularPrice": "1.99", + "department": "Pantry", + }, + { + "description": "Cheerios Original 10.8 oz", + "upcCode": "016000275645", + "quantity": 1, + "unitPrice": "4.49", + "extendedPrice": "4.49", + "regularPrice": "4.49", + "category": "Snacks", + }, + ], +} + + +@pytest.fixture +def e2e_engine(): + """In-memory SQLite engine for E2E tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def e2e_session(e2e_engine): + """SQLAlchemy session with pre-seeded store and user.""" + factory = sessionmaker(bind=e2e_engine) + with factory() as sess: + yield sess + + +@pytest.fixture +def store(e2e_session: Session) -> Store: + """Seed a Meijer store.""" + s = Store(id=uuid.uuid4(), name="Meijer", slug=StoreSlug.MEIJER) + e2e_session.add(s) + e2e_session.flush() + return s + + +@pytest.fixture +def user(e2e_session: Session) -> User: + """Seed a test user.""" + u = User( + id=uuid.uuid4(), + email="tester@cartsnitch.com", + hashed_password="hashed_test_password", + display_name="Test User", + ) + e2e_session.add(u) + e2e_session.flush() + return u + + +@pytest.fixture +def redis_mock(): + """A lightweight Redis mock that captures published messages.""" + client = MagicMock() + published: list[tuple[str, str]] = [] + + def _publish(channel: str, message: str) -> int: + published.append((channel, message)) + return 1 + + client.publish = MagicMock(side_effect=_publish) + client._published = published + return client + + +# =========================================================================== +# Test class: Full pipeline E2E — scraper → normalization → matching → storage +# =========================================================================== + + +class TestFullPipelineE2E: + """Scraper output → normalize_receipt → ProductMatcher → DB storage.""" + + def test_normalize_meijer_receipt(self, user: User, store: Store): + """Raw Meijer receipt normalizes into a valid PurchaseCreate.""" + purchase = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + assert isinstance(purchase, PurchaseCreate) + assert purchase.receipt_id == "MJ-2026-03-15-00042" + assert purchase.purchase_date == date(2026, 3, 15) + assert purchase.total == Decimal("47.82") + assert purchase.subtotal == Decimal("44.50") + assert purchase.tax == Decimal("3.32") + assert purchase.savings_total == Decimal("6.20") + assert len(purchase.items) == 5 + assert purchase.raw_data == MEIJER_RECEIPT_FIXTURE + + def test_item_field_normalization(self, user: User, store: Store): + """Items parse correctly regardless of field name variants.""" + purchase = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Item using 'description' / 'upcCode' fields + milk = purchase.items[0] + assert milk.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert milk.upc == "41250010001" # leading zeros stripped + assert milk.unit_price == Decimal("3.29") + + # Item using 'name' / 'upc' / 'qty' / 'price' / 'totalPrice' fields + pasta = purchase.items[1] + assert pasta.product_name_raw == "BARILLA SPAGHETTI 16 OZ" + assert pasta.upc == "76808280753" + assert pasta.quantity == Decimal("2") + assert pasta.extended_price == Decimal("3.38") + assert pasta.coupon_discount == Decimal("0.40") + + def test_upc_product_matching_and_storage(self, e2e_session: Session, user: User, store: Store): + """Full flow: normalize → match → store in DB. UPC matching works E2E.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Run product matching + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + assert len(outcomes) == 5 + + # First item has a UPC — auto_create makes a new product + assert outcomes[0].created_new is True + + # Store the purchase in DB + purchase_db = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id=purchase_schema.receipt_id, + purchase_date=purchase_schema.purchase_date, + total=purchase_schema.total, + subtotal=purchase_schema.subtotal, + tax=purchase_schema.tax, + savings_total=purchase_schema.savings_total, + raw_data=purchase_schema.raw_data, + ) + e2e_session.add(purchase_db) + e2e_session.flush() + + # Store items linked to the purchase and matched products + for _i, item_schema in enumerate(purchase_schema.items): + item_db = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase_db.id, + product_name_raw=item_schema.product_name_raw, + upc=item_schema.upc, + quantity=item_schema.quantity, + unit_price=item_schema.unit_price, + extended_price=item_schema.extended_price, + regular_price=item_schema.regular_price, + sale_price=item_schema.sale_price, + coupon_discount=item_schema.coupon_discount, + loyalty_discount=item_schema.loyalty_discount, + category_raw=item_schema.category_raw, + ) + e2e_session.add(item_db) + e2e_session.flush() + + # Verify data persisted correctly + stored_purchase = e2e_session.execute( + select(Purchase).where(Purchase.receipt_id == "MJ-2026-03-15-00042") + ).scalar_one() + assert stored_purchase.total == Decimal("47.82") + assert stored_purchase.user_id == user.id + assert stored_purchase.store_id == store.id + + stored_items = ( + e2e_session.execute( + select(PurchaseItem).where(PurchaseItem.purchase_id == stored_purchase.id) + ) + .scalars() + .all() + ) + assert len(stored_items) == 5 + + # Verify products were created in normalized_products table + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + assert len(products) == 5 # all 5 items auto-created products + + def test_second_visit_reuses_existing_products( + self, e2e_session: Session, user: User, store: Store + ): + """On second receipt, products matched by UPC reuse existing records.""" + # Ingest first receipt + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + matcher.match_items(first.items) + + products_after_first = e2e_session.execute(select(NormalizedProduct)).scalars().all() + first_count = len(products_after_first) + + # Ingest second receipt — overlapping UPCs + second = normalize_receipt( + MEIJER_RECEIPT_SECOND_VISIT, + user_id=str(user.id), + store_id=str(store.id), + ) + second_outcomes = matcher.match_items(second.items) + + # Milk, pasta, cheerios should match existing by UPC + assert second_outcomes[0].created_new is False # milk — UPC match + assert second_outcomes[1].created_new is False # pasta — UPC match + assert second_outcomes[2].created_new is False # cheerios — UPC match + + products_after_second = e2e_session.execute(select(NormalizedProduct)).scalars().all() + assert len(products_after_second) == first_count # no new products created + + +# =========================================================================== +# Test class: Price tracking and shrinkflation detection E2E +# =========================================================================== + + +class TestPriceTrackingE2E: + """Price recording from stored items and price delta detection.""" + + def test_price_recorded_from_ingested_receipt( + self, e2e_session: Session, user: User, store: Store + ): + """Ingest receipt → match products → record prices → verify price history.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + # Record prices for each matched item + price_entries = [] + for i, item_schema in enumerate(purchase_schema.items): + product = outcomes[i].match.product if outcomes[i].match else None + if product is None: + # Was auto-created — find the product directly + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + + if product: + entry, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=purchase_schema.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + price_entries.append((entry, delta)) + + # First ingestion — no deltas expected + assert all(delta is None for _, delta in price_entries) + + # Verify price history stored + all_prices = e2e_session.execute(select(PriceHistory)).scalars().all() + assert len(all_prices) >= 4 # at least the items with regular_price + + def test_price_increase_detected_on_second_receipt( + self, e2e_session: Session, user: User, store: Store + ): + """Second receipt with higher price triggers a PriceDelta.""" + # Ingest first receipt + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + first_outcomes = matcher.match_items(first.items) + + # Record first prices + for i, item_schema in enumerate(first.items): + product = first_outcomes[i].match.product if first_outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=first.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + + # Ingest second receipt — pasta price went up ($1.89 → $1.99) + second = normalize_receipt( + MEIJER_RECEIPT_SECOND_VISIT, + user_id=str(user.id), + store_id=str(store.id), + ) + second_outcomes = matcher.match_items(second.items) + + # Record second prices and capture deltas + deltas: list[PriceDelta] = [] + for i, item_schema in enumerate(second.items): + product = second_outcomes[i].match.product if second_outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + _, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=second.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + if delta: + deltas.append(delta) + + # Milk went from $3.49 → $3.49 (no change); pasta from $1.89 → $1.99 (increase) + price_increases = [d for d in deltas if d.is_increase] + assert len(price_increases) >= 1 + + pasta_delta = next( + (d for d in price_increases if d.old_price == Decimal("1.89")), + None, + ) + assert pasta_delta is not None + assert pasta_delta.new_price == Decimal("1.99") + assert pasta_delta.change_amount == Decimal("0.10") + assert pasta_delta.is_increase is True + + def test_price_trend_across_visits(self, e2e_session: Session, user: User, store: Store): + """get_price_trend returns ordered history after multiple ingestions.""" + # Create a product manually + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Test Product", + upc_variants=["1234567890"], + ) + e2e_session.add(product) + e2e_session.flush() + + # Record 3 prices on different dates + dates_prices = [ + (date(2026, 3, 10), Decimal("2.99")), + (date(2026, 3, 13), Decimal("3.19")), + (date(2026, 3, 16), Decimal("2.79")), + ] + for obs_date, price in dates_prices: + record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=obs_date, + regular_price=price, + ) + + trend = get_price_trend(e2e_session, product.id, store.id) + assert len(trend) == 3 + # Newest first + assert trend[0].regular_price == Decimal("2.79") + assert trend[1].regular_price == Decimal("3.19") + assert trend[2].regular_price == Decimal("2.99") + + +class TestShrinkflationE2E: + """Shrinkflation detection integrated with product matching.""" + + def test_shrinkflation_detected_from_receipt_data( + self, e2e_session: Session, user: User, store: Store + ): + """Cheerios went from 12 oz → 10.8 oz between receipts. Detect shrinkflation.""" + # Ingest first receipt — creates Cheerios product with size from name + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + first_outcomes = matcher.match_items(first.items) + + # Find the Cheerios product (index 3 in fixture) + cheerios_product = None + for outcome in first_outcomes: + if outcome.match and outcome.match.product: + p = outcome.match.product + else: + # Check auto-created products + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if "cheerios" in p.canonical_name.lower(): + cheerios_product = p + break + if cheerios_product: + break + else: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if "cheerios" in p.canonical_name.lower(): + cheerios_product = p + break + + assert cheerios_product is not None + # The auto-created product should have extracted "12" and "oz" from name + assert cheerios_product.size == "12" + assert cheerios_product.size_unit == SizeUnit.OZ + + # Now detect shrinkflation: 12 oz → 10.8 oz + event = detect_shrinkflation( + e2e_session, + product=cheerios_product, + new_size="10.8", + new_unit=SizeUnit.OZ, + new_price=Decimal("4.49"), + detected_date=date(2026, 3, 18), + ) + + assert event is not None + assert isinstance(event, ShrinkflationEvent) + assert event.old_size == "12" + assert event.new_size == "10.8" + assert event.old_unit == SizeUnit.OZ + assert event.new_unit == SizeUnit.OZ + assert event.confidence >= Decimal("0.85") # 10% decrease → 0.95 + + # Verify stored in DB + stored = e2e_session.execute( + select(ShrinkflationEvent).where( + ShrinkflationEvent.normalized_product_id == cheerios_product.id + ) + ).scalar_one() + assert stored.id == event.id + + def test_shrinkflation_dedup_on_repeat_detection( + self, e2e_session: Session, user: User, store: Store + ): + """Same shrinkflation detected twice returns the existing event, not a duplicate.""" + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Brand X Cereal 15 oz", + size="15", + size_unit=SizeUnit.OZ, + upc_variants=["999888777"], + ) + e2e_session.add(product) + e2e_session.flush() + + first = detect_shrinkflation(e2e_session, product, new_size="13.5", new_unit=SizeUnit.OZ) + second = detect_shrinkflation(e2e_session, product, new_size="13.5", new_unit=SizeUnit.OZ) + + assert first is not None + assert second is not None + assert first.id == second.id # same event, not duplicated + + count = len( + e2e_session.execute( + select(ShrinkflationEvent).where( + ShrinkflationEvent.normalized_product_id == product.id + ) + ) + .scalars() + .all() + ) + assert count == 1 + + +# =========================================================================== +# Test class: Event bus pub/sub for pipeline stage transitions +# =========================================================================== + + +class TestEventBusE2E: + """Redis event publishing at each pipeline stage.""" + + def test_receipt_ingested_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid EventEnvelope for RECEIPTS_INGESTED.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + subscribers = publish_event( + redis_mock, + EventType.RECEIPTS_INGESTED, + service="receiptwitness", + payload={ + "receipt_id": purchase_schema.receipt_id, + "user_id": str(user.id), + "store_slug": StoreSlug.MEIJER, + "item_count": len(purchase_schema.items), + "total": str(purchase_schema.total), + }, + ) + + assert subscribers == 1 + assert len(redis_mock._published) == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.RECEIPTS_INGESTED.value + + # Deserialize and validate the envelope + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.event_type == EventType.RECEIPTS_INGESTED + assert envelope.service == "receiptwitness" + assert envelope.payload["receipt_id"] == "MJ-2026-03-15-00042" + assert envelope.payload["item_count"] == 5 + + def test_price_updated_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for PRICES_UPDATED.""" + subscribers = publish_event( + redis_mock, + EventType.PRICES_UPDATED, + service="cartsnitch-common", + payload={ + "product_id": str(uuid.uuid4()), + "store_slug": StoreSlug.MEIJER, + "old_price": "1.89", + "new_price": "1.99", + "change_percent": "5.29", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.PRICES_UPDATED.value + + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.event_type == EventType.PRICES_UPDATED + assert envelope.payload["old_price"] == "1.89" + + def test_products_normalized_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for PRODUCTS_NORMALIZED.""" + product_id = str(uuid.uuid4()) + subscribers = publish_event( + redis_mock, + EventType.PRODUCTS_NORMALIZED, + service="cartsnitch-common", + payload={ + "product_id": product_id, + "canonical_name": "Barilla Spaghetti", + "match_method": "upc", + "confidence": "high", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.PRODUCTS_NORMALIZED.value + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.payload["confidence"] == "high" + + def test_shrinkflation_alert_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for ALERT_SHRINKFLATION.""" + subscribers = publish_event( + redis_mock, + EventType.ALERT_SHRINKFLATION, + service="shrinkray", + payload={ + "product_id": str(uuid.uuid4()), + "product_name": "Cheerios Original", + "old_size": "12 oz", + "new_size": "10.8 oz", + "confidence": "0.95", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.ALERT_SHRINKFLATION.value + + def test_full_pipeline_emits_events_at_each_stage( + self, e2e_session: Session, redis_mock, user: User, store: Store + ): + """Full pipeline: ingest → match → record price → publish events at each stage.""" + # Stage 1: Normalize receipt + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Publish receipt ingested + publish_event( + redis_mock, + EventType.RECEIPTS_INGESTED, + service="receiptwitness", + payload={ + "receipt_id": purchase_schema.receipt_id, + "item_count": len(purchase_schema.items), + }, + ) + + # Stage 2: Match products + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + for i, outcome in enumerate(outcomes): + product = outcome.match.product if outcome.match else None + if product is None: + # Auto-created — look up by name + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == purchase_schema.items[i].product_name_raw: + product = p + break + if product is None: + continue + publish_event( + redis_mock, + EventType.PRODUCTS_NORMALIZED, + service="cartsnitch-common", + payload={ + "product_id": str(product.id), + "match_method": outcome.match.method.value if outcome.match else "auto_create", + "confidence": outcome.confidence_level.value, + }, + ) + + # Stage 3: Record prices + for i, item_schema in enumerate(purchase_schema.items): + product = outcomes[i].match.product if outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + _, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=purchase_schema.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + ) + if delta and delta.is_increase: + publish_event( + redis_mock, + EventType.ALERT_PRICE_INCREASE, + service="stickershock", + payload={ + "product_id": str(product.id), + "old_price": str(delta.old_price), + "new_price": str(delta.new_price), + }, + ) + + # Verify events published at each stage + channels = [ch for ch, _ in redis_mock._published] + assert EventType.RECEIPTS_INGESTED.value in channels + assert EventType.PRODUCTS_NORMALIZED.value in channels + # No price increases on first receipt, so no ALERT_PRICE_INCREASE expected + + # All messages are valid EventEnvelopes + for _, raw_msg in redis_mock._published: + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.timestamp is not None + assert envelope.service + + +# =========================================================================== +# Test class: Error handling for malformed scraper output +# =========================================================================== + + +class TestMalformedScraperOutput: + """Error handling for bad, partial, or unexpected scraper data.""" + + def test_missing_item_name_produces_empty_string(self): + """Item with no description/name field normalizes with empty product_name_raw.""" + item = parse_meijer_item({"unitPrice": "2.99"}) + assert item.product_name_raw == "" + assert item.unit_price == Decimal("2.99") + + def test_missing_price_defaults_to_zero(self): + """Item with no price fields defaults to zero.""" + item = parse_meijer_item({"description": "Mystery Product"}) + assert item.unit_price == Decimal("0") + assert item.extended_price == Decimal("0") + + def test_non_numeric_price_defaults_to_zero(self): + """Non-numeric price strings safely default to zero.""" + item = parse_meijer_item( + { + "description": "Bad Price Item", + "unitPrice": "not_a_number", + "extendedPrice": "$$$.xx", + } + ) + assert item.unit_price == Decimal("0") + assert item.extended_price == Decimal("0") + + def test_empty_receipt_produces_empty_items(self, user: User, store: Store): + """Receipt with no items normalizes cleanly.""" + raw = {"receiptId": "EMPTY-001", "date": "2026-03-15", "total": "0.00"} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + + assert purchase.receipt_id == "EMPTY-001" + assert purchase.total == Decimal("0.00") + assert len(purchase.items) == 0 + + def test_receipt_missing_date_defaults_to_today(self, user: User, store: Store): + """Receipt with no date field defaults to today.""" + raw = {"receiptId": "NO-DATE-001", "total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + assert purchase.purchase_date == date.today() + + def test_receipt_missing_id_generates_uuid(self, user: User, store: Store): + """Receipt with no ID generates a UUID.""" + raw = {"date": "2026-03-15", "total": "10.00", "items": []} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + + # Should be a valid UUID string + uuid.UUID(purchase.receipt_id) + + def test_item_with_garbage_upc_preserves_it(self): + """UPC field with non-standard content is preserved as-is after strip.""" + item = parse_meijer_item( + { + "description": "Weird UPC Product", + "upc": " ABC-NOT-A-UPC ", + "unitPrice": "1.99", + } + ) + # lstrip("0") on "ABC-NOT-A-UPC" leaves it intact + assert item.upc == "ABC-NOT-A-UPC" + + def test_negative_prices_pass_through(self): + """Negative prices (refunds) are preserved, not zeroed.""" + item = parse_meijer_item( + { + "description": "Refund Item", + "unitPrice": "-5.99", + "extendedPrice": "-5.99", + } + ) + assert item.unit_price == Decimal("-5.99") + assert item.extended_price == Decimal("-5.99") + + def test_extended_price_auto_calculated(self): + """When extendedPrice is missing, it's calculated from unitPrice * quantity.""" + item = parse_meijer_item( + { + "description": "No Extended", + "unitPrice": "2.50", + "quantity": "3", + } + ) + assert item.extended_price == Decimal("7.50") + + def test_matching_with_malformed_items(self, e2e_session: Session): + """ProductMatcher handles items with missing/empty names gracefully.""" + matcher = ProductMatcher(e2e_session, auto_create=True) + + bad_items = [ + parse_meijer_item({"description": "", "unitPrice": "1.00"}), + parse_meijer_item({"unitPrice": "2.00"}), + ] + + outcomes = matcher.match_items(bad_items) + assert len(outcomes) == 2 + # Both should auto-create (no match possible for empty names) + assert all(o.created_new for o in outcomes) + + def test_completely_empty_receipt(self, user: User, store: Store): + """Totally empty dict produces a valid PurchaseCreate with defaults.""" + purchase = normalize_receipt({}, user_id=str(user.id), store_id=str(store.id)) + assert purchase.total == Decimal("0") + assert len(purchase.items) == 0 + assert purchase.purchase_date == date.today() + + def test_mixed_valid_and_malformed_items(self, user: User, store: Store): + """Receipt with a mix of good and bad items processes all of them.""" + raw = { + "receiptId": "MIX-001", + "date": "2026-03-15", + "total": "10.00", + "items": [ + { + "description": "Good Product 8 oz", + "upc": "1234567890", + "unitPrice": "3.99", + "extendedPrice": "3.99", + }, + { + "unitPrice": "not_a_price", + }, + { + "description": " *** Special Chars !!! ", + "unitPrice": "2.50", + }, + ], + } + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + assert len(purchase.items) == 3 + + # Good item + assert purchase.items[0].product_name_raw == "Good Product 8 oz" + assert purchase.items[0].upc == "1234567890" + + # Bad price item + assert purchase.items[1].unit_price == Decimal("0") + + # Special chars stripped + assert purchase.items[2].product_name_raw == "Special Chars" diff --git a/tests/test_pipeline_matching.py b/tests/test_pipeline_matching.py new file mode 100644 index 0000000..c73da9e --- /dev/null +++ b/tests/test_pipeline_matching.py @@ -0,0 +1,160 @@ +"""Tests for product matching & dedup pipeline.""" + +import uuid +from datetime import UTC, datetime +from decimal import Decimal + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import MatchMethod +from cartsnitch_common.pipeline.matching import ( + ProductMatcher, + classify_confidence, + match_purchase_item, +) +from cartsnitch_common.schemas.purchase import PurchaseItemCreate + + +class TestClassifyConfidence: + def test_upc_always_high(self): + assert classify_confidence(1.0, MatchMethod.UPC) == MatchConfidence.HIGH + assert classify_confidence(0.5, MatchMethod.UPC) == MatchConfidence.HIGH + + def test_name_high(self): + assert classify_confidence(0.9, MatchMethod.NAME) == MatchConfidence.HIGH + assert classify_confidence(0.8, MatchMethod.NAME) == MatchConfidence.HIGH + + def test_name_medium(self): + assert classify_confidence(0.6, MatchMethod.NAME) == MatchConfidence.MEDIUM + assert classify_confidence(0.5, MatchMethod.NAME) == MatchConfidence.MEDIUM + + def test_name_low(self): + assert classify_confidence(0.3, MatchMethod.NAME) == MatchConfidence.LOW + assert classify_confidence(0.0, MatchMethod.NAME) == MatchConfidence.LOW + + +class TestProductMatcher: + def _make_item(self, name: str, upc: str | None = None) -> PurchaseItemCreate: + return PurchaseItemCreate( + product_name_raw=name, + upc=upc, + unit_price=Decimal("3.99"), + extended_price=Decimal("3.99"), + ) + + def test_match_by_upc(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + upc_variants=["041250000001"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + item = self._make_item("Kroger Milk", upc="041250000001") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert prod.id == product.id + assert result is not None + assert result.method == MatchMethod.UPC + assert confidence == MatchConfidence.HIGH + + def test_match_by_name(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session, name_threshold=0.3) + item = self._make_item("Whole Milk Gallon Size") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is not None + assert result.method == MatchMethod.NAME + + def test_auto_create_when_no_match(self, session): + matcher = ProductMatcher(session, auto_create=True) + item = self._make_item("Unique Product XYZ 16 oz") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is None # No match found, was created + assert confidence == MatchConfidence.LOW + assert prod.canonical_name == "Unique Product XYZ 16 oz" + assert prod.size == "16" + assert prod.size_unit == "oz" + + def test_no_create_when_disabled(self, session): + matcher = ProductMatcher(session, auto_create=False) + item = self._make_item("Nonexistent Product") + prod, result, confidence = matcher.match_single(item) + + assert prod is None + assert result is None + + def test_batch_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs 12 Count", + upc_variants=["012345"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + items = [ + self._make_item("Large Eggs", upc="012345"), + self._make_item("Brand New Never Seen Product"), + ] + outcomes = matcher.match_items(items) + + assert len(outcomes) == 2 + assert outcomes[0].match is not None + assert outcomes[0].confidence_level == MatchConfidence.HIGH + assert outcomes[0].created_new is False + assert outcomes[1].match is None + assert outcomes[1].created_new is True + + +class TestMatchPurchaseItem: + def test_convenience_function(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + upc_variants=["999888"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + item = PurchaseItemCreate( + product_name_raw="Ground Beef", + upc="999888", + unit_price=Decimal("5.99"), + extended_price=Decimal("5.99"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.HIGH + + def test_auto_create_default(self, session): + item = PurchaseItemCreate( + product_name_raw="Totally New Item", + unit_price=Decimal("1.00"), + extended_price=Decimal("1.00"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.LOW diff --git a/tests/test_pipeline_price.py b/tests/test_pipeline_price.py new file mode 100644 index 0000000..63cae15 --- /dev/null +++ b/tests/test_pipeline_price.py @@ -0,0 +1,282 @@ +"""Tests for price history tracking pipeline.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +from cartsnitch_common.constants import PriceSource, StoreSlug +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.store import Store +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_latest_price, + get_price_trend, + record_price_from_item, +) + + +def _make_store(session, slug=StoreSlug.MEIJER) -> Store: + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=slug, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + return store + + +def _make_product(session, name="Test Product") -> NormalizedProduct: + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=name, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + return product + + +class TestGetLatestPrice: + def test_no_history(self, session): + product = _make_product(session) + store = _make_store(session) + result = get_latest_price(session, product.id, store.id) + assert result is None + + def test_returns_newest(self, session): + product = _make_product(session) + store = _make_store(session) + + # Add two entries + old = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + ) + new = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 10), + regular_price=Decimal("4.29"), + source=PriceSource.RECEIPT, + ) + session.add_all([old, new]) + session.flush() + + result = get_latest_price(session, product.id, store.id) + assert result is not None + assert result.regular_price == Decimal("4.29") + + +class TestRecordPriceFromItem: + def test_first_price_no_delta(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + ) + assert entry is not None + assert entry.regular_price == Decimal("3.99") + assert entry.source == PriceSource.RECEIPT + assert delta is None + + def test_price_increase_detected(self, session): + product = _make_product(session) + store = _make_store(session) + + # First price + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + ) + + # Price increase + entry, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.49"), + ) + + assert delta is not None + assert delta.old_price == Decimal("3.99") + assert delta.new_price == Decimal("4.49") + assert delta.change_amount == Decimal("0.50") + assert delta.is_increase is True + assert delta.is_decrease is False + assert delta.change_percent > Decimal("0") + + def test_price_decrease_detected(self, session): + product = _make_product(session) + store = _make_store(session) + + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("5.00"), + ) + + _, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.00"), + ) + + assert delta is not None + assert delta.is_decrease is True + assert delta.change_amount == Decimal("-1.00") + + def test_same_price_no_delta(self, session): + product = _make_product(session) + store = _make_store(session) + + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + ) + + _, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + ) + assert delta is None + + def test_sale_and_loyalty_prices_recorded(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, _ = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("5.99"), + sale_price=Decimal("4.99"), + loyalty_price=Decimal("4.49"), + coupon_price=Decimal("3.99"), + ) + assert entry.sale_price == Decimal("4.99") + assert entry.loyalty_price == Decimal("4.49") + assert entry.coupon_price == Decimal("3.99") + + def test_custom_source(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, _ = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + source=PriceSource.CATALOG, + ) + assert entry.source == PriceSource.CATALOG + + +class TestGetPriceTrend: + def test_empty_trend(self, session): + product = _make_product(session) + store = _make_store(session) + trend = get_price_trend(session, product.id, store.id) + assert trend == [] + + def test_returns_newest_first(self, session): + product = _make_product(session) + store = _make_store(session) + + for day in [1, 5, 10, 15]: + session.add( + PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, day), + regular_price=Decimal(str(3 + day * 0.1)), + source=PriceSource.RECEIPT, + ) + ) + session.flush() + + trend = get_price_trend(session, product.id, store.id) + assert len(trend) == 4 + assert trend[0].observed_date == date(2026, 3, 15) + assert trend[-1].observed_date == date(2026, 3, 1) + + def test_respects_limit(self, session): + product = _make_product(session) + store = _make_store(session) + + for day in range(1, 11): + session.add( + PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, day), + regular_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + ) + ) + session.flush() + + trend = get_price_trend(session, product.id, store.id, limit=3) + assert len(trend) == 3 + + +class TestPriceDelta: + def test_delta_properties(self): + delta = PriceDelta( + product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + old_price=Decimal("3.99"), + new_price=Decimal("4.49"), + change_amount=Decimal("0.50"), + change_percent=Decimal("12.53"), + old_date=date(2026, 3, 1), + new_date=date(2026, 3, 15), + ) + assert delta.is_increase is True + assert delta.is_decrease is False + + def test_decrease_properties(self): + delta = PriceDelta( + product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + old_price=Decimal("4.49"), + new_price=Decimal("3.99"), + change_amount=Decimal("-0.50"), + change_percent=Decimal("-11.14"), + old_date=date(2026, 3, 1), + new_date=date(2026, 3, 15), + ) + assert delta.is_decrease is True + assert delta.is_increase is False diff --git a/tests/test_pipeline_receipt.py b/tests/test_pipeline_receipt.py new file mode 100644 index 0000000..d937d39 --- /dev/null +++ b/tests/test_pipeline_receipt.py @@ -0,0 +1,204 @@ +"""Tests for receipt normalization pipeline.""" + +import uuid +from datetime import date +from decimal import Decimal + +from cartsnitch_common.pipeline.receipt import ( + _clean_product_name, + _safe_decimal, + normalize_receipt, + parse_meijer_item, +) + + +class TestCleanProductName: + def test_strips_whitespace(self): + assert _clean_product_name(" Milk ") == "Milk" + + def test_removes_leading_punctuation(self): + assert _clean_product_name("---Milk---") == "Milk" + + def test_collapses_internal_whitespace(self): + assert _clean_product_name("Whole Milk Gallon") == "Whole Milk Gallon" + + def test_empty_string(self): + assert _clean_product_name("") == "" + + +class TestSafeDecimal: + def test_string_input(self): + assert _safe_decimal("3.99") == Decimal("3.99") + + def test_float_input(self): + assert _safe_decimal(3.99) == Decimal("3.99") + + def test_int_input(self): + assert _safe_decimal(4) == Decimal("4") + + def test_none_returns_default(self): + assert _safe_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _safe_decimal(None, Decimal("1")) == Decimal("1") + + def test_invalid_returns_default(self): + assert _safe_decimal("not-a-number") == Decimal("0") + + def test_decimal_passthrough(self): + assert _safe_decimal(Decimal("5.50")) == Decimal("5.50") + + +class TestParseMeijerItem: + def test_basic_item(self): + raw = { + "description": "Kroger Whole Milk 1 Gallon", + "upc": "0041250000001", + "quantity": 1, + "unitPrice": "3.99", + "extendedPrice": "3.99", + "category": "DAIRY", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Kroger Whole Milk 1 Gallon" + assert item.upc == "41250000001" # leading zeros stripped + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("3.99") + assert item.extended_price == Decimal("3.99") + assert item.category_raw == "DAIRY" + + def test_alternate_field_names(self): + raw = { + "name": "Eggs Large 12 ct", + "upcCode": "012345", + "qty": 2, + "price": "4.50", + "totalPrice": "9.00", + "department": "EGGS", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Eggs Large 12 ct" + assert item.upc == "12345" + assert item.quantity == Decimal("2") + assert item.unit_price == Decimal("4.50") + assert item.extended_price == Decimal("9.00") + assert item.category_raw == "EGGS" + + def test_calculates_extended_from_unit_price(self): + raw = { + "description": "Bananas", + "unitPrice": "0.59", + "quantity": 3, + } + item = parse_meijer_item(raw) + assert item.extended_price == Decimal("1.77") + + def test_discounts_parsed(self): + raw = { + "description": "Cereal", + "unitPrice": "4.99", + "extendedPrice": "4.99", + "regularPrice": "5.99", + "salePrice": "4.99", + "couponAmount": "1.00", + "loyaltyAmount": "0.50", + } + item = parse_meijer_item(raw) + assert item.regular_price == Decimal("5.99") + assert item.sale_price == Decimal("4.99") + assert item.coupon_discount == Decimal("1.00") + assert item.loyalty_discount == Decimal("0.50") + + def test_alternate_discount_names(self): + raw = { + "description": "Bread", + "unitPrice": "2.99", + "extendedPrice": "2.99", + "couponDiscount": "0.75", + "loyaltyDiscount": "0.25", + } + item = parse_meijer_item(raw) + assert item.coupon_discount == Decimal("0.75") + assert item.loyalty_discount == Decimal("0.25") + + def test_missing_fields_default_gracefully(self): + raw = {"description": "Mystery Item"} + item = parse_meijer_item(raw) + assert item.product_name_raw == "Mystery Item" + assert item.upc is None + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("0") + assert item.regular_price is None + assert item.category_raw is None + + def test_no_upc_returns_none(self): + raw = {"description": "Loose Bananas", "unitPrice": "1.00", "extendedPrice": "1.00"} + item = parse_meijer_item(raw) + assert item.upc is None + + +class TestNormalizeReceipt: + def test_full_receipt(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receiptId": "REC-001", + "date": "2026-03-15", + "total": "25.47", + "subtotal": "23.00", + "tax": "2.47", + "savings": "3.00", + "items": [ + {"description": "Milk", "unitPrice": "3.99", "extendedPrice": "3.99"}, + {"description": "Bread", "unitPrice": "2.50", "extendedPrice": "2.50"}, + ], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-001" + assert purchase.purchase_date == date(2026, 3, 15) + assert purchase.total == Decimal("25.47") + assert purchase.subtotal == Decimal("23.00") + assert purchase.tax == Decimal("2.47") + assert purchase.savings_total == Decimal("3.00") + assert len(purchase.items) == 2 + assert purchase.items[0].product_name_raw == "Milk" + assert purchase.raw_data == raw + + def test_alternate_receipt_fields(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receipt_id": "REC-002", + "purchaseDate": "2026-03-14", + "totalAmount": "10.00", + "taxAmount": "0.75", + "totalSavings": "1.50", + "items": [], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-002" + assert purchase.purchase_date == date(2026, 3, 14) + assert purchase.total == Decimal("10.00") + assert purchase.tax == Decimal("0.75") + assert purchase.savings_total == Decimal("1.50") + + def test_missing_date_defaults_to_today(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date.today() + + def test_generates_receipt_id_if_missing(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "date": "2026-03-15", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id # Should be a generated UUID string + + def test_date_object_passthrough(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"date": date(2026, 1, 1), "total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date(2026, 1, 1) diff --git a/tests/test_pipeline_shrinkflation.py b/tests/test_pipeline_shrinkflation.py new file mode 100644 index 0000000..9c1bd0c --- /dev/null +++ b/tests/test_pipeline_shrinkflation.py @@ -0,0 +1,233 @@ +"""Tests for shrinkflation detection pipeline.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.pipeline.shrinkflation import ( + _to_comparable, + _units_comparable, + detect_shrinkflation, +) + + +class TestToComparable: + def test_oz_to_grams(self): + result = _to_comparable("16", SizeUnit.OZ) + assert result is not None + assert result == Decimal("16") * Decimal("28.3495") + + def test_lb_to_grams(self): + result = _to_comparable("1", SizeUnit.LB) + assert result == Decimal("453.592") + + def test_ml_to_ml(self): + assert _to_comparable("500", SizeUnit.ML) == Decimal("500") + + def test_fl_oz_to_ml(self): + result = _to_comparable("12", SizeUnit.FL_OZ) + assert result is not None + assert result == Decimal("12") * Decimal("29.5735") + + def test_count_units(self): + assert _to_comparable("12", SizeUnit.CT) == Decimal("12") + assert _to_comparable("6", SizeUnit.PK) == Decimal("6") + + def test_invalid_size(self): + assert _to_comparable("abc", SizeUnit.OZ) is None + + +class TestUnitsComparable: + def test_weight_comparable(self): + assert _units_comparable(SizeUnit.OZ, SizeUnit.LB) is True + assert _units_comparable(SizeUnit.G, SizeUnit.KG) is True + + def test_volume_comparable(self): + assert _units_comparable(SizeUnit.ML, SizeUnit.L) is True + assert _units_comparable(SizeUnit.FL_OZ, SizeUnit.ML) is True + + def test_count_comparable(self): + assert _units_comparable(SizeUnit.CT, SizeUnit.PK) is True + + def test_not_comparable_across_systems(self): + assert _units_comparable(SizeUnit.OZ, SizeUnit.ML) is False + assert _units_comparable(SizeUnit.CT, SizeUnit.OZ) is False + assert _units_comparable(SizeUnit.LB, SizeUnit.L) is False + + +class TestDetectShrinkflation: + def _make_product(self, session, size: str, unit: SizeUnit, name: str = "Test Product"): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=name, + size=size, + size_unit=unit, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + return product + + def test_detects_oz_decrease(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + assert event is not None + assert event.old_size == "16" + assert event.new_size == "14" + assert "decreased" in event.notes.lower() + + def test_no_detection_when_size_increases(self, session): + product = self._make_product(session, "14", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="16", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_no_detection_same_size(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="16", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_no_detection_incompatible_units(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="400", + new_unit=SizeUnit.ML, + ) + assert event is None + + def test_no_detection_without_existing_size(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="No Size Product", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + + event = detect_shrinkflation( + session, + product=product, + new_size="12", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_cross_unit_detection_same_system(self, session): + # 1 lb = 453.592g, 14 oz = 396.893g → size decreased + product = self._make_product(session, "1", SizeUnit.LB) + event = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + assert event is not None + + def test_count_decrease(self, session): + product = self._make_product(session, "12", SizeUnit.CT) + event = detect_shrinkflation( + session, + product=product, + new_size="10", + new_unit=SizeUnit.CT, + detected_date=date(2026, 3, 15), + ) + assert event is not None + assert event.old_size == "12" + assert event.new_size == "10" + + def test_dedup_existing_event(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + + # First detection + event1 = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + + # Same detection again — should return existing + event2 = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 16), + ) + + assert event1 is not None + assert event2 is not None + assert event1.id == event2.id + + def test_confidence_scaling(self, session): + # Small decrease (< 5%) → 0.70 + product1 = self._make_product(session, "100", SizeUnit.G, "Product A") + event1 = detect_shrinkflation( + session, + product=product1, + new_size="97", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event1 is not None + assert event1.confidence == Decimal("0.70") + + # Medium decrease (5-10%) → 0.85 + product2 = self._make_product(session, "100", SizeUnit.G, "Product B") + event2 = detect_shrinkflation( + session, + product=product2, + new_size="93", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event2 is not None + assert event2.confidence == Decimal("0.85") + + # Large decrease (>= 10%) → 0.95 + product3 = self._make_product(session, "100", SizeUnit.G, "Product C") + event3 = detect_shrinkflation( + session, + product=product3, + new_size="85", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event3 is not None + assert event3.confidence == Decimal("0.95") + + def test_min_size_decrease_threshold(self, session): + product = self._make_product(session, "100", SizeUnit.G) + # 0.5% decrease — below default 1% threshold + event = detect_shrinkflation( + session, + product=product, + new_size="99.5", + new_unit=SizeUnit.G, + min_size_decrease_pct=Decimal("1"), + ) + assert event is None diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..564665e --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,225 @@ +"""Tests for Pydantic v2 schemas.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from pydantic import ValidationError + +from cartsnitch_common.constants import ( + AccountStatus, + DiscountType, + EventType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.schemas import ( + CouponCreate, + EventEnvelope, + NormalizedProductCreate, + PriceHistoryCreate, + PurchaseCreate, + PurchaseItemCreate, + ShrinkflationEventCreate, + StoreCreate, + StoreLocationCreate, + StoreRead, + UserCreate, + UserStoreAccountCreate, +) + + +class TestStoreSchemas: + def test_store_create_valid(self): + s = StoreCreate(name="Meijer", slug=StoreSlug.MEIJER) + assert s.slug == StoreSlug.MEIJER + + def test_store_create_invalid_slug(self): + with pytest.raises(ValidationError): + StoreCreate(name="Walmart", slug="walmart") + + def test_store_read_from_attributes(self): + data = { + "id": uuid.uuid4(), + "name": "Kroger", + "slug": StoreSlug.KROGER, + "logo_url": None, + "website_url": None, + "created_at": datetime.now(UTC), + "updated_at": datetime.now(UTC), + } + s = StoreRead(**data) + assert s.slug == StoreSlug.KROGER + + +class TestStoreLocationSchemas: + def test_location_create(self): + loc = StoreLocationCreate( + store_id=uuid.uuid4(), + address="456 Oak Ave", + city="Detroit", + state="MI", + zip="48201", + ) + assert loc.city == "Detroit" + + +class TestUserSchemas: + def test_user_create_valid(self): + u = UserCreate(email="test@example.com", password="secret123") + assert u.email == "test@example.com" + + def test_user_create_invalid_email(self): + with pytest.raises(ValidationError): + UserCreate(email="not-an-email", password="secret123") + + +class TestUserStoreAccountSchemas: + def test_account_create_with_status(self): + a = UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + status=AccountStatus.EXPIRED, + ) + assert a.status == AccountStatus.EXPIRED + + def test_account_create_default_status(self): + a = UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + ) + assert a.status == AccountStatus.ACTIVE + + def test_account_create_invalid_status(self): + with pytest.raises(ValidationError): + UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + status="invalid_status", + ) + + +class TestPurchaseSchemas: + def test_purchase_create_with_items(self): + p = PurchaseCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + items=[ + PurchaseItemCreate( + product_name_raw="Milk", + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ), + ], + ) + assert len(p.items) == 1 + assert p.items[0].quantity == Decimal("1") + + +class TestNormalizedProductSchemas: + def test_product_create_with_enums(self): + p = NormalizedProductCreate( + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001"], + ) + assert p.category == ProductCategory.DAIRY + + def test_product_create_invalid_category(self): + with pytest.raises(ValidationError): + NormalizedProductCreate( + canonical_name="Test", + category="invalid_category", + ) + + +class TestPriceHistorySchemas: + def test_price_create(self): + p = PriceHistoryCreate( + normalized_product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + source=PriceSource.RECEIPT, + ) + assert p.source == PriceSource.RECEIPT + + def test_price_create_invalid_source(self): + with pytest.raises(ValidationError): + PriceHistoryCreate( + normalized_product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + source="invalid_source", + ) + + +class TestCouponSchemas: + def test_coupon_create(self): + c = CouponCreate( + store_id=uuid.uuid4(), + title="BOGO Chips", + discount_type=DiscountType.BOGO, + ) + assert c.discount_type == DiscountType.BOGO + + def test_coupon_create_invalid_discount_type(self): + with pytest.raises(ValidationError): + CouponCreate( + store_id=uuid.uuid4(), + title="Test", + discount_type="free_stuff", + ) + + +class TestShrinkflationEventSchemas: + def test_shrinkflation_create(self): + s = ShrinkflationEventCreate( + normalized_product_id=uuid.uuid4(), + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + confidence=Decimal("0.95"), + ) + assert s.old_unit == SizeUnit.OZ + + def test_shrinkflation_create_invalid_unit(self): + with pytest.raises(ValidationError): + ShrinkflationEventCreate( + normalized_product_id=uuid.uuid4(), + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit="bushels", + new_unit=SizeUnit.OZ, + ) + + +class TestEventEnvelope: + def test_valid_event(self): + e = EventEnvelope( + event_type=EventType.RECEIPTS_INGESTED, + timestamp=datetime.now(UTC), + service="receiptwitness", + payload={"receipt_id": "RCP-001"}, + ) + assert e.event_type == EventType.RECEIPTS_INGESTED + + def test_invalid_event_type(self): + with pytest.raises(ValidationError): + EventEnvelope( + event_type="invalid.event", + timestamp=datetime.now(UTC), + service="test", + payload={}, + ) diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 0000000..8007f9c --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,357 @@ +"""Tests for the seed data generator.""" + +import random + +from faker import Faker + +from cartsnitch_common.seed.config import ( + NUM_ACTIVE_USERS, + NUM_COUPONS, + NUM_PRICE_HISTORY, + NUM_PRODUCTS, + NUM_PURCHASE_ITEMS, + NUM_PURCHASES, + NUM_SHRINKFLATION_EVENTS, + NUM_STORES, + NUM_USERS, + SEED_END_DATE, + SEED_START_DATE, + SEED_VALUE, +) +from cartsnitch_common.seed.generators.coupons import generate_coupons +from cartsnitch_common.seed.generators.prices import generate_price_history +from cartsnitch_common.seed.generators.products import generate_products +from cartsnitch_common.seed.generators.purchases import generate_purchase_items, generate_purchases +from cartsnitch_common.seed.generators.shrinkflation import generate_shrinkflation_events +from cartsnitch_common.seed.generators.stores import generate_store_locations, generate_stores +from cartsnitch_common.seed.generators.users import generate_users + + +def _seed() -> None: + random.seed(SEED_VALUE) + Faker.seed(SEED_VALUE) + + +def _make_fake() -> Faker: + return Faker() + + +# --------------------------------------------------------------------------- +# Stores +# --------------------------------------------------------------------------- + + +def test_generate_stores_count() -> None: + stores = generate_stores() + assert len(stores) == NUM_STORES + + +def test_generate_stores_deterministic() -> None: + stores_a = generate_stores() + stores_b = generate_stores() + # Stores are fixed (no RNG), so slugs are stable + slugs_a = {s["slug"] for s in stores_a} + slugs_b = {s["slug"] for s in stores_b} + assert slugs_a == slugs_b + + +def test_generate_store_locations_count() -> None: + stores = generate_stores() + locs = generate_store_locations(stores) + assert len(locs) == 15 # 3 stores * 5 locations + + +def test_generate_store_locations_fk() -> None: + stores = generate_stores() + locs = generate_store_locations(stores) + store_ids = {s["id"] for s in stores} + for loc in locs: + assert loc["store_id"] in store_ids + + +# --------------------------------------------------------------------------- +# Users +# --------------------------------------------------------------------------- + + +def test_generate_users_count() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + assert len(users) == NUM_USERS + + +def test_generate_users_active_count() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + active = [u for u in users if u["_active"]] + assert len(active) == NUM_ACTIVE_USERS + + +def test_generate_users_deterministic() -> None: + _seed() + fake_a = _make_fake() + users_a = generate_users(fake_a) + + _seed() + fake_b = _make_fake() + users_b = generate_users(fake_b) + + # Emails should match (same seed → same Faker output) + emails_a = [u["email"] for u in users_a] + emails_b = [u["email"] for u in users_b] + assert emails_a == emails_b + + +def test_generate_users_unique_emails() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + emails = [u["email"] for u in users] + assert len(emails) == len(set(emails)) + + +# --------------------------------------------------------------------------- +# Products +# --------------------------------------------------------------------------- + + +def test_generate_products_count() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + assert len(products) == NUM_PRODUCTS + + +def test_generate_products_deterministic() -> None: + _seed() + fake_a = _make_fake() + products_a = generate_products(fake_a) + + _seed() + fake_b = _make_fake() + products_b = generate_products(fake_b) + + names_a = [p["canonical_name"] for p in products_a] + names_b = [p["canonical_name"] for p in products_b] + assert names_a == names_b + + +def test_generate_products_have_categories() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + for product in products: + assert product["category"] is not None + + +def test_generate_products_have_upc_variants() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + for product in products: + assert product["upc_variants"] + assert isinstance(product["upc_variants"], list) + assert len(product["upc_variants"]) >= 1 + + +# --------------------------------------------------------------------------- +# Purchases +# --------------------------------------------------------------------------- + + +def test_generate_purchases_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + assert len(purchases) == NUM_PURCHASES + + +def test_generate_purchases_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + + user_ids = {u["id"] for u in users} + store_ids = {s["id"] for s in stores} + for p in purchases: + assert p["user_id"] in user_ids + assert p["store_id"] in store_ids + + +def test_generate_purchase_items_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + # Should be close to target (within 20%) + assert abs(len(items) - NUM_PURCHASE_ITEMS) < NUM_PURCHASE_ITEMS * 0.20 + + +def test_generate_purchase_items_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + + purchase_ids = {p["id"] for p in purchases} + product_ids = {p["id"] for p in products} + for item in items: + assert item["purchase_id"] in purchase_ids + assert item["normalized_product_id"] in product_ids + + +# --------------------------------------------------------------------------- +# Price History +# --------------------------------------------------------------------------- + + +def test_generate_price_history_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + # Should be within 10% of target + assert abs(len(prices) - NUM_PRICE_HISTORY) < NUM_PRICE_HISTORY * 0.10 + + +def test_generate_price_history_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + + product_ids = {p["id"] for p in products} + store_ids = {s["id"] for s in stores} + for ph in prices: + assert ph["normalized_product_id"] in product_ids + assert ph["store_id"] in store_ids + assert ph["regular_price"] > 0 + + +def test_price_history_dates_in_range() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + + for ph in prices: + assert SEED_START_DATE <= ph["observed_date"] <= SEED_END_DATE + + +# --------------------------------------------------------------------------- +# Coupons +# --------------------------------------------------------------------------- + + +def test_generate_coupons_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + products = generate_products(fake) + coupons = generate_coupons(fake, products, stores) + assert len(coupons) == NUM_COUPONS + + +def test_generate_coupons_mix() -> None: + """Verify ~60% expired and ~40% active.""" + _seed() + fake = _make_fake() + stores = generate_stores() + products = generate_products(fake) + coupons = generate_coupons(fake, products, stores) + + expired = [c for c in coupons if c["valid_to"] < SEED_END_DATE] + active = [c for c in coupons if c["valid_to"] >= SEED_END_DATE] + # Allow ±15% variance from target + assert len(expired) / NUM_COUPONS > 0.45 + assert len(active) / NUM_COUPONS > 0.25 + + +# --------------------------------------------------------------------------- +# Shrinkflation +# --------------------------------------------------------------------------- + + +def test_generate_shrinkflation_count() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + assert len(events) == NUM_SHRINKFLATION_EVENTS + + +def test_generate_shrinkflation_fk() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + product_ids = {p["id"] for p in products} + for event in events: + assert event["normalized_product_id"] in product_ids + + +def test_generate_shrinkflation_price_held_or_increased() -> None: + """Validate shrinkflation: new_size < old_size, price maintained or up.""" + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + for event in events: + old_size = float(event["old_size"]) + new_size = float(event["new_size"]) + assert new_size < old_size, f"Expected size reduction: {old_size} -> {new_size}" + if event["price_at_old_size"] and event["price_at_new_size"]: + # Price should be maintained or increased (not significantly dropped) + assert float(event["price_at_new_size"]) >= float(event["price_at_old_size"]) * 0.95 + + +def test_generate_shrinkflation_confidence_range() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + for event in events: + assert 0 <= float(event["confidence"]) <= 1.0 + + +# --------------------------------------------------------------------------- +# Dry-run smoke test +# --------------------------------------------------------------------------- + + +def test_dry_run_does_not_raise() -> None: + """Smoke test the full run_seed in dry-run mode.""" + from cartsnitch_common.seed.runner import run_seed + + run_seed(dry_run=True, seed_value=SEED_VALUE) From 342906c9d178923d462a08aec35e486703366eba Mon Sep 17 00:00:00 2001 From: Coupon Carl Date: Sat, 28 Mar 2026 02:24:22 +0000 Subject: [PATCH 3/4] Squashed 'receiptwitness/' content from commit e8d374a git-subtree-dir: receiptwitness git-subtree-split: e8d374a89ed8978f429598e02d31b1c5963efe22 --- .dockerignore | 12 + .github/workflows/ci.yml | 168 +++++ .gitignore | 7 + CLAUDE.md | 227 +++++++ Dockerfile | 67 ++ pyproject.toml | 54 ++ renovate.json | 4 + src/receiptwitness/__init__.py | 1 + src/receiptwitness/api/__init__.py | 1 + src/receiptwitness/api/routes.py | 10 + src/receiptwitness/config.py | 26 + src/receiptwitness/events.py | 75 +++ src/receiptwitness/main.py | 8 + src/receiptwitness/parsers/__init__.py | 1 + src/receiptwitness/parsers/kroger.py | 148 +++++ src/receiptwitness/parsers/meijer.py | 138 +++++ src/receiptwitness/parsers/target.py | 191 ++++++ src/receiptwitness/pipeline/__init__.py | 30 + src/receiptwitness/pipeline/matching.py | 136 ++++ src/receiptwitness/pipeline/normalization.py | 155 +++++ src/receiptwitness/pipeline/receipt.py | 144 +++++ src/receiptwitness/scrapers/__init__.py | 1 + src/receiptwitness/scrapers/base.py | 72 +++ src/receiptwitness/scrapers/kroger.py | 344 ++++++++++ src/receiptwitness/scrapers/meijer.py | 301 +++++++++ src/receiptwitness/scrapers/target.py | 326 ++++++++++ src/receiptwitness/session/__init__.py | 1 + src/receiptwitness/session/encryption.py | 52 ++ src/receiptwitness/session/manager.py | 81 +++ tests/conftest.py | 29 + tests/fixtures/kroger_receipt.json | 131 ++++ tests/fixtures/meijer_receipt.json | 85 +++ tests/fixtures/target_receipt.json | 140 +++++ tests/test_parsers/__init__.py | 0 tests/test_parsers/test_kroger_parser.py | 399 ++++++++++++ tests/test_parsers/test_meijer_parser.py | 174 ++++++ tests/test_parsers/test_target_parser.py | 471 ++++++++++++++ tests/test_pipeline/__init__.py | 0 tests/test_pipeline/conftest.py | 23 + tests/test_pipeline/test_matching.py | 161 +++++ tests/test_pipeline/test_normalization.py | 158 +++++ tests/test_pipeline/test_receipt.py | 204 ++++++ tests/test_regression/__init__.py | 0 tests/test_regression/test_layout_changes.py | 435 +++++++++++++ tests/test_regression/test_rate_limiting.py | 365 +++++++++++ .../test_regression/test_schema_validation.py | 364 +++++++++++ tests/test_scrapers/__init__.py | 0 tests/test_scrapers/test_base.py | 58 ++ tests/test_scrapers/test_kroger_scraper.py | 574 +++++++++++++++++ tests/test_scrapers/test_meijer_scraper.py | 585 ++++++++++++++++++ tests/test_session/__init__.py | 0 tests/test_session/test_encryption.py | 61 ++ tests/test_session/test_manager.py | 102 +++ 53 files changed, 7300 insertions(+) create mode 100644 .dockerignore create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 pyproject.toml create mode 100644 renovate.json create mode 100644 src/receiptwitness/__init__.py create mode 100644 src/receiptwitness/api/__init__.py create mode 100644 src/receiptwitness/api/routes.py create mode 100644 src/receiptwitness/config.py create mode 100644 src/receiptwitness/events.py create mode 100644 src/receiptwitness/main.py create mode 100644 src/receiptwitness/parsers/__init__.py create mode 100644 src/receiptwitness/parsers/kroger.py create mode 100644 src/receiptwitness/parsers/meijer.py create mode 100644 src/receiptwitness/parsers/target.py create mode 100644 src/receiptwitness/pipeline/__init__.py create mode 100644 src/receiptwitness/pipeline/matching.py create mode 100644 src/receiptwitness/pipeline/normalization.py create mode 100644 src/receiptwitness/pipeline/receipt.py create mode 100644 src/receiptwitness/scrapers/__init__.py create mode 100644 src/receiptwitness/scrapers/base.py create mode 100644 src/receiptwitness/scrapers/kroger.py create mode 100644 src/receiptwitness/scrapers/meijer.py create mode 100644 src/receiptwitness/scrapers/target.py create mode 100644 src/receiptwitness/session/__init__.py create mode 100644 src/receiptwitness/session/encryption.py create mode 100644 src/receiptwitness/session/manager.py create mode 100644 tests/conftest.py create mode 100644 tests/fixtures/kroger_receipt.json create mode 100644 tests/fixtures/meijer_receipt.json create mode 100644 tests/fixtures/target_receipt.json create mode 100644 tests/test_parsers/__init__.py create mode 100644 tests/test_parsers/test_kroger_parser.py create mode 100644 tests/test_parsers/test_meijer_parser.py create mode 100644 tests/test_parsers/test_target_parser.py create mode 100644 tests/test_pipeline/__init__.py create mode 100644 tests/test_pipeline/conftest.py create mode 100644 tests/test_pipeline/test_matching.py create mode 100644 tests/test_pipeline/test_normalization.py create mode 100644 tests/test_pipeline/test_receipt.py create mode 100644 tests/test_regression/__init__.py create mode 100644 tests/test_regression/test_layout_changes.py create mode 100644 tests/test_regression/test_rate_limiting.py create mode 100644 tests/test_regression/test_schema_validation.py create mode 100644 tests/test_scrapers/__init__.py create mode 100644 tests/test_scrapers/test_base.py create mode 100644 tests/test_scrapers/test_kroger_scraper.py create mode 100644 tests/test_scrapers/test_meijer_scraper.py create mode 100644 tests/test_session/__init__.py create mode 100644 tests/test_session/test_encryption.py create mode 100644 tests/test_session/test_manager.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..289a751 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +.pytest_cache/ +*.egg-info/ +dist/ +.venv/ +.env +.git/ +.github/ +tests/ +*.md +renovate.json diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..785af69 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,168 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + packages: write + +env: + REGISTRY: ghcr.io + IMAGE_NAME: cartsnitch/receiptwitness + +jobs: + lint: + runs-on: runners-cartsnitch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@76685ed0384103228cd670b477b967e7752ebe6b" + - run: pip install ruff + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + + typecheck: + runs-on: runners-cartsnitch + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@76685ed0384103228cd670b477b967e7752ebe6b" + - run: pip install -e ".[dev]" mypy + - name: Type check + run: mypy src/receiptwitness + + test: + runs-on: runners-cartsnitch + services: + postgres: + image: postgres:15-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + env: + POSTGRES_USER: cartsnitch + POSTGRES_PASSWORD: cartsnitch_test + POSTGRES_DB: cartsnitch_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + redis: + image: redis:7-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + DATABASE_URL: postgresql://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + REDIS_URL: redis://localhost:6379/0 + ENCRYPTION_KEY: dGVzdC1lbmNyeXB0aW9uLWtleS0xMjM0NTY3ODk= + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install cartsnitch-common from GitHub + run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@76685ed0384103228cd670b477b967e7752ebe6b" + - run: pip install -e ".[dev]" + - name: Install Playwright browsers + run: playwright install chromium --with-deps + - name: Run tests + run: pytest --tb=short -q + + build-and-push: + runs-on: runners-cartsnitch + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Generate CalVer tag + id: calver + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + DATE_TAG=$(date -u +%Y.%m.%d) + EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1) + if [ -z "$EXISTING" ]; then + VERSION="$DATE_TAG" + elif [ "$EXISTING" = "v${DATE_TAG}" ]; then + VERSION="${DATE_TAG}.2" + else + BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//") + VERSION="${DATE_TAG}.$((BUILD_NUM + 1))" + fi + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + echo "CalVer tag: $VERSION" + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=sha,prefix=sha- + type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + target: prod + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Create git tag + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + git tag "v${{ steps.calver.outputs.version }}" + git push origin "v${{ steps.calver.outputs.version }}" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..687387e --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +*.pyc +.pytest_cache/ +*.egg-info/ +dist/ +.venv/ +.env diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..255b742 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,227 @@ +# ReceiptWitness — CartSnitch Receipt Ingestion Service + +## Project Context + +CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/receiptwitness`) is the receipt/purchase history ingestion service. + +**GitHub org:** github.com/cartsnitch +**Domain:** cartsnitch.com + +### CartSnitch Services + +| Repo | Service | Purpose | +|------|---------|---------| +| `cartsnitch/common` | — | Shared models, schemas, utilities | +| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers (this repo) | +| `cartsnitch/api` | API Gateway | Frontend-facing REST API | +| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) | +| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | +| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | +| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | +| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations | + +### Architecture Decisions + +- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline. +- **Shared DB:** One PostgreSQL cluster. This service writes to `purchases`, `purchase_items`, `price_history` tables. Models come from `cartsnitch-common`. +- **Inter-service comms:** REST (synchronous) + Redis pub/sub (async events). +- **Target scale:** 500–1,000 users. Each user has their own authenticated sessions to up to 3 retailers. + +## What This Service Does + +ReceiptWitness authenticates with grocery retailer web portals using per-user sessions, scrapes purchase history / receipt data, parses it into structured records, and writes it to the shared database. After ingestion, it publishes a `cartsnitch.receipts.ingested` event so downstream services (StickerShock, ClipArtist) can react. + +### Target Retailers (MVP) + +#### Meijer (mPerks) +- **Auth:** No public API. Session cookie-based auth on mperks.meijer.com. +- **Receipt location:** meijer.com/mperks/receipts-savings.html (or underlying XHR endpoints) +- **Approach:** Playwright login → capture session → hit receipt XHR endpoints directly. Map the API calls the frontend makes via browser dev tools network tab. +- **Prior art:** `dapperfu/python_Meijer` (requires MITM proxy for auth — avoid this pattern, prefer direct browser automation). +- **Data available:** Digital receipts appear ~15 minutes after purchase if mPerks ID was used at checkout. Includes item names, prices, discounts, savings. + +#### Kroger +- **Auth:** No public API for purchase history (that's behind Partner API). Session cookie-based auth on kroger.com. +- **Receipt location:** kroger.com/mypurchases +- **Approach:** Playwright login → scrape purchase history pages or intercept XHR endpoints. +- **Anti-bot:** Kroger uses Akamai Bot Manager. Aggressive headless browser detection. Need Playwright stealth, realistic fingerprinting, human-like interaction pacing. +- **Prior art:** `phyllis-vance/KrogerScrape` (.NET, old), `callaginn/kroger-sweeper` (Puppeteer/Node), `ThermoMan/Get-Kroger-Grocery-List` (Greasemonkey userscript). +- **Kroger public API:** Free developer account at developer.kroger.com provides product catalog data (`product.compact` scope) — useful for enriching scraped receipt data with UPCs, categories, product images. NOT useful for purchase history. +- **Data available:** Purchase history tied to Kroger Plus loyalty card. Shows items, prices, quantities. + +#### Target (Circle) +- **Auth:** Session-based auth on target.com. +- **Receipt location:** target.com account → Orders → In-store tab, or target.com/account/orders +- **Approach:** Playwright login → scrape in-store purchase history. +- **Data available:** ~1 year of history if user paid with a linked card, used the Target app wallet, or entered their Target Circle phone number at checkout. Includes item names, prices. + +## Tech Stack + +- Python 3.12+ +- Playwright (Python async API) for headless browser automation +- FastAPI (lightweight internal API for triggering scrapes, health checks, status) +- SQLAlchemy 2.0 (via `cartsnitch-common`) +- Redis (pub/sub event publishing) +- APScheduler or Celery (for scheduled scraping jobs) +- cryptography / Fernet (encrypting stored session data) + +## Repo Structure + +``` +receiptwitness/ +├── CLAUDE.md +├── README.md +├── pyproject.toml +├── Dockerfile # Playwright + Chromium headless +├── docker-compose.yml # Local dev (Postgres, Redis, this service) +├── src/ +│ └── receiptwitness/ +│ ├── __init__.py +│ ├── config.py # Service-specific settings +│ ├── main.py # FastAPI app + scheduler bootstrap +│ ├── scrapers/ +│ │ ├── __init__.py +│ │ ├── base.py # Abstract BaseScraper class +│ │ ├── meijer.py # Meijer/mPerks scraper +│ │ ├── kroger.py # Kroger scraper +│ │ └── target.py # Target/Circle scraper +│ ├── parsers/ +│ │ ├── __init__.py +│ │ ├── meijer.py # Parse raw Meijer receipt data → PurchaseItem records +│ │ ├── kroger.py +│ │ └── target.py +│ ├── session/ +│ │ ├── __init__.py +│ │ ├── manager.py # Session storage, retrieval, refresh logic +│ │ └── encryption.py # Encrypt/decrypt session cookies at rest +│ ├── scheduler.py # Scrape scheduling (per-user cron jobs) +│ ├── events.py # Publish receipt.ingested events to Redis +│ ├── api/ +│ │ ├── __init__.py +│ │ ├── routes.py # Internal API: trigger scrape, check status, health +│ │ └── auth.py # Internal service auth (API key or JWT) +│ └── enrichment.py # Optional: enrich receipt data via Kroger public API +└── tests/ + ├── conftest.py + ├── fixtures/ # Sample receipt HTML/JSON for testing parsers + │ ├── meijer_receipt.json + │ ├── kroger_receipt.html + │ └── target_receipt.html + ├── test_scrapers/ + ├── test_parsers/ + └── test_session/ +``` + +## Scraper Architecture + +### Base Scraper Pattern + +```python +class BaseScraper(ABC): + """All retailer scrapers implement this interface.""" + + @abstractmethod + async def login(self, credentials: UserStoreAccount) -> SessionData: ... + + @abstractmethod + async def check_session(self, session: SessionData) -> bool: ... + + @abstractmethod + async def scrape_receipts(self, session: SessionData, since: datetime | None) -> list[RawReceipt]: ... + + @abstractmethod + def parse_receipt(self, raw: RawReceipt) -> tuple[Purchase, list[PurchaseItem]]: ... +``` + +### Scraping Flow + +1. **Scheduler fires** for a user+store combination +2. **Load session** from `user_store_accounts` table (encrypted) +3. **Check session validity** — quick lightweight request to verify auth +4. **If expired:** launch Playwright, re-authenticate, save new session +5. **Scrape receipts** since `last_sync_at` timestamp +6. **Parse** raw data into `Purchase` and `PurchaseItem` records +7. **Deduplicate** — skip receipts already in DB (match on `receipt_id` per store) +8. **Write to DB** — insert new purchases and items +9. **Derive price_history** entries from purchase_items +10. **Publish event** — `cartsnitch.receipts.ingested` to Redis +11. **Update** `user_store_accounts.last_sync_at` + +### Session Management + +- Sessions (cookies, tokens) are encrypted at rest using Fernet symmetric encryption. +- The encryption key is provided via environment variable, not stored in the DB. +- Sessions are stored in the `user_store_accounts` table as encrypted JSONB. +- Each scrape attempt first checks if the existing session is valid before launching a full Playwright browser instance. +- When a session expires, the service needs the user's stored credentials OR a manual re-auth flow (the user logs in via the frontend, and we capture the session). + +### Anti-Bot Considerations + +- Use `playwright-stealth` or equivalent to mask automation signals. +- Set realistic viewport sizes, user agents, and locale settings. +- Add human-like delays between page navigations (randomized 1-5 seconds). +- For Kroger specifically (Akamai Bot Manager): may need to use non-headless mode on initial auth, or route through a persistent browser profile that has established trust. +- Rate limit scraping: no more than 1 scrape per user per store per hour. Default cadence: once daily. +- Store and reuse browser profiles/cookies to minimize fresh logins. + +### Dockerfile + +The Dockerfile must include Playwright and Chromium. Base image pattern: + +```dockerfile +FROM mcr.microsoft.com/playwright/python:v1.49.0-noble +# Install deps, copy code, etc. +``` + +This is a large image (~2GB) due to Chromium. Consider multi-stage builds if the final image can be slimmed down. + +## Internal API Endpoints + +This service exposes a lightweight internal API (not public-facing): + +- `GET /health` — health check +- `GET /status/{user_id}` — sync status per store for a user +- `POST /scrape/{user_id}/{store_slug}` — trigger an immediate scrape for a user+store +- `POST /scrape/{user_id}/all` — trigger scrape across all configured stores +- `GET /sessions/{user_id}` — list configured store sessions and their status + +The public-facing API gateway (`cartsnitch/api`) proxies user-facing requests to this service's internal API. + +## Events Published + +### `cartsnitch.receipts.ingested` + +Published after new receipt data is successfully written to the DB. + +```json +{ + "event_type": "cartsnitch.receipts.ingested", + "timestamp": "2026-03-15T12:00:00Z", + "service": "receiptwitness", + "payload": { + "user_id": "uuid", + "store_slug": "meijer", + "purchase_id": "uuid", + "purchase_date": "2026-03-14", + "item_count": 23, + "total": 87.42 + } +} +``` + +## Development Workflow + +- **Never push directly to main.** Always create feature branches and open PRs. +- Branch naming: `feature//` or `fix/` +- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` +- Test parsers with fixture data (sample receipts in `tests/fixtures/`). Scraper integration tests require real credentials and should be tagged/skipped in CI. +- Local dev: `docker-compose up` starts Postgres, Redis, and the service. Playwright runs inside the container. + +## Important Notes + +- The Playwright container image is large. On K8s, consider using a dedicated node or tolerating scheduling delays. +- Each user needs their own authenticated sessions. At 1,000 users × 3 stores = 3,000 sessions to manage. Sessions expire at different rates per retailer. +- Scraping must be respectful: randomized intervals, rate limiting, no parallel scraping of the same store for the same user. +- Receipt data structure varies significantly between retailers. The parsers must be robust and handle edge cases (returns, voided items, weighted produce, BOGO items, coupon stacking). +- Kroger's public API (`product.compact` scope) can be used to enrich scraped data with UPCs and product metadata after receipt parsing. This is optional but improves product normalization downstream. +- Store credentials for users should ideally NOT be stored by CartSnitch. Prefer a flow where the user authenticates in a controlled browser session, and we capture/store only the resulting session cookies. If credential storage is necessary, use strong encryption and make the tradeoffs clear to users. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..bb6300d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,67 @@ +# Stage 1: Build dependencies +FROM python:3.12-slim AS build + +WORKDIR /app + +# git is required to install cartsnitch-common from GitHub; build-essential and +# libpq-dev are needed to compile any C-extension wheels (e.g. psycopg2 fallback) +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + libpq-dev \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY pyproject.toml ./ +COPY src/ ./src/ + +# cartsnitch-common is not on PyPI — install it directly from GitHub, then +# install the rest of the package dependencies in a single resolver pass so +# pip can satisfy the cartsnitch-common>=0.1.0 constraint declared in +# pyproject.toml without hitting PyPI for it. +RUN pip install --no-cache-dir --prefix=/install \ + "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@76685ed0384103228cd670b477b967e7752ebe6b" \ + . + +# Stage 2: Production image with Playwright + Chromium +FROM python:3.12-slim AS prod + +WORKDIR /app + +# Install Playwright system dependencies for Chromium +RUN apt-get update && apt-get install -y --no-install-recommends \ + libnss3 \ + libatk1.0-0 \ + libatk-bridge2.0-0 \ + libcups2 \ + libdrm2 \ + libxkbcommon0 \ + libxcomposite1 \ + libxdamage1 \ + libxrandr2 \ + libgbm1 \ + libpango-1.0-0 \ + libcairo2 \ + libasound2 \ + libxshmfence1 \ + libx11-xcb1 \ + libxcb-dri3-0 \ + fonts-liberation \ + && rm -rf /var/lib/apt/lists/* + +RUN adduser --system --group --uid 1000 app + +COPY --from=build /install /usr/local +COPY src/ ./src/ + +# Install Playwright Chromium browser (runs as root; /opt/playwright is world-readable) +RUN PLAYWRIGHT_BROWSERS_PATH=/opt/playwright playwright install chromium + +ENV PLAYWRIGHT_BROWSERS_PATH=/opt/playwright + +USER 1000 +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=3s \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" + +CMD ["uvicorn", "receiptwitness.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f32acfc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,54 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "receiptwitness" +version = "0.1.0" +description = "CartSnitch receipt/purchase history ingestion service" +requires-python = ">=3.12" +dependencies = [ + "cartsnitch-common>=0.1.0", + "playwright>=1.49,<2.0", + "playwright-stealth>=1.0,<2.0", + "cryptography>=42.0,<44.0", + "fastapi>=0.115,<1.0", + "uvicorn[standard]>=0.30,<1.0", + "redis>=5.0,<6.0", + "pydantic>=2.0,<3.0", + "pydantic-settings>=2.0,<3.0", + "sqlalchemy[asyncio]>=2.0,<3.0", + "asyncpg>=0.29,<1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "ruff>=0.3", + "pytest-cov>=5.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/receiptwitness"] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] + +[tool.mypy] +python_version = "3.12" +strict = false +warn_return_any = true +warn_unused_ignores = true + +[[tool.mypy.overrides]] +module = "cartsnitch_common.*" +ignore_missing_imports = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..833ba3b --- /dev/null +++ b/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["local>cartsnitch/.github:renovate-config"] +} diff --git a/src/receiptwitness/__init__.py b/src/receiptwitness/__init__.py new file mode 100644 index 0000000..6b17aab --- /dev/null +++ b/src/receiptwitness/__init__.py @@ -0,0 +1 @@ +"""ReceiptWitness — CartSnitch receipt ingestion service.""" diff --git a/src/receiptwitness/api/__init__.py b/src/receiptwitness/api/__init__.py new file mode 100644 index 0000000..74ded59 --- /dev/null +++ b/src/receiptwitness/api/__init__.py @@ -0,0 +1 @@ +"""Internal API for ReceiptWitness service.""" diff --git a/src/receiptwitness/api/routes.py b/src/receiptwitness/api/routes.py new file mode 100644 index 0000000..23cc109 --- /dev/null +++ b/src/receiptwitness/api/routes.py @@ -0,0 +1,10 @@ +"""Internal API routes for triggering scrapes and checking status.""" + +from fastapi import APIRouter + +router = APIRouter() + + +@router.get("/health") +async def health(): + return {"status": "ok", "service": "receiptwitness"} diff --git a/src/receiptwitness/config.py b/src/receiptwitness/config.py new file mode 100644 index 0000000..1341f3f --- /dev/null +++ b/src/receiptwitness/config.py @@ -0,0 +1,26 @@ +"""Service-specific configuration for ReceiptWitness.""" + +from pydantic_settings import BaseSettings + + +class ReceiptWitnessSettings(BaseSettings): + model_config = {"env_prefix": "RW_"} + + # Inherited from cartsnitch-common + database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + redis_url: str = "redis://localhost:6379/0" + + # Session encryption + session_encryption_key: str = "" + + # Scraping defaults + scrape_interval_seconds: int = 86400 # 24 hours + min_request_delay_ms: int = 1000 + max_request_delay_ms: int = 5000 + + # Playwright + headless: bool = True + browser_timeout_ms: int = 60000 + + +settings = ReceiptWitnessSettings() diff --git a/src/receiptwitness/events.py b/src/receiptwitness/events.py new file mode 100644 index 0000000..3d75614 --- /dev/null +++ b/src/receiptwitness/events.py @@ -0,0 +1,75 @@ +"""Publish receipt ingestion events to Redis/DragonflyDB pub/sub.""" + +import json +import logging +from datetime import UTC, datetime +from decimal import Decimal + +import redis.asyncio as aioredis + +from receiptwitness.config import settings + +logger = logging.getLogger(__name__) + +CHANNEL_RECEIPTS_INGESTED = "cartsnitch.receipts.ingested" + +# Module-level connection pool — shared across all publish calls +_pool: aioredis.ConnectionPool | None = None + + +class _DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Decimal): + return float(o) + return super().default(o) + + +def _get_pool() -> aioredis.ConnectionPool: + """Get or create the shared Redis connection pool.""" + global _pool + if _pool is None: + _pool = aioredis.ConnectionPool.from_url( + settings.redis_url, decode_responses=True, max_connections=10 + ) + return _pool + + +async def get_redis_client() -> aioredis.Redis: + """Create an async Redis/DragonflyDB client with connection pooling.""" + return aioredis.Redis(connection_pool=_get_pool()) + + +async def publish_receipt_ingested( + user_id: str, + store_slug: str, + purchase_id: str, + purchase_date: str, + item_count: int, + total: Decimal | float, +) -> None: + """Publish a cartsnitch.receipts.ingested event after successful ingestion.""" + event = { + "event_type": CHANNEL_RECEIPTS_INGESTED, + "timestamp": datetime.now(UTC).isoformat(), + "service": "receiptwitness", + "payload": { + "user_id": user_id, + "store_slug": store_slug, + "purchase_id": purchase_id, + "purchase_date": purchase_date, + "item_count": item_count, + "total": float(total) if isinstance(total, Decimal) else total, + }, + } + + try: + client = await get_redis_client() + await client.publish(CHANNEL_RECEIPTS_INGESTED, json.dumps(event, cls=_DecimalEncoder)) + logger.info( + "Published %s event for purchase %s", + CHANNEL_RECEIPTS_INGESTED, + purchase_id, + ) + except aioredis.ConnectionError: + logger.error("Failed to publish event — Redis/DragonflyDB connection error") + raise diff --git a/src/receiptwitness/main.py b/src/receiptwitness/main.py new file mode 100644 index 0000000..55cda42 --- /dev/null +++ b/src/receiptwitness/main.py @@ -0,0 +1,8 @@ +"""FastAPI app entrypoint for ReceiptWitness.""" + +from fastapi import FastAPI + +from receiptwitness.api.routes import router + +app = FastAPI(title="ReceiptWitness", version="0.1.0") +app.include_router(router) diff --git a/src/receiptwitness/parsers/__init__.py b/src/receiptwitness/parsers/__init__.py new file mode 100644 index 0000000..2b56ce8 --- /dev/null +++ b/src/receiptwitness/parsers/__init__.py @@ -0,0 +1 @@ +"""Receipt parsers for each retailer.""" diff --git a/src/receiptwitness/parsers/kroger.py b/src/receiptwitness/parsers/kroger.py new file mode 100644 index 0000000..13e5a20 --- /dev/null +++ b/src/receiptwitness/parsers/kroger.py @@ -0,0 +1,148 @@ +"""Kroger receipt parser. + +Transforms raw Kroger receipt JSON into the common PurchaseCreate schema. +Kroger receipt data uses different field names than Meijer — this parser +handles Kroger-specific naming conventions and receipt structure. +""" + +import logging +from decimal import Decimal, InvalidOperation + +from receiptwitness.scrapers.base import RawReceipt + +logger = logging.getLogger(__name__) + + +def _to_decimal(value, default: str = "0") -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return Decimal(default) + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError, TypeError): + return Decimal(default) + + +def _parse_item(item: dict) -> dict: + """Parse a single line item from a Kroger receipt. + + Kroger items typically include fields like: + - description / itemDescription / productName + - upc / krogerProductId + - quantity / qty + - basePrice / unitPrice / price + - totalPrice / extendedAmount / lineTotal + - regularPrice / originalPrice + - salePrice / promoPrice + - couponAmount / couponSavings + - loyaltyDiscount / fuelPointsDiscount / plusCardSavings + - department / category / aisle + """ + description = ( + item.get("description") + or item.get("itemDescription") + or item.get("productName") + or item.get("name") + or "UNKNOWN ITEM" + ) + + quantity = _to_decimal(item.get("quantity", item.get("qty", item.get("quantitySold", 1))), "1") + unit_price = _to_decimal(item.get("basePrice", item.get("unitPrice", item.get("price", 0)))) + extended_price = _to_decimal( + item.get("totalPrice", item.get("extendedAmount", item.get("lineTotal"))) + ) + + # Compute extended_price if not provided + if extended_price == Decimal("0") and unit_price != Decimal("0"): + extended_price = unit_price * quantity + + regular_price = item.get("regularPrice", item.get("originalPrice")) + sale_price = item.get("salePrice", item.get("promoPrice")) + coupon_discount = item.get( + "couponAmount", item.get("couponSavings", item.get("couponDiscount")) + ) + loyalty_discount = item.get( + "plusCardSavings", + item.get("loyaltyDiscount", item.get("fuelPointsDiscount")), + ) + + # UPC handling — Kroger may use krogerProductId or upc + upc = item.get("upc", item.get("UPC", item.get("krogerProductId"))) + if upc: + upc = str(upc).strip().lstrip("0") or None + + category = item.get("department", item.get("category", item.get("aisle"))) + + # Weight info for produce/deli items + weight = item.get("weight", item.get("netWeight")) + extra = {} + if weight is not None: + extra["weight"] = str(weight) + weight_uom = item.get("weightUom", item.get("unitOfMeasure")) + if weight_uom: + extra["weight_uom"] = weight_uom + + result = { + "product_name_raw": description.strip(), + "upc": upc, + "quantity": quantity, + "unit_price": unit_price, + "extended_price": extended_price, + "regular_price": (_to_decimal(regular_price) if regular_price is not None else None), + "sale_price": (_to_decimal(sale_price) if sale_price is not None else None), + "coupon_discount": (_to_decimal(coupon_discount) if coupon_discount is not None else None), + "loyalty_discount": ( + _to_decimal(loyalty_discount) if loyalty_discount is not None else None + ), + "category_raw": category.strip() if category else None, + } + + return result + + +def parse_kroger_receipt(raw: RawReceipt) -> dict: + """Parse a RawReceipt from Kroger into a PurchaseCreate-compatible dict.""" + data = raw.raw_data + detail = data.get("detail", {}) + + # Parse items — Kroger uses "items" or "lineItems" or "receiptItems" + raw_items = detail.get("items", detail.get("lineItems", detail.get("receiptItems", []))) + items = [] + for raw_item in raw_items: + # Skip voided / returned items + if raw_item.get("voided") or raw_item.get("status") in ( + "VOIDED", + "RETURNED", + ): + logger.debug("Skipping voided/returned item: %s", raw_item.get("description")) + continue + if raw_item.get("returnFlag") or raw_item.get("isReturn"): + logger.debug("Skipping returned item: %s", raw_item.get("description")) + continue + items.append(_parse_item(raw_item)) + + # Parse totals — Kroger uses various field names + total = _to_decimal( + detail.get( + "total", + data.get("total", data.get("orderTotal", data.get("grandTotal", 0))), + ) + ) + subtotal = detail.get("subtotal", data.get("subtotal", data.get("subTotal"))) + tax = detail.get("tax", data.get("tax", data.get("salesTax"))) + savings = detail.get( + "totalSavings", + data.get("savings", data.get("totalDiscount", data.get("youSaved"))), + ) + + return { + "receipt_id": raw.receipt_id, + "purchase_date": raw.purchase_date, + "total": total, + "subtotal": _to_decimal(subtotal) if subtotal is not None else None, + "tax": _to_decimal(tax) if tax is not None else None, + "savings_total": _to_decimal(savings) if savings is not None else None, + "source_url": raw.source_url, + "raw_data": data, + "items": items, + } diff --git a/src/receiptwitness/parsers/meijer.py b/src/receiptwitness/parsers/meijer.py new file mode 100644 index 0000000..d1960d0 --- /dev/null +++ b/src/receiptwitness/parsers/meijer.py @@ -0,0 +1,138 @@ +"""Parse raw Meijer mPerks receipt data into PurchaseCreate-compatible dicts. + +The mPerks receipt JSON structure (reverse-engineered from their SPA) +typically looks like: + +Transaction listing: +{ + "transactions": [ + { + "transactionId": "12345", + "transactionDate": "2026-03-10T14:30:00Z", + "storeNumber": "123", + "total": 87.42, + "savings": 12.50 + } + ] +} + +Receipt detail: +{ + "receiptId": "12345", + "items": [ + { + "description": "ORGANIC BANANAS", + "upc": "0000000004011", + "quantity": 1, + "price": 0.69, + "extendedPrice": 0.69, + "regularPrice": 0.79, + "salePrice": 0.69, + "couponDiscount": 0.0, + "mperksDiscount": 0.10, + "category": "PRODUCE" + } + ], + "subtotal": 74.92, + "tax": 5.24, + "total": 87.42, + "totalSavings": 12.50 +} +""" + +import logging +from decimal import Decimal, InvalidOperation + +from receiptwitness.scrapers.base import RawReceipt + +logger = logging.getLogger(__name__) + + +def _to_decimal(value, default: str = "0") -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return Decimal(default) + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError, TypeError): + return Decimal(default) + + +def _parse_item(item: dict) -> dict: + """Parse a single line item from Meijer receipt detail.""" + description = ( + item.get("description") or item.get("itemDescription") or item.get("name") or "UNKNOWN ITEM" + ) + + quantity = _to_decimal(item.get("quantity", item.get("qty", 1)), "1") + unit_price = _to_decimal(item.get("price", item.get("unitPrice", 0))) + extended_price = _to_decimal(item.get("extendedPrice", item.get("totalPrice"))) + + # If extended_price wasn't provided, compute it + if extended_price == Decimal("0") and unit_price != Decimal("0"): + extended_price = unit_price * quantity + + regular_price = item.get("regularPrice") + sale_price = item.get("salePrice") + coupon_discount = item.get("couponDiscount", item.get("couponSavings")) + loyalty_discount = item.get("mperksDiscount", item.get("loyaltyDiscount")) + + upc = item.get("upc", item.get("UPC")) + if upc: + upc = str(upc).strip().lstrip("0") or None + + category = item.get("category", item.get("departmentDescription")) + + return { + "product_name_raw": description.strip(), + "upc": upc, + "quantity": quantity, + "unit_price": unit_price, + "extended_price": extended_price, + "regular_price": _to_decimal(regular_price) if regular_price is not None else None, + "sale_price": _to_decimal(sale_price) if sale_price is not None else None, + "coupon_discount": (_to_decimal(coupon_discount) if coupon_discount is not None else None), + "loyalty_discount": ( + _to_decimal(loyalty_discount) if loyalty_discount is not None else None + ), + "category_raw": category.strip() if category else None, + } + + +def parse_meijer_receipt(raw: RawReceipt) -> dict: + """Parse a RawReceipt from Meijer into a PurchaseCreate-compatible dict. + + Returns a dict with keys matching PurchaseCreate schema fields. + The caller is responsible for setting store_id and store_location_id + from the store registry. + """ + data = raw.raw_data + detail = data.get("detail", {}) + + # Parse items from the detail response + raw_items = detail.get("items", detail.get("lineItems", [])) + items = [] + for raw_item in raw_items: + # Skip voided items + if raw_item.get("voided") or raw_item.get("status") == "VOIDED": + logger.debug("Skipping voided item: %s", raw_item.get("description")) + continue + items.append(_parse_item(raw_item)) + + # Parse totals + total = _to_decimal(detail.get("total", data.get("total", data.get("transactionTotal", 0)))) + subtotal = detail.get("subtotal", data.get("subtotal")) + tax = detail.get("tax", data.get("tax")) + savings = detail.get("totalSavings", data.get("savings", data.get("totalDiscount"))) + + return { + "receipt_id": raw.receipt_id, + "purchase_date": raw.purchase_date, + "total": total, + "subtotal": _to_decimal(subtotal) if subtotal is not None else None, + "tax": _to_decimal(tax) if tax is not None else None, + "savings_total": _to_decimal(savings) if savings is not None else None, + "source_url": raw.source_url, + "raw_data": data, + "items": items, + } diff --git a/src/receiptwitness/parsers/target.py b/src/receiptwitness/parsers/target.py new file mode 100644 index 0000000..25b4204 --- /dev/null +++ b/src/receiptwitness/parsers/target.py @@ -0,0 +1,191 @@ +"""Target Circle receipt parser. + +Transforms raw Target in-store receipt JSON into the common PurchaseCreate schema. +Target receipt data includes Circle pricing, BOGO deals, and Circle rewards +discounts that need special handling. + +Target receipt detail structure (reverse-engineered from target.com SPA): + +{ + "orderId": "TGT-2026-0315-7890", + "items": [ + { + "description": "GOOD & GATHER WHOLE MILK GAL", + "tcin": "14767459", + "upc": "0085239100123", + "quantity": 1, + "unitPrice": 3.89, + "totalPrice": 3.89, + "regularPrice": 4.19, + "circlePrice": 3.89, + "couponDiscount": 0.0, + "circleRewardsDiscount": 0.30, + "promoDescription": "Circle offer: Save 30c", + "department": "GROCERY" + } + ], + "subtotal": 78.32, + "tax": 4.89, + "total": 83.21, + "totalSavings": 11.45 +} +""" + +import logging +from decimal import Decimal, InvalidOperation + +from receiptwitness.scrapers.base import RawReceipt + +logger = logging.getLogger(__name__) + + +def _to_decimal(value, default: str = "0") -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return Decimal(default) + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError, TypeError): + return Decimal(default) + + +def _parse_item(item: dict) -> dict: + """Parse a single line item from a Target receipt. + + Target items may include fields like: + - description / itemDescription / productName + - tcin (Target internal product ID) / upc / dpci + - quantity / qty + - unitPrice / price + - totalPrice / extendedPrice / lineTotal + - regularPrice / originalPrice + - circlePrice / salePrice / promoPrice + - couponDiscount / couponSavings + - circleRewardsDiscount / circleDiscount / loyaltyDiscount + - promoDescription / offerDescription (e.g. "BOGO 50% off", "Circle offer") + - department / category + """ + description = ( + item.get("description") + or item.get("itemDescription") + or item.get("productName") + or item.get("name") + or "UNKNOWN ITEM" + ) + + quantity = _to_decimal(item.get("quantity", item.get("qty", item.get("quantitySold", 1))), "1") + unit_price = _to_decimal(item.get("unitPrice", item.get("price", item.get("basePrice", 0)))) + extended_price = _to_decimal( + item.get("totalPrice", item.get("extendedPrice", item.get("lineTotal"))) + ) + + # Compute extended_price if not provided + if extended_price == Decimal("0") and unit_price != Decimal("0"): + extended_price = unit_price * quantity + + regular_price = item.get("regularPrice", item.get("originalPrice")) + # Target Circle pricing — circlePrice takes precedence over generic salePrice + sale_price = item.get("circlePrice", item.get("salePrice", item.get("promoPrice"))) + coupon_discount = item.get( + "couponDiscount", item.get("couponSavings", item.get("couponAmount")) + ) + # Circle rewards / loyalty discount + loyalty_discount = item.get( + "circleRewardsDiscount", + item.get("circleDiscount", item.get("loyaltyDiscount")), + ) + + # UPC handling — Target may use tcin, upc, or dpci + upc = item.get("upc", item.get("UPC")) + if upc: + upc = str(upc).strip().lstrip("0") or None + + # Target also has TCIN (Target.com Item Number) and DPCI (Department/Class/Item) + tcin = item.get("tcin", item.get("TCIN")) + dpci = item.get("dpci", item.get("DPCI")) + + category = item.get("department", item.get("category")) + + # Capture promo/deal description for BOGO and Circle offers + promo_description = item.get("promoDescription", item.get("offerDescription")) + + # Weight info for produce/deli items + weight = item.get("weight", item.get("netWeight")) + extra: dict = {} + if weight is not None: + extra["weight"] = str(weight) + weight_uom = item.get("weightUom", item.get("unitOfMeasure")) + if weight_uom: + extra["weight_uom"] = weight_uom + if tcin: + extra["tcin"] = str(tcin) + if dpci: + extra["dpci"] = str(dpci) + if promo_description: + extra["promo_description"] = promo_description + + result: dict = { + "product_name_raw": description.strip(), + "upc": upc, + "quantity": quantity, + "unit_price": unit_price, + "extended_price": extended_price, + "regular_price": _to_decimal(regular_price) if regular_price is not None else None, + "sale_price": _to_decimal(sale_price) if sale_price is not None else None, + "coupon_discount": (_to_decimal(coupon_discount) if coupon_discount is not None else None), + "loyalty_discount": ( + _to_decimal(loyalty_discount) if loyalty_discount is not None else None + ), + "category_raw": category.strip() if category else None, + } + + return result + + +def parse_target_receipt(raw: RawReceipt) -> dict: + """Parse a RawReceipt from Target into a PurchaseCreate-compatible dict.""" + data = raw.raw_data + detail = data.get("detail", {}) + + # Parse items — Target uses "items" or "lineItems" + raw_items = detail.get("items", detail.get("lineItems", [])) + items = [] + for raw_item in raw_items: + # Skip voided / returned items + if raw_item.get("voided") or raw_item.get("status") in ( + "VOIDED", + "RETURNED", + "CANCELLED", + ): + logger.debug("Skipping voided/returned item: %s", raw_item.get("description")) + continue + if raw_item.get("returnFlag") or raw_item.get("isReturn"): + logger.debug("Skipping returned item: %s", raw_item.get("description")) + continue + items.append(_parse_item(raw_item)) + + # Parse totals + total = _to_decimal( + detail.get( + "total", + data.get("total", data.get("orderTotal", data.get("grandTotal", 0))), + ) + ) + subtotal = detail.get("subtotal", data.get("subtotal", data.get("subTotal"))) + tax = detail.get("tax", data.get("tax", data.get("salesTax"))) + savings = detail.get( + "totalSavings", + data.get("savings", data.get("totalDiscount", data.get("circleSavings"))), + ) + + return { + "receipt_id": raw.receipt_id, + "purchase_date": raw.purchase_date, + "total": total, + "subtotal": _to_decimal(subtotal) if subtotal is not None else None, + "tax": _to_decimal(tax) if tax is not None else None, + "savings_total": _to_decimal(savings) if savings is not None else None, + "source_url": raw.source_url, + "raw_data": data, + "items": items, + } diff --git a/src/receiptwitness/pipeline/__init__.py b/src/receiptwitness/pipeline/__init__.py new file mode 100644 index 0000000..e590387 --- /dev/null +++ b/src/receiptwitness/pipeline/__init__.py @@ -0,0 +1,30 @@ +"""Receipt & product matching pipeline — receipt normalization and product dedup.""" + +from receiptwitness.pipeline.matching import ( + ConfidenceLevel, + ProductMatcher, + match_purchase_item, +) +from receiptwitness.pipeline.normalization import ( + MatchMethod, + MatchResult, + clean_name, + extract_size_info, + jaccard_similarity, + normalize_product, +) +from receiptwitness.pipeline.receipt import normalize_receipt, parse_meijer_item + +__all__ = [ + "ConfidenceLevel", + "MatchMethod", + "MatchResult", + "ProductMatcher", + "clean_name", + "extract_size_info", + "jaccard_similarity", + "match_purchase_item", + "normalize_product", + "normalize_receipt", + "parse_meijer_item", +] diff --git a/src/receiptwitness/pipeline/matching.py b/src/receiptwitness/pipeline/matching.py new file mode 100644 index 0000000..7e71039 --- /dev/null +++ b/src/receiptwitness/pipeline/matching.py @@ -0,0 +1,136 @@ +"""Product matching & dedup — UPC primary, fuzzy name fallback, confidence scoring. + +Wraps the Phase 1 normalization module with confidence-level classification +and batch matching for purchase ingestion. +""" + +import uuid +from dataclasses import dataclass + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.schemas.purchase import PurchaseItemCreate +from sqlalchemy.orm import Session + +from receiptwitness.pipeline.normalization import ( + MatchMethod, + MatchResult, + extract_size_info, + normalize_product, +) + +# Re-export for convenience +ConfidenceLevel = MatchConfidence + + +@dataclass(frozen=True) +class MatchOutcome: + """Result of matching a single purchase item to a normalized product.""" + + item_index: int + match: MatchResult | None + confidence_level: MatchConfidence + created_new: bool = False + + +def classify_confidence(score: float, method: MatchMethod) -> MatchConfidence: + """Classify a match score into high/medium/low confidence.""" + if method == MatchMethod.UPC: + return MatchConfidence.HIGH + # Name-based matching thresholds + if score >= 0.8: + return MatchConfidence.HIGH + if score >= 0.5: + return MatchConfidence.MEDIUM + return MatchConfidence.LOW + + +def _create_product_from_item( + session: Session, + item: PurchaseItemCreate, +) -> NormalizedProduct: + """Create a new NormalizedProduct from a purchase item that had no match.""" + size_info = extract_size_info(item.product_name_raw) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=item.product_name_raw, + size=size_info[0] if size_info else None, + size_unit=size_info[1] if size_info else None, + upc_variants=[item.upc] if item.upc else [], + ) + session.add(product) + session.flush() + return product + + +class ProductMatcher: + """Batch product matcher for purchase ingestion. + + Usage: + matcher = ProductMatcher(session) + outcomes = matcher.match_items(items) + """ + + def __init__( + self, + session: Session, + name_threshold: float = 0.4, + auto_create: bool = True, + ): + self.session = session + self.name_threshold = name_threshold + self.auto_create = auto_create + + def match_single( + self, + item: PurchaseItemCreate, + ) -> tuple[NormalizedProduct | None, MatchResult | None, MatchConfidence]: + """Match a single purchase item to a normalized product. + + Returns (product, match_result, confidence_level). + If auto_create is True and no match found, creates a new product. + """ + result = normalize_product( + self.session, + item.product_name_raw, + upc=item.upc, + name_threshold=self.name_threshold, + ) + + if result: + confidence = classify_confidence(result.confidence, result.method) + return result.product, result, confidence + + if self.auto_create: + product = _create_product_from_item(self.session, item) + return product, None, MatchConfidence.LOW + + return None, None, MatchConfidence.LOW + + def match_items(self, items: list[PurchaseItemCreate]) -> list[MatchOutcome]: + """Match a batch of purchase items. Returns outcomes in order.""" + outcomes: list[MatchOutcome] = [] + for idx, item in enumerate(items): + product, result, confidence = self.match_single(item) + created = result is None and product is not None + outcomes.append( + MatchOutcome( + item_index=idx, + match=result, + confidence_level=confidence, + created_new=created, + ) + ) + return outcomes + + +def match_purchase_item( + session: Session, + item: PurchaseItemCreate, + name_threshold: float = 0.4, + auto_create: bool = True, +) -> tuple[NormalizedProduct | None, MatchConfidence]: + """Convenience function: match a single item, return (product, confidence).""" + matcher = ProductMatcher(session, name_threshold=name_threshold, auto_create=auto_create) + product, _, confidence = matcher.match_single(item) + return product, confidence diff --git a/src/receiptwitness/pipeline/normalization.py b/src/receiptwitness/pipeline/normalization.py new file mode 100644 index 0000000..c1fade9 --- /dev/null +++ b/src/receiptwitness/pipeline/normalization.py @@ -0,0 +1,155 @@ +"""Product normalization — Phase 1: UPC matching + fuzzy name matching. + +Matches products across retailers by: +1. Exact UPC match (highest confidence) +2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) +""" + +import re +from dataclasses import dataclass +from enum import StrEnum + +from cartsnitch_common.models.product import NormalizedProduct +from sqlalchemy import select +from sqlalchemy.orm import Session + + +class MatchMethod(StrEnum): + """How a product match was determined.""" + + UPC = "upc" + NAME = "name" + + +@dataclass(frozen=True) +class MatchResult: + """Result of a product normalization attempt.""" + + product: NormalizedProduct + confidence: float + method: MatchMethod + + +# Noise words stripped during name cleaning +_NOISE_WORDS = frozenset( + { + "the", + "a", + "an", + "and", + "or", + "of", + "with", + "in", + "for", + "to", + "brand", + "original", + "classic", + "new", + "improved", + } +) + +# Regex for extracting size info (e.g., "16 oz", "1.5 lb", "12 ct") +_SIZE_PATTERN = re.compile( + r"(\d+(?:\.\d+)?)\s*(oz|fl\s*oz|lb|lbs|g|kg|ml|l|ct|pk|count|pack)\b", + re.IGNORECASE, +) + + +def clean_name(name: str) -> str: + """Normalize a product name for comparison. + + - Lowercase + - Remove size info (e.g., "16 oz") + - Strip noise words + - Collapse whitespace + """ + cleaned = name.lower() + cleaned = _SIZE_PATTERN.sub("", cleaned) + cleaned = re.sub(r"[^\w\s]", " ", cleaned) + tokens = cleaned.split() + tokens = [t for t in tokens if t not in _NOISE_WORDS] + return " ".join(tokens) + + +def extract_size_info(name: str) -> tuple[str, str] | None: + """Extract (size, unit) from a product name, if present.""" + match = _SIZE_PATTERN.search(name) + if match: + return match.group(1), match.group(2).lower().replace(" ", "_") + return None + + +def jaccard_similarity(a: str, b: str) -> float: + """Token-based Jaccard similarity between two cleaned names.""" + tokens_a = set(a.split()) + tokens_b = set(b.split()) + if not tokens_a or not tokens_b: + return 0.0 + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) + + +def match_by_upc(session: Session, upc: str) -> MatchResult | None: + """Find a normalized product by exact UPC match. + + Loads products with upc_variants and checks membership in Python + for cross-database compatibility (works on both PostgreSQL and SQLite). + """ + # TODO: Use PostgreSQL JSON containment query (@>) for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None)) + products = session.execute(stmt).scalars().all() + for product in products: + if product.upc_variants and upc in product.upc_variants: + return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) + return None + + +def match_by_name( + session: Session, + name: str, + threshold: float = 0.5, +) -> MatchResult | None: + """Find the best normalized product by fuzzy name matching. + + Loads all normalized products and computes Jaccard similarity. + Returns the best match above the threshold, or None. + """ + # TODO: Use pg_trgm similarity index for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + cleaned = clean_name(name) + stmt = select(NormalizedProduct) + products = session.execute(stmt).scalars().all() + + best_match: NormalizedProduct | None = None + best_score = 0.0 + + for product in products: + score = jaccard_similarity(cleaned, clean_name(product.canonical_name)) + if score > best_score and score >= threshold: + best_score = score + best_match = product + + if best_match: + return MatchResult(product=best_match, confidence=best_score, method=MatchMethod.NAME) + return None + + +def normalize_product( + session: Session, + name: str, + upc: str | None = None, + name_threshold: float = 0.5, +) -> MatchResult | None: + """Full normalization pipeline: UPC first, then fuzzy name fallback.""" + if upc: + result = match_by_upc(session, upc) + if result: + return result + return match_by_name(session, name, threshold=name_threshold) diff --git a/src/receiptwitness/pipeline/receipt.py b/src/receiptwitness/pipeline/receipt.py new file mode 100644 index 0000000..7d3e863 --- /dev/null +++ b/src/receiptwitness/pipeline/receipt.py @@ -0,0 +1,144 @@ +"""Receipt normalization — parse raw Meijer scraper output into purchase records. + +Maps raw receipt fields, cleans product names, extracts quantities/units. +""" + +import re +from datetime import date +from decimal import Decimal, InvalidOperation + +from cartsnitch_common.schemas.purchase import PurchaseCreate, PurchaseItemCreate + + +def _clean_product_name(raw: str) -> str: + """Clean raw product name from scraper output.""" + cleaned = raw.strip() + # Remove leading/trailing non-alphanumeric chars + cleaned = re.sub(r"^\W+|\W+$", "", cleaned) + # Collapse internal whitespace + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned + + +def _safe_decimal( + value: str | float | int | Decimal | None, + default: Decimal = Decimal("0"), +) -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return default + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError): + return default + + +def parse_meijer_item(raw_item: dict) -> PurchaseItemCreate: + """Parse a single Meijer scraper line item into a PurchaseItemCreate. + + Expected raw_item keys (from Meijer scraper): + - description / name: product name + - upc / upcCode: UPC barcode + - quantity / qty: number of units + - unitPrice / price: per-unit price + - extendedPrice / totalPrice: line total + - regularPrice: shelf price before discounts + - salePrice: sale price if applicable + - couponAmount / couponDiscount: coupon savings + - loyaltyAmount / loyaltyDiscount: loyalty savings + - category / department: raw category + """ + name = raw_item.get("description") or raw_item.get("name") or "" + cleaned_name = _clean_product_name(name) + + upc = raw_item.get("upc") or raw_item.get("upcCode") + if upc: + upc = str(upc).strip().lstrip("0") or str(upc).strip() + + qty = _safe_decimal( + raw_item.get("quantity") or raw_item.get("qty"), + default=Decimal("1"), + ) + + unit_price = _safe_decimal(raw_item.get("unitPrice") or raw_item.get("price")) + extended = _safe_decimal(raw_item.get("extendedPrice") or raw_item.get("totalPrice")) + if extended == Decimal("0") and unit_price > 0: + extended = unit_price * qty + + regular = raw_item.get("regularPrice") + sale = raw_item.get("salePrice") + coupon = raw_item.get("couponAmount") or raw_item.get("couponDiscount") + loyalty = raw_item.get("loyaltyAmount") or raw_item.get("loyaltyDiscount") + category = raw_item.get("category") or raw_item.get("department") + + return PurchaseItemCreate( + product_name_raw=cleaned_name, + upc=upc, + quantity=qty, + unit_price=unit_price, + extended_price=extended, + regular_price=_safe_decimal(regular) if regular is not None else None, + sale_price=_safe_decimal(sale) if sale is not None else None, + coupon_discount=_safe_decimal(coupon) if coupon is not None else None, + loyalty_discount=_safe_decimal(loyalty) if loyalty is not None else None, + category_raw=str(category).strip() if category else None, + ) + + +def normalize_receipt( + raw_receipt: dict, + user_id: str, + store_id: str, +) -> PurchaseCreate: + """Parse a complete Meijer raw receipt into a PurchaseCreate. + + Expected raw_receipt keys: + - receiptId / receipt_id / id: unique receipt identifier + - date / purchaseDate / purchase_date: purchase date (YYYY-MM-DD or similar) + - total / totalAmount: receipt total + - subtotal: pre-tax subtotal + - tax / taxAmount: tax amount + - savings / totalSavings: total discount savings + - items: list of raw line item dicts + """ + import uuid + + receipt_id = str( + raw_receipt.get("receiptId") + or raw_receipt.get("receipt_id") + or raw_receipt.get("id") + or uuid.uuid4() + ) + + raw_date = ( + raw_receipt.get("date") + or raw_receipt.get("purchaseDate") + or raw_receipt.get("purchase_date") + ) + if isinstance(raw_date, str): + purchase_date = date.fromisoformat(raw_date[:10]) + elif isinstance(raw_date, date): + purchase_date = raw_date + else: + purchase_date = date.today() + + total = _safe_decimal(raw_receipt.get("total") or raw_receipt.get("totalAmount")) + subtotal = raw_receipt.get("subtotal") + tax = raw_receipt.get("tax") or raw_receipt.get("taxAmount") + savings = raw_receipt.get("savings") or raw_receipt.get("totalSavings") + + raw_items = raw_receipt.get("items") or [] + items = [parse_meijer_item(item) for item in raw_items] + + return PurchaseCreate( + user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id, + store_id=uuid.UUID(store_id) if isinstance(store_id, str) else store_id, + receipt_id=receipt_id, + purchase_date=purchase_date, + total=total, + subtotal=_safe_decimal(subtotal) if subtotal is not None else None, + tax=_safe_decimal(tax) if tax is not None else None, + savings_total=_safe_decimal(savings) if savings is not None else None, + raw_data=raw_receipt, + items=items, + ) diff --git a/src/receiptwitness/scrapers/__init__.py b/src/receiptwitness/scrapers/__init__.py new file mode 100644 index 0000000..cfc8d9e --- /dev/null +++ b/src/receiptwitness/scrapers/__init__.py @@ -0,0 +1 @@ +"""Retailer scrapers.""" diff --git a/src/receiptwitness/scrapers/base.py b/src/receiptwitness/scrapers/base.py new file mode 100644 index 0000000..fd5fdc3 --- /dev/null +++ b/src/receiptwitness/scrapers/base.py @@ -0,0 +1,72 @@ +"""Abstract base scraper interface for all retailer scrapers.""" + +import asyncio +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime + +from receiptwitness.config import settings + + +@dataclass +class SessionData: + """Holds session cookies and metadata for a retailer login.""" + + cookies: list[dict] + user_agent: str + created_at: datetime + expires_at: datetime | None = None + extra: dict = field(default_factory=dict) + + +@dataclass +class RawReceipt: + """Raw receipt data before parsing.""" + + receipt_id: str + purchase_date: str + store_number: str | None = None + raw_data: dict = field(default_factory=dict) + source_url: str | None = None + + +class BaseScraper(ABC): + """All retailer scrapers implement this interface. + + Provides common functionality: human-like delays, rate limiting guards, + and the abstract methods each retailer scraper must implement. + """ + + @abstractmethod + async def login(self, username: str, password: str) -> SessionData: + """Authenticate with the retailer portal and return session data.""" + ... + + @abstractmethod + async def check_session(self, session: SessionData) -> bool: + """Verify if an existing session is still valid.""" + ... + + @abstractmethod + async def scrape_receipts( + self, session: SessionData, since: datetime | None = None + ) -> list[RawReceipt]: + """Scrape receipt data from the retailer portal.""" + ... + + @abstractmethod + def parse_receipt(self, raw: RawReceipt) -> dict: + """Parse a raw receipt into structured data. + + Returns a dict with keys matching PurchaseCreate schema fields, + including an 'items' list matching PurchaseItemCreate fields. + """ + ... + + async def human_delay(self, min_ms: int | None = None, max_ms: int | None = None) -> None: + """Sleep for a randomized human-like interval.""" + lo = min_ms or settings.min_request_delay_ms + hi = max_ms or settings.max_request_delay_ms + delay = random.randint(lo, hi) / 1000.0 + await asyncio.sleep(delay) diff --git a/src/receiptwitness/scrapers/kroger.py b/src/receiptwitness/scrapers/kroger.py new file mode 100644 index 0000000..a7993af --- /dev/null +++ b/src/receiptwitness/scrapers/kroger.py @@ -0,0 +1,344 @@ +"""Kroger loyalty portal scraper using Playwright. + +Kroger uses Akamai Bot Manager for aggressive headless browser detection. +This scraper uses enhanced stealth measures including playwright-stealth, +realistic fingerprinting, and human-like interaction pacing. +""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import cast + +from playwright.async_api import BrowserContext, Page, Playwright, async_playwright + +from receiptwitness.config import settings +from receiptwitness.scrapers.base import BaseScraper, RawReceipt, SessionData + +logger = logging.getLogger(__name__) + +# Kroger endpoints +KROGER_BASE = "https://www.kroger.com" +KROGER_LOGIN_PAGE = f"{KROGER_BASE}/signin" +KROGER_PURCHASE_HISTORY = f"{KROGER_BASE}/mypurchases" +KROGER_RECEIPT_API = f"{KROGER_BASE}/atlas/v1/purchase-history/api" +KROGER_RECEIPT_DETAIL_API = f"{KROGER_BASE}/atlas/v1/receipt/api" +KROGER_ACCOUNT_PAGE = f"{KROGER_BASE}/account/dashboard" + +# Realistic browser fingerprint — Chrome on Windows (matches Kroger's typical audience) +DEFAULT_USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/131.0.0.0 Safari/537.36" +) +DEFAULT_VIEWPORT = {"width": 1920, "height": 1080} +DEFAULT_LOCALE = "en-US" +DEFAULT_TIMEZONE = "America/New_York" + + +class KrogerScraper(BaseScraper): + """Scraper for Kroger loyalty purchase history. + + Kroger uses Akamai Bot Manager which aggressively detects headless + browsers. This scraper employs enhanced stealth measures: + - Masks webdriver/automation signals + - Sets realistic browser fingerprint + - Uses human-like interaction pacing + - Preserves browser context across sessions + """ + + async def _create_stealth_context( + self, playwright_instance: Playwright, cookies: list[dict] | None = None + ) -> BrowserContext: + """Create a browser context with enhanced stealth for Akamai evasion.""" + browser = await playwright_instance.chromium.launch( + headless=settings.headless, + args=[ + "--disable-blink-features=AutomationControlled", + "--no-sandbox", + "--disable-dev-shm-usage", + "--disable-infobars", + "--window-size=1920,1080", + ], + ) + context = await browser.new_context( + user_agent=DEFAULT_USER_AGENT, + viewport=DEFAULT_VIEWPORT, # type: ignore[arg-type] + locale=DEFAULT_LOCALE, + timezone_id=DEFAULT_TIMEZONE, + java_script_enabled=True, + bypass_csp=False, + color_scheme="light", + has_touch=False, + ) + + # Enhanced stealth script targeting Akamai Bot Manager detection vectors + await context.add_init_script( + """ + // Mask webdriver flag + Object.defineProperty(navigator, 'webdriver', { + get: () => undefined + }); + + // Chrome runtime object + window.chrome = { + runtime: {}, + loadTimes: function() {}, + csi: function() {}, + app: { isInstalled: false } + }; + + // Realistic plugin array + Object.defineProperty(navigator, 'plugins', { + get: () => [1, 2, 3, 4, 5] + }); + + // Languages + Object.defineProperty(navigator, 'languages', { + get: () => ['en-US', 'en'] + }); + + // Platform + Object.defineProperty(navigator, 'platform', { + get: () => 'Win32' + }); + + // Hardware concurrency + Object.defineProperty(navigator, 'hardwareConcurrency', { + get: () => 8 + }); + + // Device memory + Object.defineProperty(navigator, 'deviceMemory', { + get: () => 8 + }); + + // Permissions query override (Akamai checks this) + const originalQuery = window.navigator.permissions.query; + window.navigator.permissions.query = (parameters) => + parameters.name === 'notifications' + ? Promise.resolve({ state: Notification.permission }) + : originalQuery(parameters); + + // WebGL vendor/renderer (avoid "Google Inc." / "ANGLE" tells) + const getParameter = WebGLRenderingContext.prototype.getParameter; + WebGLRenderingContext.prototype.getParameter = function(parameter) { + if (parameter === 37445) return 'Intel Inc.'; + if (parameter === 37446) return 'Intel Iris OpenGL Engine'; + return getParameter.call(this, parameter); + }; + """ + ) + + if cookies: + await context.add_cookies(cookies) # type: ignore[arg-type] + + return cast(BrowserContext, context) + + async def login(self, username: str, password: str) -> SessionData: + """Log in to Kroger and capture session cookies.""" + async with async_playwright() as p: + context = await self._create_stealth_context(p) + page = await context.new_page() + try: + return await self._perform_login(page, context, username, password) + finally: + if context.browser: + await context.browser.close() + + async def _perform_login( + self, page: Page, context: BrowserContext, username: str, password: str + ) -> SessionData: + """Execute the Kroger login flow.""" + logger.info("Navigating to Kroger sign-in page") + await page.goto(KROGER_LOGIN_PAGE, wait_until="networkidle") + await self.human_delay(2000, 4000) + + # Kroger login form — email/username field + email_input = page.locator( + 'input[id="SignIn-emailInput"], ' + 'input[name="email"], ' + 'input[type="email"], ' + 'input[data-testid="SignIn-emailInput"]' + ) + await email_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await email_input.click() + await self.human_delay(300, 700) + await email_input.fill(username) + await self.human_delay(800, 1500) + + # Password field + password_input = page.locator( + 'input[id="SignIn-passwordInput"], ' + 'input[name="password"], ' + 'input[type="password"], ' + 'input[data-testid="SignIn-passwordInput"]' + ) + await password_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await password_input.click() + await self.human_delay(300, 700) + await password_input.fill(password) + await self.human_delay(1000, 2000) + + # Sign-in button + sign_in_btn = page.locator( + 'button[id="SignIn-submitButton"], ' + 'button[data-testid="SignIn-submitButton"], ' + 'button[type="submit"]:has-text("Sign In")' + ) + await sign_in_btn.click() + + # Wait for redirect away from sign-in page + await page.wait_for_url( + lambda url: "signin" not in url.lower(), + timeout=settings.browser_timeout_ms, + ) + await self.human_delay(1500, 3000) + + # Capture cookies + raw_cookies = await context.cookies() + cookies = [dict(c) for c in raw_cookies] + now = datetime.now(UTC) + + logger.info("Kroger login successful, captured %d cookies", len(cookies)) + return SessionData( + cookies=cookies, + user_agent=DEFAULT_USER_AGENT, + created_at=now, + expires_at=now + timedelta(hours=2), + extra={"retailer": "kroger"}, + ) + + async def check_session(self, session: SessionData) -> bool: + """Check if the Kroger session is still valid.""" + if session.expires_at and datetime.now(UTC) > session.expires_at: + logger.info("Kroger session expired based on timestamp") + return False + + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + response = await page.goto(KROGER_ACCOUNT_PAGE, wait_until="networkidle") + current_url = page.url.lower() + is_valid = "signin" not in current_url and response is not None and response.ok + logger.info("Kroger session check: valid=%s (url=%s)", is_valid, page.url) + return is_valid + except Exception: + logger.exception("Kroger session check failed") + return False + finally: + if context.browser: + await context.browser.close() + + async def scrape_receipts( + self, session: SessionData, since: datetime | None = None + ) -> list[RawReceipt]: + """Scrape purchase history from Kroger.""" + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + return await self._fetch_receipts(page, since) + finally: + if context.browser: + await context.browser.close() + + async def _fetch_receipts(self, page: Page, since: datetime | None) -> list[RawReceipt]: + """Fetch receipt list and details from Kroger purchase history.""" + # Navigate to purchase history to establish context + await page.goto(KROGER_PURCHASE_HISTORY, wait_until="networkidle") + await self.human_delay(1500, 3000) + + receipts: list[RawReceipt] = [] + + # Kroger purchase history API endpoint + api_response = await page.request.get(KROGER_RECEIPT_API) + if not api_response.ok: + logger.warning( + "Kroger purchase history request failed: %d %s", + api_response.status, + api_response.status_text, + ) + return [] + + response = await api_response.json() + if not isinstance(response, dict): + logger.warning("Unexpected purchase history response type: %s", type(response)) + return [] + + # Handle Kroger's response structure + orders = response.get("orders", response.get("purchases", [])) + if not isinstance(orders, list): + logger.warning("No orders found in Kroger purchase history response") + return [] + + logger.info("Found %d orders in Kroger purchase history", len(orders)) + + for order in orders: + raw_id = order.get("orderId") or order.get("receiptId") or order.get("id") or "" + order_id = str(raw_id) + purchase_date = order.get( + "purchaseDate", order.get("transactionDate", order.get("date", "")) + ) + + # Filter by date if 'since' is provided + if since and purchase_date: + try: + txn_dt = datetime.fromisoformat(purchase_date.replace("Z", "+00:00")) + if txn_dt < since: + continue + except (ValueError, TypeError): + pass + + if not order_id: + continue + + await self.human_delay(1000, 2500) + + # Fetch receipt detail + detail = await self._fetch_receipt_detail(page, order_id) + + raw_store = ( + order.get("storeNumber") + or order.get("divisionNumber") + or order.get("storeId") + or "" + ) + store_number = str(raw_store) + + receipts.append( + RawReceipt( + receipt_id=order_id, + purchase_date=purchase_date, + store_number=store_number, + raw_data={**order, "detail": detail}, + source_url=f"{KROGER_RECEIPT_DETAIL_API}?orderId={order_id}", + ) + ) + + logger.info("Scraped %d receipts from Kroger", len(receipts)) + return receipts + + async def _fetch_receipt_detail(self, page: Page, order_id: str) -> dict: + """Fetch detailed receipt data for a single Kroger order.""" + try: + url = f"{KROGER_RECEIPT_DETAIL_API}?orderId={order_id}" + api_response = await page.request.get(url) + if not api_response.ok: + logger.warning( + "Kroger receipt detail request failed for %s: %d", + order_id, + api_response.status, + ) + return {} + detail = await api_response.json() + return detail if isinstance(detail, dict) else {} + except Exception: + logger.exception("Failed to fetch Kroger receipt detail for %s", order_id) + return {} + + def parse_receipt(self, raw: RawReceipt) -> dict: + """Parse raw Kroger receipt into structured purchase data.""" + from receiptwitness.parsers.kroger import parse_kroger_receipt + + return parse_kroger_receipt(raw) diff --git a/src/receiptwitness/scrapers/meijer.py b/src/receiptwitness/scrapers/meijer.py new file mode 100644 index 0000000..4a4dd8e --- /dev/null +++ b/src/receiptwitness/scrapers/meijer.py @@ -0,0 +1,301 @@ +"""Meijer mPerks scraper using Playwright. + +Meijer has no public API. We reverse-engineer the XHR endpoints the mPerks +web app uses to pull purchase history and receipt data. The flow: + +1. Launch stealth Playwright browser +2. Navigate to mPerks login page and authenticate +3. Capture session cookies after successful login +4. Use those cookies to hit the mPerks receipt API endpoints directly +5. Parse receipt JSON into structured PurchaseCreate records + +Key endpoints (reverse-engineered from mPerks SPA): +- Login: POST https://www.meijer.com/bin/meijer/account/login +- Receipts: GET https://www.meijer.com/bin/meijer/profile/purchasehistory +- Receipt detail: GET https://www.meijer.com/bin/meijer/profile/receipt?receiptId=... +""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import cast + +from playwright.async_api import BrowserContext, Page, Playwright, async_playwright + +from receiptwitness.config import settings +from receiptwitness.scrapers.base import BaseScraper, RawReceipt, SessionData + +logger = logging.getLogger(__name__) + +# Meijer mPerks URLs +MEIJER_BASE = "https://www.meijer.com" +MEIJER_LOGIN_PAGE = f"{MEIJER_BASE}/shopping/login.html" +MEIJER_LOGIN_API = f"{MEIJER_BASE}/bin/meijer/account/login" +MEIJER_PURCHASE_HISTORY = f"{MEIJER_BASE}/bin/meijer/profile/purchasehistory" +MEIJER_RECEIPT_DETAIL = f"{MEIJER_BASE}/bin/meijer/profile/receipt" +MEIJER_MPERKS_HOME = f"{MEIJER_BASE}/mperks.html" + +# Realistic browser fingerprint +DEFAULT_USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/131.0.0.0 Safari/537.36" +) +DEFAULT_VIEWPORT = {"width": 1920, "height": 1080} +DEFAULT_LOCALE = "en-US" +DEFAULT_TIMEZONE = "America/Detroit" # Meijer HQ is in Grand Rapids, MI + + +class MeijerScraper(BaseScraper): + """Scraper for Meijer mPerks purchase history.""" + + async def _create_stealth_context( + self, playwright_instance: Playwright, cookies: list[dict] | None = None + ) -> BrowserContext: + """Create a browser context with stealth settings.""" + browser = await playwright_instance.chromium.launch( + headless=settings.headless, + args=[ + "--disable-blink-features=AutomationControlled", + "--no-sandbox", + ], + ) + context = await browser.new_context( + user_agent=DEFAULT_USER_AGENT, + viewport=DEFAULT_VIEWPORT, # type: ignore[arg-type] + locale=DEFAULT_LOCALE, + timezone_id=DEFAULT_TIMEZONE, + java_script_enabled=True, + bypass_csp=False, + ) + # Mask webdriver flag + await context.add_init_script( + """ + Object.defineProperty(navigator, 'webdriver', { + get: () => undefined + }); + // Mask chrome automation indicators + window.chrome = { runtime: {} }; + Object.defineProperty(navigator, 'plugins', { + get: () => [1, 2, 3, 4, 5] + }); + Object.defineProperty(navigator, 'languages', { + get: () => ['en-US', 'en'] + }); + """ + ) + if cookies: + await context.add_cookies(cookies) # type: ignore[arg-type] + return cast(BrowserContext, context) + + async def login(self, username: str, password: str) -> SessionData: + """Log in to Meijer mPerks and capture session cookies. + + The mPerks login flow: + 1. Navigate to login page + 2. Fill email and password fields + 3. Click sign-in button + 4. Wait for redirect to mPerks dashboard + 5. Extract session cookies + """ + async with async_playwright() as p: + context = await self._create_stealth_context(p) + page = await context.new_page() + try: + return await self._perform_login(page, context, username, password) + finally: + if context.browser: + await context.browser.close() + + async def _perform_login( + self, page: Page, context: BrowserContext, username: str, password: str + ) -> SessionData: + """Execute the login flow on the mPerks portal.""" + logger.info("Navigating to Meijer login page") + await page.goto(MEIJER_LOGIN_PAGE, wait_until="networkidle") + await self.human_delay(1500, 3000) + + # Fill email field + email_input = page.locator('input[type="email"], input[name="email"], #email') + await email_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await email_input.click() + await self.human_delay(200, 500) + await email_input.fill(username) + await self.human_delay(500, 1000) + + # Fill password field + password_input = page.locator('input[type="password"], input[name="password"], #password') + await password_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await password_input.click() + await self.human_delay(200, 500) + await password_input.fill(password) + await self.human_delay(500, 1500) + + # Click sign-in button + sign_in_btn = page.locator( + 'button[type="submit"], button:has-text("Sign In"), button:has-text("Log In")' + ) + await sign_in_btn.click() + + # Wait for navigation after login + await page.wait_for_url( + lambda url: "login" not in url.lower(), + timeout=settings.browser_timeout_ms, + ) + await self.human_delay(1000, 2000) + + # Capture cookies + raw_cookies = await context.cookies() + cookies = [dict(c) for c in raw_cookies] + now = datetime.now(UTC) + + logger.info("Meijer login successful, captured %d cookies", len(cookies)) + return SessionData( + cookies=cookies, + user_agent=DEFAULT_USER_AGENT, + created_at=now, + expires_at=now + timedelta(hours=4), + ) + + async def check_session(self, session: SessionData) -> bool: + """Check if the mPerks session is still valid. + + Makes a lightweight request to the mPerks home page and checks + if we get redirected to login (session expired) or not. + """ + if session.expires_at and datetime.now(UTC) > session.expires_at: + logger.info("Meijer session expired based on timestamp") + return False + + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + response = await page.goto(MEIJER_MPERKS_HOME, wait_until="networkidle") + current_url = page.url.lower() + is_valid = "login" not in current_url and response is not None and response.ok + logger.info("Meijer session check: valid=%s (url=%s)", is_valid, page.url) + return is_valid + except Exception: + logger.exception("Meijer session check failed") + return False + finally: + if context.browser: + await context.browser.close() + + async def scrape_receipts( + self, session: SessionData, since: datetime | None = None + ) -> list[RawReceipt]: + """Scrape purchase history from Meijer mPerks. + + Uses the XHR endpoints the mPerks SPA calls to fetch receipt data. + The purchase history endpoint returns a list of recent transactions, + and we can fetch individual receipt details for line items. + """ + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + return await self._fetch_receipts(page, since) + finally: + if context.browser: + await context.browser.close() + + async def _fetch_receipts(self, page: Page, since: datetime | None) -> list[RawReceipt]: + """Fetch receipt list and detail via mPerks XHR endpoints. + + Uses Playwright's page.request API (APIRequestContext) instead of + page.evaluate(fetch(...)) for better observability — requests show up + in Playwright traces and can be intercepted by route handlers. + """ + # Navigate to mPerks to establish context (cookies need domain context) + await page.goto(MEIJER_MPERKS_HOME, wait_until="networkidle") + await self.human_delay(1000, 2000) + + receipts: list[RawReceipt] = [] + + # Fetch purchase history listing via page.request (APIRequestContext) + api_response = await page.request.get(MEIJER_PURCHASE_HISTORY) + if not api_response.ok: + logger.warning( + "Purchase history request failed: %d %s", + api_response.status, + api_response.status_text, + ) + return [] + + response = await api_response.json() + + if not isinstance(response, dict): + logger.warning("Unexpected purchase history response type: %s", type(response)) + return [] + + transactions = response.get("transactions", response.get("purchaseHistory", [])) + if not isinstance(transactions, list): + logger.warning("No transactions found in purchase history response") + return [] + + logger.info("Found %d transactions in Meijer purchase history", len(transactions)) + + for txn in transactions: + receipt_id = str(txn.get("transactionId", txn.get("receiptId", ""))) + purchase_date = txn.get("transactionDate", txn.get("purchaseDate", "")) + + # Filter by date if 'since' is provided + if since and purchase_date: + try: + txn_dt = datetime.fromisoformat(purchase_date.replace("Z", "+00:00")) + if txn_dt < since: + continue + except (ValueError, TypeError): + pass + + if not receipt_id: + continue + + await self.human_delay(800, 2000) + + # Fetch receipt detail + detail = await self._fetch_receipt_detail(page, receipt_id) + + receipts.append( + RawReceipt( + receipt_id=receipt_id, + purchase_date=purchase_date, + store_number=str(txn.get("storeNumber", txn.get("storeId", ""))), + raw_data={**txn, "detail": detail}, + source_url=f"{MEIJER_RECEIPT_DETAIL}?receiptId={receipt_id}", + ) + ) + + logger.info("Scraped %d receipts from Meijer", len(receipts)) + return receipts + + async def _fetch_receipt_detail(self, page: Page, receipt_id: str) -> dict: + """Fetch detailed receipt data for a single transaction. + + Uses Playwright's page.request API for traceability. + """ + try: + url = f"{MEIJER_RECEIPT_DETAIL}?receiptId={receipt_id}" + api_response = await page.request.get(url) + if not api_response.ok: + logger.warning( + "Receipt detail request failed for %s: %d", + receipt_id, + api_response.status, + ) + return {} + detail = await api_response.json() + return detail if isinstance(detail, dict) else {} + except Exception: + logger.exception("Failed to fetch receipt detail for %s", receipt_id) + return {} + + def parse_receipt(self, raw: RawReceipt) -> dict: + """Parse raw Meijer receipt into structured purchase data. + + Delegates to the dedicated parser module. + """ + from receiptwitness.parsers.meijer import parse_meijer_receipt + + return parse_meijer_receipt(raw) diff --git a/src/receiptwitness/scrapers/target.py b/src/receiptwitness/scrapers/target.py new file mode 100644 index 0000000..1f959a6 --- /dev/null +++ b/src/receiptwitness/scrapers/target.py @@ -0,0 +1,326 @@ +"""Target Circle scraper using Playwright. + +Target stores ~1 year of in-store purchase history tied to Circle accounts. +Purchases appear when the user pays with a linked card, uses the Target app +wallet, or enters their Circle phone number at checkout. + +Key endpoints (reverse-engineered from target.com SPA): +- Login: POST https://gsp.target.com/gsp/authentications/v1/auth_codes +- Order history: GET https://api.target.com/order_history/v1/orders (in-store tab) +- Receipt detail: GET https://api.target.com/order_history/v1/orders/{orderId} +""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import cast + +from playwright.async_api import BrowserContext, Page, Playwright, async_playwright + +from receiptwitness.config import settings +from receiptwitness.scrapers.base import BaseScraper, RawReceipt, SessionData + +logger = logging.getLogger(__name__) + +# Target endpoints +TARGET_BASE = "https://www.target.com" +TARGET_LOGIN_PAGE = f"{TARGET_BASE}/login" +TARGET_ACCOUNT_PAGE = f"{TARGET_BASE}/account" +TARGET_ORDER_HISTORY = f"{TARGET_BASE}/account/orders" +TARGET_ORDER_API = "https://api.target.com/order_history/v1/orders" +TARGET_RECEIPT_API = "https://api.target.com/order_history/v1/orders" + +# Realistic browser fingerprint — Chrome on Windows +DEFAULT_USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/131.0.0.0 Safari/537.36" +) +DEFAULT_VIEWPORT = {"width": 1920, "height": 1080} +DEFAULT_LOCALE = "en-US" +DEFAULT_TIMEZONE = "America/Detroit" # SE Michigan coverage + + +class TargetScraper(BaseScraper): + """Scraper for Target Circle in-store purchase history. + + Target's order history SPA loads purchase data from internal API + endpoints. This scraper authenticates via the web login flow, + captures session cookies, and uses those to hit the order history + API for in-store receipt data. + """ + + async def _create_stealth_context( + self, playwright_instance: Playwright, cookies: list[dict] | None = None + ) -> BrowserContext: + """Create a browser context with stealth settings for Target.""" + browser = await playwright_instance.chromium.launch( + headless=settings.headless, + args=[ + "--disable-blink-features=AutomationControlled", + "--no-sandbox", + "--disable-dev-shm-usage", + ], + ) + context = await browser.new_context( + user_agent=DEFAULT_USER_AGENT, + viewport=DEFAULT_VIEWPORT, # type: ignore[arg-type] + locale=DEFAULT_LOCALE, + timezone_id=DEFAULT_TIMEZONE, + java_script_enabled=True, + bypass_csp=False, + color_scheme="light", + has_touch=False, + ) + # Mask webdriver and automation signals + await context.add_init_script( + """ + Object.defineProperty(navigator, 'webdriver', { + get: () => undefined + }); + + window.chrome = { + runtime: {}, + loadTimes: function() {}, + csi: function() {}, + app: { isInstalled: false } + }; + + Object.defineProperty(navigator, 'plugins', { + get: () => [1, 2, 3, 4, 5] + }); + + Object.defineProperty(navigator, 'languages', { + get: () => ['en-US', 'en'] + }); + + Object.defineProperty(navigator, 'platform', { + get: () => 'Win32' + }); + + Object.defineProperty(navigator, 'hardwareConcurrency', { + get: () => 8 + }); + + Object.defineProperty(navigator, 'deviceMemory', { + get: () => 8 + }); + """ + ) + if cookies: + await context.add_cookies(cookies) # type: ignore[arg-type] + return cast(BrowserContext, context) + + async def login(self, username: str, password: str) -> SessionData: + """Log in to Target and capture session cookies.""" + async with async_playwright() as p: + context = await self._create_stealth_context(p) + page = await context.new_page() + try: + return await self._perform_login(page, context, username, password) + finally: + if context.browser: + await context.browser.close() + + async def _perform_login( + self, page: Page, context: BrowserContext, username: str, password: str + ) -> SessionData: + """Execute the Target login flow.""" + logger.info("Navigating to Target sign-in page") + await page.goto(TARGET_LOGIN_PAGE, wait_until="networkidle") + await self.human_delay(2000, 4000) + + # Target login form — email/username field + email_input = page.locator( + 'input[id="username"], ' + 'input[name="username"], ' + 'input[type="email"], ' + 'input[data-test="username"]' + ) + await email_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await email_input.click() + await self.human_delay(300, 700) + await email_input.fill(username) + await self.human_delay(800, 1500) + + # Password field + password_input = page.locator( + 'input[id="password"], ' + 'input[name="password"], ' + 'input[type="password"], ' + 'input[data-test="password"]' + ) + await password_input.wait_for(state="visible", timeout=settings.browser_timeout_ms) + await password_input.click() + await self.human_delay(300, 700) + await password_input.fill(password) + await self.human_delay(1000, 2000) + + # Sign-in button + sign_in_btn = page.locator( + 'button[id="login"], ' + 'button[data-test="login-button"], ' + 'button[type="submit"]:has-text("Sign in")' + ) + await sign_in_btn.click() + + # Wait for redirect away from login page + await page.wait_for_url( + lambda url: "login" not in url.lower(), + timeout=settings.browser_timeout_ms, + ) + await self.human_delay(1500, 3000) + + # Capture cookies + raw_cookies = await context.cookies() + cookies = [dict(c) for c in raw_cookies] + now = datetime.now(UTC) + + logger.info("Target login successful, captured %d cookies", len(cookies)) + return SessionData( + cookies=cookies, + user_agent=DEFAULT_USER_AGENT, + created_at=now, + expires_at=now + timedelta(hours=2), + extra={"retailer": "target"}, + ) + + async def check_session(self, session: SessionData) -> bool: + """Check if the Target session is still valid.""" + if session.expires_at and datetime.now(UTC) > session.expires_at: + logger.info("Target session expired based on timestamp") + return False + + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + response = await page.goto(TARGET_ACCOUNT_PAGE, wait_until="networkidle") + current_url = page.url.lower() + is_valid = "login" not in current_url and response is not None and response.ok + logger.info("Target session check: valid=%s (url=%s)", is_valid, page.url) + return is_valid + except Exception: + logger.exception("Target session check failed") + return False + finally: + if context.browser: + await context.browser.close() + + async def scrape_receipts( + self, session: SessionData, since: datetime | None = None + ) -> list[RawReceipt]: + """Scrape in-store purchase history from Target Circle.""" + async with async_playwright() as p: + context = await self._create_stealth_context(p, cookies=session.cookies) + page = await context.new_page() + try: + return await self._fetch_receipts(page, since) + finally: + if context.browser: + await context.browser.close() + + async def _fetch_receipts(self, page: Page, since: datetime | None) -> list[RawReceipt]: + """Fetch receipt list and details from Target order history. + + Target's order history page has separate tabs for online and in-store + purchases. We target the in-store tab which shows Circle-linked + transactions. + """ + # Navigate to order history to establish context + await page.goto(TARGET_ORDER_HISTORY, wait_until="networkidle") + await self.human_delay(1500, 3000) + + receipts: list[RawReceipt] = [] + + # Target order history API — filter for in-store purchases + api_response = await page.request.get( + TARGET_ORDER_API, + params={"channel": "in_store", "limit": "50"}, + ) + if not api_response.ok: + logger.warning( + "Target order history request failed: %d %s", + api_response.status, + api_response.status_text, + ) + return [] + + response = await api_response.json() + if not isinstance(response, dict): + logger.warning("Unexpected order history response type: %s", type(response)) + return [] + + # Target uses "orders" key for in-store purchase list + orders = response.get("orders", response.get("transactions", [])) + if not isinstance(orders, list): + logger.warning("No orders found in Target order history response") + return [] + + logger.info("Found %d in-store orders in Target history", len(orders)) + + for order in orders: + raw_id = order.get("orderId") or order.get("transactionId") or order.get("id") or "" + order_id = str(raw_id) + purchase_date = order.get( + "purchaseDate", + order.get("transactionDate", order.get("date", "")), + ) + + # Filter by date if 'since' is provided + if since and purchase_date: + try: + txn_dt = datetime.fromisoformat(purchase_date.replace("Z", "+00:00")) + if txn_dt < since: + continue + except (ValueError, TypeError): + pass + + if not order_id: + continue + + await self.human_delay(1000, 2500) + + # Fetch receipt detail + detail = await self._fetch_receipt_detail(page, order_id) + + raw_store = ( + order.get("storeNumber") or order.get("storeId") or order.get("locationId") or "" + ) + store_number = str(raw_store) + + receipts.append( + RawReceipt( + receipt_id=order_id, + purchase_date=purchase_date, + store_number=store_number, + raw_data={**order, "detail": detail}, + source_url=f"{TARGET_RECEIPT_API}/{order_id}", + ) + ) + + logger.info("Scraped %d receipts from Target", len(receipts)) + return receipts + + async def _fetch_receipt_detail(self, page: Page, order_id: str) -> dict: + """Fetch detailed receipt data for a single Target order.""" + try: + url = f"{TARGET_RECEIPT_API}/{order_id}" + api_response = await page.request.get(url) + if not api_response.ok: + logger.warning( + "Target receipt detail request failed for %s: %d", + order_id, + api_response.status, + ) + return {} + detail = await api_response.json() + return detail if isinstance(detail, dict) else {} + except Exception: + logger.exception("Failed to fetch Target receipt detail for %s", order_id) + return {} + + def parse_receipt(self, raw: RawReceipt) -> dict: + """Parse raw Target receipt into structured purchase data.""" + from receiptwitness.parsers.target import parse_target_receipt + + return parse_target_receipt(raw) diff --git a/src/receiptwitness/session/__init__.py b/src/receiptwitness/session/__init__.py new file mode 100644 index 0000000..70beaef --- /dev/null +++ b/src/receiptwitness/session/__init__.py @@ -0,0 +1 @@ +"""Session management — encrypted cookie storage and refresh logic.""" diff --git a/src/receiptwitness/session/encryption.py b/src/receiptwitness/session/encryption.py new file mode 100644 index 0000000..b406bcf --- /dev/null +++ b/src/receiptwitness/session/encryption.py @@ -0,0 +1,52 @@ +"""Fernet-based encryption for session cookies at rest. + +Session data (cookies, tokens) is encrypted before writing to the database +and decrypted only when needed for a scrape. The encryption key is provided +via the RW_SESSION_ENCRYPTION_KEY environment variable — it is never stored +in the database or logged. +""" + +import json +import logging + +from cryptography.fernet import Fernet, InvalidToken + +from receiptwitness.config import settings + +logger = logging.getLogger(__name__) + + +def _get_fernet() -> Fernet: + """Get a Fernet instance using the configured encryption key.""" + key = settings.session_encryption_key + if not key: + raise ValueError( + "RW_SESSION_ENCRYPTION_KEY is not set. " + "Generate one with: " + "python -c 'from cryptography.fernet import Fernet; " + "print(Fernet.generate_key().decode())'" + ) + return Fernet(key.encode() if isinstance(key, str) else key) + + +def encrypt_session_data(data: dict) -> str: + """Encrypt session data dict to a Fernet token string. + + The data is JSON-serialized, then encrypted. The result is a + URL-safe base64-encoded string suitable for storing in JSONB. + """ + f = _get_fernet() + plaintext = json.dumps(data, default=str).encode("utf-8") + return f.encrypt(plaintext).decode("utf-8") + + +def decrypt_session_data(encrypted: str) -> dict: + """Decrypt a Fernet token string back to a session data dict.""" + f = _get_fernet() + try: + plaintext = f.decrypt(encrypted.encode("utf-8")) + result: dict = json.loads(plaintext) + return result + except InvalidToken: + logger.error("Failed to decrypt session data — invalid token or wrong key") + raise diff --git a/src/receiptwitness/session/manager.py b/src/receiptwitness/session/manager.py new file mode 100644 index 0000000..205ccbd --- /dev/null +++ b/src/receiptwitness/session/manager.py @@ -0,0 +1,81 @@ +"""Session storage, retrieval, and refresh logic. + +Manages the lifecycle of retailer session data: +- Load encrypted session from DB +- Check validity via scraper +- Re-authenticate if expired +- Save new session back (encrypted) +""" + +import logging +from dataclasses import asdict +from datetime import UTC, datetime + +from receiptwitness.scrapers.base import BaseScraper, SessionData +from receiptwitness.session.encryption import decrypt_session_data, encrypt_session_data + +logger = logging.getLogger(__name__) + + +def session_from_db_record(session_data_encrypted: str | None) -> SessionData | None: + """Deserialize and decrypt a session from the database. + + The session_data column in user_store_accounts stores the Fernet-encrypted + JSON of the SessionData fields. + """ + if not session_data_encrypted: + return None + + try: + data = decrypt_session_data(session_data_encrypted) + return SessionData( + cookies=data["cookies"], + user_agent=data["user_agent"], + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=( + datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None + ), + extra=data.get("extra", {}), + ) + except Exception: + logger.exception("Failed to load session from DB record") + return None + + +def session_to_db_value(session: SessionData) -> str: + """Serialize and encrypt a session for database storage.""" + data = asdict(session) + # Convert datetime objects to ISO strings for JSON serialization + data["created_at"] = session.created_at.isoformat() + if session.expires_at: + data["expires_at"] = session.expires_at.isoformat() + return encrypt_session_data(data) + + +async def get_valid_session( + scraper: BaseScraper, + session_data_encrypted: str | None, + username: str, + password: str, +) -> tuple[SessionData, bool]: + """Get a valid session, re-authenticating if needed. + + Returns: + A tuple of (session, was_refreshed). If was_refreshed is True, + the caller should persist the new session to the database. + """ + # Try existing session first + existing = session_from_db_record(session_data_encrypted) + if existing: + if existing.expires_at and datetime.now(UTC) > existing.expires_at: + logger.info("Session expired by timestamp, re-authenticating") + elif await scraper.check_session(existing): + logger.info("Existing session is valid") + return existing, False + else: + logger.info("Session check failed, re-authenticating") + + # Need to re-authenticate + logger.info("Performing fresh login") + new_session = await scraper.login(username, password) + return new_session, True diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a8b29ba --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +"""Shared test fixtures.""" + +import json +from pathlib import Path + +import pytest + +FIXTURES_DIR = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def meijer_receipt_data() -> dict: + """Load the sample Meijer receipt fixture.""" + with open(FIXTURES_DIR / "meijer_receipt.json") as f: + return json.load(f) + + +@pytest.fixture +def kroger_receipt_data() -> dict: + """Load the sample Kroger receipt fixture.""" + with open(FIXTURES_DIR / "kroger_receipt.json") as f: + return json.load(f) + + +@pytest.fixture +def target_receipt_data() -> dict: + """Load the sample Target receipt fixture.""" + with open(FIXTURES_DIR / "target_receipt.json") as f: + return json.load(f) diff --git a/tests/fixtures/kroger_receipt.json b/tests/fixtures/kroger_receipt.json new file mode 100644 index 0000000..51c0481 --- /dev/null +++ b/tests/fixtures/kroger_receipt.json @@ -0,0 +1,131 @@ +{ + "orderId": "KR-2026-0312-4471", + "purchaseDate": "2026-03-12T16:45:00Z", + "storeNumber": "00357", + "divisionNumber": "014", + "total": 94.17, + "savings": 15.30, + "detail": { + "receiptId": "KR-2026-0312-4471", + "items": [ + { + "description": "KROGER WHOLE MILK GAL", + "upc": "0001111041700", + "quantity": 1, + "basePrice": 3.99, + "totalPrice": 3.99, + "regularPrice": 4.29, + "salePrice": 3.99, + "couponAmount": 0.0, + "plusCardSavings": 0.30, + "department": "DAIRY" + }, + { + "description": "BANANAS", + "upc": "0000000004011", + "quantity": 1, + "basePrice": 0.59, + "totalPrice": 0.59, + "regularPrice": 0.59, + "salePrice": null, + "couponAmount": null, + "plusCardSavings": null, + "department": "PRODUCE" + }, + { + "description": "SIMPLE TRUTH ORG EGGS 12CT", + "upc": "0001111087840", + "quantity": 2, + "basePrice": 5.49, + "totalPrice": 10.98, + "regularPrice": 5.99, + "salePrice": 5.49, + "couponAmount": 0.0, + "plusCardSavings": 1.00, + "department": "DAIRY" + }, + { + "description": "KROGER DELI TURKEY BREAST", + "upc": null, + "quantity": 0.68, + "basePrice": 9.99, + "totalPrice": 6.79, + "regularPrice": 9.99, + "salePrice": null, + "weight": 0.68, + "weightUom": "LB", + "department": "DELI" + }, + { + "description": "TIDE PODS 42CT", + "upc": "0003700096223", + "quantity": 1, + "basePrice": 13.99, + "totalPrice": 13.99, + "regularPrice": 15.99, + "salePrice": 13.99, + "couponAmount": 2.00, + "plusCardSavings": 0.0, + "department": "HOUSEHOLD" + }, + { + "description": "VOIDED DORITOS NACHO", + "upc": "0002840032505", + "quantity": 1, + "basePrice": 4.79, + "totalPrice": 4.79, + "voided": true, + "department": "SNACKS" + }, + { + "description": "RETURNED GATORADE 8PK", + "upc": "0005200012505", + "quantity": 1, + "basePrice": 7.99, + "totalPrice": 7.99, + "status": "RETURNED", + "department": "BEVERAGES" + }, + { + "description": "KROGER SHARP CHEDDAR 8OZ", + "upc": "0001111060930", + "quantity": 1, + "basePrice": 3.49, + "totalPrice": 3.49, + "regularPrice": 3.49, + "salePrice": null, + "couponAmount": null, + "plusCardSavings": null, + "department": "DAIRY" + }, + { + "description": "PRIVATE SELECTION PASTA", + "upc": "0001111085612", + "quantity": 3, + "basePrice": 2.49, + "totalPrice": 7.47, + "regularPrice": 2.99, + "salePrice": 2.49, + "couponAmount": 0.0, + "plusCardSavings": 1.50, + "department": "GROCERY" + }, + { + "description": "KROGER GROUND BEEF 80/20", + "upc": null, + "quantity": 1.23, + "basePrice": 5.99, + "totalPrice": 7.37, + "regularPrice": 6.99, + "salePrice": 5.99, + "weight": 1.23, + "weightUom": "LB", + "department": "MEAT" + } + ], + "subtotal": 78.47, + "tax": 5.50, + "total": 94.17, + "totalSavings": 15.30 + } +} diff --git a/tests/fixtures/meijer_receipt.json b/tests/fixtures/meijer_receipt.json new file mode 100644 index 0000000..a733215 --- /dev/null +++ b/tests/fixtures/meijer_receipt.json @@ -0,0 +1,85 @@ +{ + "transactionId": "TXN-2026-0310-001", + "transactionDate": "2026-03-10T14:30:00Z", + "storeNumber": "42", + "total": 87.42, + "savings": 12.50, + "detail": { + "receiptId": "TXN-2026-0310-001", + "items": [ + { + "description": "ORGANIC BANANAS", + "upc": "0000000004011", + "quantity": 1, + "price": 0.69, + "extendedPrice": 0.69, + "regularPrice": 0.79, + "salePrice": 0.69, + "couponDiscount": 0.0, + "mperksDiscount": 0.10, + "category": "PRODUCE" + }, + { + "description": "MEIJER 2% MILK GAL", + "upc": "0041250000123", + "quantity": 2, + "price": 3.49, + "extendedPrice": 6.98, + "regularPrice": 3.79, + "salePrice": 3.49, + "couponDiscount": 0.0, + "mperksDiscount": 0.0, + "category": "DAIRY" + }, + { + "description": "CHEERIOS 18OZ", + "upc": "0016000275614", + "quantity": 1, + "price": 4.99, + "extendedPrice": 4.99, + "regularPrice": 5.49, + "salePrice": null, + "couponDiscount": 0.50, + "mperksDiscount": 0.0, + "category": "CEREAL" + }, + { + "description": "WEIGHTED DELI TURKEY", + "upc": null, + "quantity": 0.75, + "price": 8.99, + "extendedPrice": 6.74, + "regularPrice": 8.99, + "salePrice": null, + "couponDiscount": null, + "mperksDiscount": null, + "category": "DELI" + }, + { + "description": "VOIDED SODA 12PK", + "upc": "0004900005678", + "quantity": 1, + "price": 5.99, + "extendedPrice": 5.99, + "voided": true, + "category": "BEVERAGES" + }, + { + "description": "MEIJER PAPER TOWELS 6PK", + "upc": "0041250099001", + "quantity": 1, + "price": 7.99, + "extendedPrice": 7.99, + "regularPrice": 9.99, + "salePrice": 7.99, + "couponDiscount": 1.00, + "mperksDiscount": 1.00, + "category": "HOUSEHOLD" + } + ], + "subtotal": 74.92, + "tax": 5.24, + "total": 87.42, + "totalSavings": 12.50 + } +} diff --git a/tests/fixtures/target_receipt.json b/tests/fixtures/target_receipt.json new file mode 100644 index 0000000..c76bb5b --- /dev/null +++ b/tests/fixtures/target_receipt.json @@ -0,0 +1,140 @@ +{ + "orderId": "TGT-2026-0315-7890", + "purchaseDate": "2026-03-15T11:23:00Z", + "storeNumber": "2774", + "total": 83.21, + "savings": 11.45, + "detail": { + "receiptId": "TGT-2026-0315-7890", + "items": [ + { + "description": "GOOD & GATHER WHOLE MILK GAL", + "tcin": "14767459", + "upc": "0085239100123", + "quantity": 1, + "unitPrice": 3.89, + "totalPrice": 3.89, + "regularPrice": 4.19, + "circlePrice": 3.89, + "couponDiscount": 0.0, + "circleRewardsDiscount": 0.30, + "promoDescription": "Circle offer: Save 30c", + "department": "GROCERY" + }, + { + "description": "BANANAS", + "upc": "0000000004011", + "quantity": 1, + "unitPrice": 0.25, + "totalPrice": 0.25, + "regularPrice": 0.25, + "circlePrice": null, + "couponDiscount": null, + "circleRewardsDiscount": null, + "department": "PRODUCE" + }, + { + "description": "MARKET PANTRY LARGE EGGS 18CT", + "tcin": "13292174", + "upc": "0085239206753", + "quantity": 2, + "unitPrice": 4.99, + "totalPrice": 9.98, + "regularPrice": 5.49, + "circlePrice": 4.99, + "couponDiscount": 0.0, + "circleRewardsDiscount": 1.00, + "promoDescription": "Circle offer: 2 for $10", + "department": "GROCERY" + }, + { + "description": "DELI SLICED TURKEY BREAST", + "upc": null, + "quantity": 0.72, + "unitPrice": 10.99, + "totalPrice": 7.91, + "regularPrice": 10.99, + "weight": 0.72, + "weightUom": "LB", + "department": "DELI" + }, + { + "description": "TIDE PODS 42CT", + "tcin": "76150253", + "upc": "0003700096223", + "quantity": 1, + "unitPrice": 13.49, + "totalPrice": 13.49, + "regularPrice": 15.99, + "circlePrice": 13.49, + "couponDiscount": 2.50, + "circleRewardsDiscount": 0.0, + "promoDescription": "Circle offer + mfr coupon", + "department": "HOUSEHOLD" + }, + { + "description": "UP&UP PAPER TOWELS 6PK", + "tcin": "52493117", + "upc": "0085239401567", + "quantity": 1, + "unitPrice": 8.99, + "totalPrice": 8.99, + "regularPrice": 8.99, + "circlePrice": null, + "couponDiscount": null, + "circleRewardsDiscount": null, + "department": "HOUSEHOLD" + }, + { + "description": "VOIDED COCA-COLA 12PK", + "upc": "0004900002521", + "quantity": 1, + "unitPrice": 7.49, + "totalPrice": 7.49, + "voided": true, + "department": "BEVERAGES" + }, + { + "description": "RETURNED OLAY MOISTURIZER", + "upc": "0007560402118", + "quantity": 1, + "unitPrice": 12.99, + "totalPrice": 12.99, + "status": "RETURNED", + "department": "BEAUTY" + }, + { + "description": "FAVOURITE DAY TRAIL MIX", + "tcin": "83921045", + "dpci": "271-09-0142", + "upc": "0085239700891", + "quantity": 1, + "unitPrice": 5.49, + "totalPrice": 5.49, + "regularPrice": 5.49, + "circlePrice": null, + "couponDiscount": null, + "circleRewardsDiscount": null, + "department": "SNACKS" + }, + { + "description": "BOGO GOOD & GATHER PASTA", + "tcin": "78114326", + "upc": "0085239300456", + "quantity": 2, + "unitPrice": 1.79, + "totalPrice": 1.79, + "regularPrice": 1.79, + "circlePrice": 0.895, + "couponDiscount": 0.0, + "circleRewardsDiscount": 1.79, + "promoDescription": "Buy 1 get 1 free", + "department": "GROCERY" + } + ], + "subtotal": 78.32, + "tax": 4.89, + "total": 83.21, + "totalSavings": 11.45 + } +} diff --git a/tests/test_parsers/__init__.py b/tests/test_parsers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_parsers/test_kroger_parser.py b/tests/test_parsers/test_kroger_parser.py new file mode 100644 index 0000000..001d205 --- /dev/null +++ b/tests/test_parsers/test_kroger_parser.py @@ -0,0 +1,399 @@ +"""Tests for the Kroger receipt parser.""" + +from decimal import Decimal + +from receiptwitness.parsers.kroger import _parse_item, _to_decimal, parse_kroger_receipt +from receiptwitness.scrapers.base import RawReceipt + + +class TestToDecimal: + def test_from_int(self): + assert _to_decimal(42) == Decimal("42") + + def test_from_float(self): + assert _to_decimal(3.99) == Decimal("3.99") + + def test_from_string(self): + assert _to_decimal("7.49") == Decimal("7.49") + + def test_none_returns_default(self): + assert _to_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _to_decimal(None, "1") == Decimal("1") + + def test_invalid_string_returns_default(self): + assert _to_decimal("not-a-number") == Decimal("0") + + def test_empty_string_returns_default(self): + assert _to_decimal("") == Decimal("0") + + +class TestParseItem: + def test_standard_item(self): + raw = { + "description": "KROGER WHOLE MILK GAL", + "upc": "0001111041700", + "quantity": 1, + "basePrice": 3.99, + "totalPrice": 3.99, + "regularPrice": 4.29, + "salePrice": 3.99, + "couponAmount": 0.0, + "plusCardSavings": 0.30, + "department": "DAIRY", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "KROGER WHOLE MILK GAL" + assert result["upc"] == "1111041700" + assert result["quantity"] == Decimal("1") + assert result["unit_price"] == Decimal("3.99") + assert result["extended_price"] == Decimal("3.99") + assert result["regular_price"] == Decimal("4.29") + assert result["sale_price"] == Decimal("3.99") + assert result["loyalty_discount"] == Decimal("0.30") + assert result["category_raw"] == "DAIRY" + + def test_weighted_item(self): + raw = { + "description": "KROGER DELI TURKEY BREAST", + "quantity": 0.68, + "basePrice": 9.99, + "totalPrice": 6.79, + "weight": 0.68, + "weightUom": "LB", + "department": "DELI", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "KROGER DELI TURKEY BREAST" + assert result["upc"] is None + assert result["quantity"] == Decimal("0.68") + assert result["unit_price"] == Decimal("9.99") + assert result["extended_price"] == Decimal("6.79") + + def test_missing_extended_price_computed(self): + raw = { + "description": "TEST ITEM", + "quantity": 3, + "basePrice": 2.49, + } + result = _parse_item(raw) + assert result["extended_price"] == Decimal("2.49") * Decimal("3") + + def test_item_with_coupon(self): + raw = { + "description": "TIDE PODS 42CT", + "upc": "0003700096223", + "quantity": 1, + "basePrice": 13.99, + "totalPrice": 13.99, + "couponAmount": 2.00, + } + result = _parse_item(raw) + assert result["coupon_discount"] == Decimal("2.00") + + def test_missing_description_fallback(self): + raw = {"basePrice": 1.00, "totalPrice": 1.00} + result = _parse_item(raw) + assert result["product_name_raw"] == "UNKNOWN ITEM" + + def test_alternative_field_names_product_name(self): + raw = { + "productName": "ALT NAME ITEM", + "unitPrice": 5.00, + "extendedAmount": 5.00, + "qty": 1, + "krogerProductId": "123456789", + "category": "GROCERY", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ALT NAME ITEM" + assert result["unit_price"] == Decimal("5.00") + assert result["extended_price"] == Decimal("5.00") + assert result["upc"] == "123456789" + assert result["category_raw"] == "GROCERY" + + def test_item_description_field_name(self): + raw = { + "itemDescription": "ITEM DESC FIELD", + "price": 3.00, + "lineTotal": 3.00, + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ITEM DESC FIELD" + assert result["unit_price"] == Decimal("3.00") + assert result["extended_price"] == Decimal("3.00") + + def test_null_optional_fields(self): + raw = { + "description": "BANANAS", + "upc": "0000000004011", + "quantity": 1, + "basePrice": 0.59, + "totalPrice": 0.59, + "salePrice": None, + "couponAmount": None, + "plusCardSavings": None, + } + result = _parse_item(raw) + assert result["sale_price"] is None + assert result["coupon_discount"] is None + assert result["loyalty_discount"] is None + + def test_upc_leading_zeros_stripped(self): + raw = { + "description": "TEST", + "upc": "0000000004011", + "basePrice": 1.00, + "totalPrice": 1.00, + } + result = _parse_item(raw) + assert result["upc"] == "4011" + + def test_upc_from_kroger_product_id(self): + raw = { + "description": "TEST", + "krogerProductId": "987654321", + "basePrice": 1.00, + "totalPrice": 1.00, + } + result = _parse_item(raw) + assert result["upc"] == "987654321" + + def test_description_whitespace_stripped(self): + raw = { + "description": " EXTRA SPACES ", + "basePrice": 1.00, + "totalPrice": 1.00, + } + result = _parse_item(raw) + assert result["product_name_raw"] == "EXTRA SPACES" + + def test_promo_price_field(self): + raw = { + "description": "PROMO ITEM", + "promoPrice": 2.99, + "originalPrice": 4.99, + "basePrice": 2.99, + "totalPrice": 2.99, + } + result = _parse_item(raw) + assert result["sale_price"] == Decimal("2.99") + assert result["regular_price"] == Decimal("4.99") + + def test_loyalty_discount_from_fuel_points(self): + raw = { + "description": "FUEL DISC ITEM", + "fuelPointsDiscount": 0.50, + "basePrice": 3.00, + "totalPrice": 3.00, + } + result = _parse_item(raw) + assert result["loyalty_discount"] == Decimal("0.50") + + def test_multi_quantity_item(self): + raw = { + "description": "PRIVATE SELECTION PASTA", + "quantity": 3, + "basePrice": 2.49, + "totalPrice": 7.47, + "department": "GROCERY", + } + result = _parse_item(raw) + assert result["quantity"] == Decimal("3") + assert result["unit_price"] == Decimal("2.49") + assert result["extended_price"] == Decimal("7.47") + + def test_aisle_as_category(self): + raw = { + "description": "AISLE ITEM", + "aisle": "FROZEN FOODS", + "basePrice": 4.00, + "totalPrice": 4.00, + } + result = _parse_item(raw) + assert result["category_raw"] == "FROZEN FOODS" + + +class TestParseKrogerReceipt: + def test_full_receipt(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + store_number="00357", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + + assert result["receipt_id"] == "KR-2026-0312-4471" + assert result["purchase_date"] == "2026-03-12T16:45:00Z" + assert result["total"] == Decimal("94.17") + assert result["subtotal"] == Decimal("78.47") + assert result["tax"] == Decimal("5.50") + assert result["savings_total"] == Decimal("15.30") + + # Should have 8 items (voided + returned items excluded) + assert len(result["items"]) == 8 + + # Verify first item + milk = result["items"][0] + assert milk["product_name_raw"] == "KROGER WHOLE MILK GAL" + assert milk["upc"] == "1111041700" + + def test_voided_items_excluded(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + + item_names = [i["product_name_raw"] for i in result["items"]] + assert "VOIDED DORITOS NACHO" not in item_names + + def test_returned_items_excluded(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + + item_names = [i["product_name_raw"] for i in result["items"]] + assert "RETURNED GATORADE 8PK" not in item_names + + def test_return_flag_items_excluded(self): + data = { + "detail": { + "items": [ + { + "description": "NORMAL ITEM", + "basePrice": 5.00, + "totalPrice": 5.00, + }, + { + "description": "RETURNED VIA FLAG", + "basePrice": 3.00, + "totalPrice": 3.00, + "returnFlag": True, + }, + { + "description": "IS RETURN ITEM", + "basePrice": 2.00, + "totalPrice": 2.00, + "isReturn": True, + }, + ], + "total": 5.00, + } + } + raw = RawReceipt( + receipt_id="RET-001", + purchase_date="2026-03-12", + raw_data=data, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "NORMAL ITEM" + + def test_empty_receipt(self): + raw = RawReceipt( + receipt_id="EMPTY-001", + purchase_date="2026-03-12", + raw_data={"detail": {"items": [], "total": 0}}, + ) + result = parse_kroger_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("0") + + def test_receipt_with_no_detail(self): + raw = RawReceipt( + receipt_id="NO-DETAIL-001", + purchase_date="2026-03-12", + raw_data={"total": 50.00}, + ) + result = parse_kroger_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("50.00") + + def test_raw_data_preserved(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + assert result["raw_data"] is kroger_receipt_data + + def test_alternative_total_field_names(self): + raw = RawReceipt( + receipt_id="ALT-001", + purchase_date="2026-03-12", + raw_data={ + "orderTotal": 42.00, + "subTotal": 35.00, + "salesTax": 3.50, + "youSaved": 5.00, + "detail": {"items": []}, + }, + ) + result = parse_kroger_receipt(raw) + assert result["total"] == Decimal("42.00") + assert result["subtotal"] == Decimal("35.00") + assert result["tax"] == Decimal("3.50") + assert result["savings_total"] == Decimal("5.00") + + def test_receipt_items_alternative_key(self): + data = { + "detail": { + "receiptItems": [ + { + "description": "ALT KEY ITEM", + "basePrice": 3.00, + "totalPrice": 3.00, + } + ], + "total": 3.00, + } + } + raw = RawReceipt( + receipt_id="ALT-KEY-001", + purchase_date="2026-03-12", + raw_data=data, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "ALT KEY ITEM" + + def test_source_url_preserved(self): + raw = RawReceipt( + receipt_id="URL-001", + purchase_date="2026-03-12", + raw_data={"detail": {"items": [], "total": 0}}, + source_url="https://www.kroger.com/atlas/v1/receipt/api?orderId=URL-001", + ) + result = parse_kroger_receipt(raw) + assert result["source_url"] == "https://www.kroger.com/atlas/v1/receipt/api?orderId=URL-001" + + def test_weighted_items_in_full_receipt(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + + # Find the weighted turkey item + turkey = next(i for i in result["items"] if "TURKEY" in i["product_name_raw"]) + assert turkey["quantity"] == Decimal("0.68") + assert turkey["unit_price"] == Decimal("9.99") + assert turkey["extended_price"] == Decimal("6.79") + + def test_grand_total_field(self): + raw = RawReceipt( + receipt_id="GT-001", + purchase_date="2026-03-12", + raw_data={"grandTotal": 99.99, "detail": {"items": []}}, + ) + result = parse_kroger_receipt(raw) + assert result["total"] == Decimal("99.99") diff --git a/tests/test_parsers/test_meijer_parser.py b/tests/test_parsers/test_meijer_parser.py new file mode 100644 index 0000000..47a5fa9 --- /dev/null +++ b/tests/test_parsers/test_meijer_parser.py @@ -0,0 +1,174 @@ +"""Tests for the Meijer receipt parser.""" + +from decimal import Decimal + +from receiptwitness.parsers.meijer import _parse_item, _to_decimal, parse_meijer_receipt +from receiptwitness.scrapers.base import RawReceipt + + +class TestToDecimal: + def test_from_int(self): + assert _to_decimal(42) == Decimal("42") + + def test_from_float(self): + assert _to_decimal(3.49) == Decimal("3.49") + + def test_from_string(self): + assert _to_decimal("7.99") == Decimal("7.99") + + def test_none_returns_default(self): + assert _to_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _to_decimal(None, "1") == Decimal("1") + + def test_invalid_string_returns_default(self): + assert _to_decimal("not-a-number") == Decimal("0") + + +class TestParseItem: + def test_standard_item(self): + raw = { + "description": "ORGANIC BANANAS", + "upc": "0000000004011", + "quantity": 1, + "price": 0.69, + "extendedPrice": 0.69, + "regularPrice": 0.79, + "salePrice": 0.69, + "couponDiscount": 0.0, + "mperksDiscount": 0.10, + "category": "PRODUCE", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ORGANIC BANANAS" + assert result["upc"] == "4011" + assert result["quantity"] == Decimal("1") + assert result["unit_price"] == Decimal("0.69") + assert result["extended_price"] == Decimal("0.69") + assert result["regular_price"] == Decimal("0.79") + assert result["sale_price"] == Decimal("0.69") + assert result["loyalty_discount"] == Decimal("0.10") + assert result["category_raw"] == "PRODUCE" + + def test_weighted_item(self): + raw = { + "description": "WEIGHTED DELI TURKEY", + "quantity": 0.75, + "price": 8.99, + "extendedPrice": 6.74, + "category": "DELI", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "WEIGHTED DELI TURKEY" + assert result["upc"] is None + assert result["quantity"] == Decimal("0.75") + assert result["unit_price"] == Decimal("8.99") + assert result["extended_price"] == Decimal("6.74") + + def test_missing_extended_price_computed(self): + raw = { + "description": "TEST ITEM", + "quantity": 3, + "price": 2.50, + } + result = _parse_item(raw) + assert result["extended_price"] == Decimal("2.50") * Decimal("3") + + def test_item_with_coupon_discount(self): + raw = { + "description": "CHEERIOS 18OZ", + "upc": "0016000275614", + "quantity": 1, + "price": 4.99, + "extendedPrice": 4.99, + "couponDiscount": 0.50, + } + result = _parse_item(raw) + assert result["coupon_discount"] == Decimal("0.50") + + def test_missing_description_fallback(self): + raw = {"price": 1.00, "extendedPrice": 1.00} + result = _parse_item(raw) + assert result["product_name_raw"] == "UNKNOWN ITEM" + + def test_alternative_field_names(self): + raw = { + "itemDescription": "ALT NAME ITEM", + "unitPrice": 5.00, + "totalPrice": 5.00, + "qty": 1, + "UPC": "123456789", + "departmentDescription": "GROCERY", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ALT NAME ITEM" + assert result["unit_price"] == Decimal("5.00") + assert result["upc"] == "123456789" + assert result["category_raw"] == "GROCERY" + + +class TestParseMeijerReceipt: + def test_full_receipt(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + store_number="42", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + + assert result["receipt_id"] == "TXN-2026-0310-001" + assert result["purchase_date"] == "2026-03-10T14:30:00Z" + assert result["total"] == Decimal("87.42") + assert result["subtotal"] == Decimal("74.92") + assert result["tax"] == Decimal("5.24") + assert result["savings_total"] == Decimal("12.50") + + # Should have 5 items (voided item excluded) + assert len(result["items"]) == 5 + + # Verify first item + bananas = result["items"][0] + assert bananas["product_name_raw"] == "ORGANIC BANANAS" + assert bananas["upc"] == "4011" + + def test_voided_items_excluded(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + + item_names = [i["product_name_raw"] for i in result["items"]] + assert "VOIDED SODA 12PK" not in item_names + + def test_empty_receipt(self): + raw = RawReceipt( + receipt_id="EMPTY-001", + purchase_date="2026-03-10", + raw_data={"detail": {"items": [], "total": 0}}, + ) + result = parse_meijer_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("0") + + def test_receipt_with_no_detail(self): + raw = RawReceipt( + receipt_id="NO-DETAIL-001", + purchase_date="2026-03-10", + raw_data={"total": 50.00}, + ) + result = parse_meijer_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("50.00") + + def test_raw_data_preserved(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + assert result["raw_data"] is meijer_receipt_data diff --git a/tests/test_parsers/test_target_parser.py b/tests/test_parsers/test_target_parser.py new file mode 100644 index 0000000..8f197ac --- /dev/null +++ b/tests/test_parsers/test_target_parser.py @@ -0,0 +1,471 @@ +"""Tests for the Target receipt parser.""" + +from decimal import Decimal + +from receiptwitness.parsers.target import _parse_item, _to_decimal, parse_target_receipt +from receiptwitness.scrapers.base import RawReceipt + + +class TestToDecimal: + def test_from_int(self): + assert _to_decimal(42) == Decimal("42") + + def test_from_float(self): + assert _to_decimal(3.89) == Decimal("3.89") + + def test_from_string(self): + assert _to_decimal("8.99") == Decimal("8.99") + + def test_none_returns_default(self): + assert _to_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _to_decimal(None, "1") == Decimal("1") + + def test_invalid_string_returns_default(self): + assert _to_decimal("not-a-number") == Decimal("0") + + def test_empty_string_returns_default(self): + assert _to_decimal("") == Decimal("0") + + +class TestParseItem: + def test_standard_item(self): + raw = { + "description": "GOOD & GATHER WHOLE MILK GAL", + "tcin": "14767459", + "upc": "0085239100123", + "quantity": 1, + "unitPrice": 3.89, + "totalPrice": 3.89, + "regularPrice": 4.19, + "circlePrice": 3.89, + "couponDiscount": 0.0, + "circleRewardsDiscount": 0.30, + "department": "GROCERY", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "GOOD & GATHER WHOLE MILK GAL" + assert result["upc"] == "85239100123" + assert result["quantity"] == Decimal("1") + assert result["unit_price"] == Decimal("3.89") + assert result["extended_price"] == Decimal("3.89") + assert result["regular_price"] == Decimal("4.19") + assert result["sale_price"] == Decimal("3.89") + assert result["loyalty_discount"] == Decimal("0.30") + assert result["category_raw"] == "GROCERY" + + def test_weighted_item(self): + raw = { + "description": "DELI SLICED TURKEY BREAST", + "quantity": 0.72, + "unitPrice": 10.99, + "totalPrice": 7.91, + "weight": 0.72, + "weightUom": "LB", + "department": "DELI", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "DELI SLICED TURKEY BREAST" + assert result["upc"] is None + assert result["quantity"] == Decimal("0.72") + assert result["unit_price"] == Decimal("10.99") + assert result["extended_price"] == Decimal("7.91") + + def test_missing_extended_price_computed(self): + raw = { + "description": "TEST ITEM", + "quantity": 3, + "unitPrice": 2.49, + } + result = _parse_item(raw) + assert result["extended_price"] == Decimal("2.49") * Decimal("3") + + def test_item_with_coupon(self): + raw = { + "description": "TIDE PODS 42CT", + "upc": "0003700096223", + "quantity": 1, + "unitPrice": 13.49, + "totalPrice": 13.49, + "couponDiscount": 2.50, + } + result = _parse_item(raw) + assert result["coupon_discount"] == Decimal("2.50") + + def test_missing_description_fallback(self): + raw = {"unitPrice": 1.00, "totalPrice": 1.00} + result = _parse_item(raw) + assert result["product_name_raw"] == "UNKNOWN ITEM" + + def test_alternative_field_names(self): + raw = { + "productName": "ALT NAME ITEM", + "price": 5.00, + "extendedPrice": 5.00, + "qty": 1, + "UPC": "123456789", + "category": "FROZEN", + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ALT NAME ITEM" + assert result["unit_price"] == Decimal("5.00") + assert result["extended_price"] == Decimal("5.00") + assert result["upc"] == "123456789" + assert result["category_raw"] == "FROZEN" + + def test_item_description_field_name(self): + raw = { + "itemDescription": "ITEM DESC FIELD", + "price": 3.00, + "lineTotal": 3.00, + } + result = _parse_item(raw) + assert result["product_name_raw"] == "ITEM DESC FIELD" + assert result["unit_price"] == Decimal("3.00") + assert result["extended_price"] == Decimal("3.00") + + def test_null_optional_fields(self): + raw = { + "description": "BANANAS", + "upc": "0000000004011", + "quantity": 1, + "unitPrice": 0.25, + "totalPrice": 0.25, + "circlePrice": None, + "couponDiscount": None, + "circleRewardsDiscount": None, + } + result = _parse_item(raw) + assert result["sale_price"] is None + assert result["coupon_discount"] is None + assert result["loyalty_discount"] is None + + def test_upc_leading_zeros_stripped(self): + raw = { + "description": "TEST", + "upc": "0000000004011", + "unitPrice": 1.00, + "totalPrice": 1.00, + } + result = _parse_item(raw) + assert result["upc"] == "4011" + + def test_description_whitespace_stripped(self): + raw = { + "description": " EXTRA SPACES ", + "unitPrice": 1.00, + "totalPrice": 1.00, + } + result = _parse_item(raw) + assert result["product_name_raw"] == "EXTRA SPACES" + + def test_circle_price_preferred_over_sale_price(self): + raw = { + "description": "CIRCLE ITEM", + "circlePrice": 2.99, + "salePrice": 3.49, + "unitPrice": 2.99, + "totalPrice": 2.99, + } + result = _parse_item(raw) + assert result["sale_price"] == Decimal("2.99") + + def test_sale_price_fallback_when_no_circle_price(self): + raw = { + "description": "SALE ITEM", + "salePrice": 3.49, + "unitPrice": 3.49, + "totalPrice": 3.49, + } + result = _parse_item(raw) + assert result["sale_price"] == Decimal("3.49") + + def test_circle_rewards_discount(self): + raw = { + "description": "CIRCLE REWARDS ITEM", + "circleRewardsDiscount": 1.50, + "unitPrice": 5.00, + "totalPrice": 5.00, + } + result = _parse_item(raw) + assert result["loyalty_discount"] == Decimal("1.50") + + def test_circle_discount_fallback(self): + raw = { + "description": "CIRCLE DISC ITEM", + "circleDiscount": 0.75, + "unitPrice": 3.00, + "totalPrice": 3.00, + } + result = _parse_item(raw) + assert result["loyalty_discount"] == Decimal("0.75") + + def test_bogo_item(self): + raw = { + "description": "BOGO GOOD & GATHER PASTA", + "upc": "0085239300456", + "quantity": 2, + "unitPrice": 1.79, + "totalPrice": 1.79, + "regularPrice": 1.79, + "circlePrice": 0.895, + "circleRewardsDiscount": 1.79, + "promoDescription": "Buy 1 get 1 free", + "department": "GROCERY", + } + result = _parse_item(raw) + assert result["quantity"] == Decimal("2") + assert result["unit_price"] == Decimal("1.79") + assert result["extended_price"] == Decimal("1.79") + assert result["sale_price"] == Decimal("0.895") + assert result["loyalty_discount"] == Decimal("1.79") + + def test_multi_quantity_item(self): + raw = { + "description": "MARKET PANTRY EGGS", + "quantity": 2, + "unitPrice": 4.99, + "totalPrice": 9.98, + "department": "GROCERY", + } + result = _parse_item(raw) + assert result["quantity"] == Decimal("2") + assert result["unit_price"] == Decimal("4.99") + assert result["extended_price"] == Decimal("9.98") + + def test_coupon_savings_field(self): + raw = { + "description": "COUPON ITEM", + "couponSavings": 1.00, + "unitPrice": 5.00, + "totalPrice": 5.00, + } + result = _parse_item(raw) + assert result["coupon_discount"] == Decimal("1.00") + + +class TestParseTargetReceipt: + def test_full_receipt(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15T11:23:00Z", + store_number="2774", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + + assert result["receipt_id"] == "TGT-2026-0315-7890" + assert result["purchase_date"] == "2026-03-15T11:23:00Z" + assert result["total"] == Decimal("83.21") + assert result["subtotal"] == Decimal("78.32") + assert result["tax"] == Decimal("4.89") + assert result["savings_total"] == Decimal("11.45") + + # Should have 8 items (voided + returned items excluded) + assert len(result["items"]) == 8 + + # Verify first item + milk = result["items"][0] + assert milk["product_name_raw"] == "GOOD & GATHER WHOLE MILK GAL" + assert milk["upc"] == "85239100123" + + def test_voided_items_excluded(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + + item_names = [i["product_name_raw"] for i in result["items"]] + assert "VOIDED COCA-COLA 12PK" not in item_names + + def test_returned_items_excluded(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + + item_names = [i["product_name_raw"] for i in result["items"]] + assert "RETURNED OLAY MOISTURIZER" not in item_names + + def test_return_flag_items_excluded(self): + data = { + "detail": { + "items": [ + { + "description": "NORMAL ITEM", + "unitPrice": 5.00, + "totalPrice": 5.00, + }, + { + "description": "RETURNED VIA FLAG", + "unitPrice": 3.00, + "totalPrice": 3.00, + "returnFlag": True, + }, + { + "description": "IS RETURN ITEM", + "unitPrice": 2.00, + "totalPrice": 2.00, + "isReturn": True, + }, + ], + "total": 5.00, + } + } + raw = RawReceipt( + receipt_id="RET-001", + purchase_date="2026-03-15", + raw_data=data, + ) + result = parse_target_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "NORMAL ITEM" + + def test_cancelled_items_excluded(self): + data = { + "detail": { + "items": [ + { + "description": "NORMAL ITEM", + "unitPrice": 5.00, + "totalPrice": 5.00, + }, + { + "description": "CANCELLED ITEM", + "unitPrice": 3.00, + "totalPrice": 3.00, + "status": "CANCELLED", + }, + ], + "total": 5.00, + } + } + raw = RawReceipt( + receipt_id="CAN-001", + purchase_date="2026-03-15", + raw_data=data, + ) + result = parse_target_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "NORMAL ITEM" + + def test_empty_receipt(self): + raw = RawReceipt( + receipt_id="EMPTY-001", + purchase_date="2026-03-15", + raw_data={"detail": {"items": [], "total": 0}}, + ) + result = parse_target_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("0") + + def test_receipt_with_no_detail(self): + raw = RawReceipt( + receipt_id="NO-DETAIL-001", + purchase_date="2026-03-15", + raw_data={"total": 50.00}, + ) + result = parse_target_receipt(raw) + assert result["items"] == [] + assert result["total"] == Decimal("50.00") + + def test_raw_data_preserved(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + assert result["raw_data"] is target_receipt_data + + def test_alternative_total_field_names(self): + raw = RawReceipt( + receipt_id="ALT-001", + purchase_date="2026-03-15", + raw_data={ + "orderTotal": 42.00, + "subTotal": 35.00, + "salesTax": 3.50, + "circleSavings": 5.00, + "detail": {"items": []}, + }, + ) + result = parse_target_receipt(raw) + assert result["total"] == Decimal("42.00") + assert result["subtotal"] == Decimal("35.00") + assert result["tax"] == Decimal("3.50") + assert result["savings_total"] == Decimal("5.00") + + def test_receipt_items_alternative_key(self): + data = { + "detail": { + "lineItems": [ + { + "description": "ALT KEY ITEM", + "unitPrice": 3.00, + "totalPrice": 3.00, + } + ], + "total": 3.00, + } + } + raw = RawReceipt( + receipt_id="ALT-KEY-001", + purchase_date="2026-03-15", + raw_data=data, + ) + result = parse_target_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "ALT KEY ITEM" + + def test_source_url_preserved(self): + raw = RawReceipt( + receipt_id="URL-001", + purchase_date="2026-03-15", + raw_data={"detail": {"items": [], "total": 0}}, + source_url="https://api.target.com/order_history/v1/orders/URL-001", + ) + result = parse_target_receipt(raw) + assert result["source_url"] == "https://api.target.com/order_history/v1/orders/URL-001" + + def test_weighted_items_in_full_receipt(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + + # Find the weighted turkey item + turkey = next(i for i in result["items"] if "TURKEY" in i["product_name_raw"]) + assert turkey["quantity"] == Decimal("0.72") + assert turkey["unit_price"] == Decimal("10.99") + assert turkey["extended_price"] == Decimal("7.91") + + def test_bogo_items_in_full_receipt(self, target_receipt_data): + raw = RawReceipt( + receipt_id="TGT-2026-0315-7890", + purchase_date="2026-03-15", + raw_data=target_receipt_data, + ) + result = parse_target_receipt(raw) + + # Find the BOGO pasta item + pasta = next(i for i in result["items"] if "BOGO" in i["product_name_raw"]) + assert pasta["quantity"] == Decimal("2") + assert pasta["extended_price"] == Decimal("1.79") + assert pasta["loyalty_discount"] == Decimal("1.79") + + def test_grand_total_field(self): + raw = RawReceipt( + receipt_id="GT-001", + purchase_date="2026-03-15", + raw_data={"grandTotal": 99.99, "detail": {"items": []}}, + ) + result = parse_target_receipt(raw) + assert result["total"] == Decimal("99.99") diff --git a/tests/test_pipeline/__init__.py b/tests/test_pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_pipeline/conftest.py b/tests/test_pipeline/conftest.py new file mode 100644 index 0000000..693366f --- /dev/null +++ b/tests/test_pipeline/conftest.py @@ -0,0 +1,23 @@ +"""Shared test fixtures for pipeline tests.""" + +import pytest +from cartsnitch_common.models.base import Base +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +@pytest.fixture +def engine(): + """In-memory SQLite engine for unit tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """SQLAlchemy session bound to in-memory SQLite.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess diff --git a/tests/test_pipeline/test_matching.py b/tests/test_pipeline/test_matching.py new file mode 100644 index 0000000..408153c --- /dev/null +++ b/tests/test_pipeline/test_matching.py @@ -0,0 +1,161 @@ +"""Tests for product matching & dedup pipeline.""" + +import uuid +from datetime import UTC, datetime +from decimal import Decimal + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.schemas.purchase import PurchaseItemCreate + +from receiptwitness.pipeline.matching import ( + ProductMatcher, + classify_confidence, + match_purchase_item, +) +from receiptwitness.pipeline.normalization import MatchMethod + + +class TestClassifyConfidence: + def test_upc_always_high(self): + assert classify_confidence(1.0, MatchMethod.UPC) == MatchConfidence.HIGH + assert classify_confidence(0.5, MatchMethod.UPC) == MatchConfidence.HIGH + + def test_name_high(self): + assert classify_confidence(0.9, MatchMethod.NAME) == MatchConfidence.HIGH + assert classify_confidence(0.8, MatchMethod.NAME) == MatchConfidence.HIGH + + def test_name_medium(self): + assert classify_confidence(0.6, MatchMethod.NAME) == MatchConfidence.MEDIUM + assert classify_confidence(0.5, MatchMethod.NAME) == MatchConfidence.MEDIUM + + def test_name_low(self): + assert classify_confidence(0.3, MatchMethod.NAME) == MatchConfidence.LOW + assert classify_confidence(0.0, MatchMethod.NAME) == MatchConfidence.LOW + + +class TestProductMatcher: + def _make_item(self, name: str, upc: str | None = None) -> PurchaseItemCreate: + return PurchaseItemCreate( + product_name_raw=name, + upc=upc, + unit_price=Decimal("3.99"), + extended_price=Decimal("3.99"), + ) + + def test_match_by_upc(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + upc_variants=["041250000001"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + item = self._make_item("Kroger Milk", upc="041250000001") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert prod.id == product.id + assert result is not None + assert result.method == MatchMethod.UPC + assert confidence == MatchConfidence.HIGH + + def test_match_by_name(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session, name_threshold=0.3) + item = self._make_item("Whole Milk Gallon Size") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is not None + assert result.method == MatchMethod.NAME + + def test_auto_create_when_no_match(self, session): + matcher = ProductMatcher(session, auto_create=True) + item = self._make_item("Unique Product XYZ 16 oz") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is None # No match found, was created + assert confidence == MatchConfidence.LOW + assert prod.canonical_name == "Unique Product XYZ 16 oz" + assert prod.size == "16" + assert prod.size_unit == "oz" + + def test_no_create_when_disabled(self, session): + matcher = ProductMatcher(session, auto_create=False) + item = self._make_item("Nonexistent Product") + prod, result, confidence = matcher.match_single(item) + + assert prod is None + assert result is None + + def test_batch_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs 12 Count", + upc_variants=["012345"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + items = [ + self._make_item("Large Eggs", upc="012345"), + self._make_item("Brand New Never Seen Product"), + ] + outcomes = matcher.match_items(items) + + assert len(outcomes) == 2 + assert outcomes[0].match is not None + assert outcomes[0].confidence_level == MatchConfidence.HIGH + assert outcomes[0].created_new is False + assert outcomes[1].match is None + assert outcomes[1].created_new is True + + +class TestMatchPurchaseItem: + def test_convenience_function(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + upc_variants=["999888"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + item = PurchaseItemCreate( + product_name_raw="Ground Beef", + upc="999888", + unit_price=Decimal("5.99"), + extended_price=Decimal("5.99"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.HIGH + + def test_auto_create_default(self, session): + item = PurchaseItemCreate( + product_name_raw="Totally New Item", + unit_price=Decimal("1.00"), + extended_price=Decimal("1.00"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.LOW diff --git a/tests/test_pipeline/test_normalization.py b/tests/test_pipeline/test_normalization.py new file mode 100644 index 0000000..de1d566 --- /dev/null +++ b/tests/test_pipeline/test_normalization.py @@ -0,0 +1,158 @@ +"""Tests for product normalization module.""" + +import uuid +from datetime import UTC, datetime + +from cartsnitch_common.models.product import NormalizedProduct + +from receiptwitness.pipeline.normalization import ( + MatchMethod, + clean_name, + extract_size_info, + jaccard_similarity, + match_by_name, + match_by_upc, + normalize_product, +) + + +class TestCleanName: + def test_lowercase(self): + assert clean_name("Kroger WHOLE MILK") == "kroger whole milk" + + def test_removes_size_info(self): + assert "oz" not in clean_name("Milk 16 oz Whole") + + def test_removes_noise_words(self): + cleaned = clean_name("The Original Brand Milk") + assert "the" not in cleaned.split() + assert "original" not in cleaned.split() + assert "brand" not in cleaned.split() + + def test_collapses_whitespace(self): + assert " " not in clean_name("Milk Whole Gallon") + + def test_removes_punctuation(self): + cleaned = clean_name("Meijer's Best (Organic) Milk!") + assert "'" not in cleaned + assert "(" not in cleaned + + +class TestExtractSizeInfo: + def test_extracts_oz(self): + result = extract_size_info("Cereal 18 oz box") + assert result == ("18", "oz") + + def test_extracts_fl_oz(self): + result = extract_size_info("Juice 64 fl oz") + assert result == ("64", "fl_oz") + + def test_extracts_lb(self): + result = extract_size_info("Ground Beef 1.5 lb") + assert result == ("1.5", "lb") + + def test_extracts_ct(self): + result = extract_size_info("Eggs Large 12 ct") + assert result == ("12", "ct") + + def test_no_size_returns_none(self): + assert extract_size_info("Bananas") is None + + +class TestJaccardSimilarity: + def test_identical_strings(self): + assert jaccard_similarity("whole milk gallon", "whole milk gallon") == 1.0 + + def test_completely_different(self): + assert jaccard_similarity("apple juice", "ground beef") == 0.0 + + def test_partial_overlap(self): + score = jaccard_similarity("kroger whole milk", "meijer whole milk") + assert 0.4 < score < 0.8 # "whole" and "milk" overlap + + def test_empty_strings(self): + assert jaccard_similarity("", "") == 0.0 + assert jaccard_similarity("milk", "") == 0.0 + + +class TestMatchByUPC: + def test_match_found(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + # SQLite doesn't support JSONB containment — this will raise + # In production (PostgreSQL), this would work + result = match_by_upc(session, "0041250000001") + assert result is not None + assert result.method == MatchMethod.UPC + assert result.confidence == 1.0 + + def test_no_match(self, session): + result = match_by_upc(session, "9999999999999") + assert result is None + + +class TestMatchByName: + def test_exact_name_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Whole Milk Gallon") + assert result is not None + assert result.method == MatchMethod.NAME + assert result.confidence > 0.5 + + def test_fuzzy_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Kroger Whole Milk, 1 Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Meijer Whole Milk 1 Gallon", threshold=0.3) + assert result is not None + assert result.confidence > 0.3 + + def test_no_match_below_threshold(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Apple Juice 64 oz", threshold=0.5) + assert result is None + + +class TestNormalizeProduct: + def test_name_fallback(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs, 12 count", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = normalize_product(session, "Large Eggs 12 ct", upc=None) + assert result is not None + assert result.method == MatchMethod.NAME + + def test_no_match(self, session): + result = normalize_product(session, "Nonexistent Product XYZ", upc=None) + assert result is None diff --git a/tests/test_pipeline/test_receipt.py b/tests/test_pipeline/test_receipt.py new file mode 100644 index 0000000..8210713 --- /dev/null +++ b/tests/test_pipeline/test_receipt.py @@ -0,0 +1,204 @@ +"""Tests for receipt normalization pipeline.""" + +import uuid +from datetime import date +from decimal import Decimal + +from receiptwitness.pipeline.receipt import ( + _clean_product_name, + _safe_decimal, + normalize_receipt, + parse_meijer_item, +) + + +class TestCleanProductName: + def test_strips_whitespace(self): + assert _clean_product_name(" Milk ") == "Milk" + + def test_removes_leading_punctuation(self): + assert _clean_product_name("---Milk---") == "Milk" + + def test_collapses_internal_whitespace(self): + assert _clean_product_name("Whole Milk Gallon") == "Whole Milk Gallon" + + def test_empty_string(self): + assert _clean_product_name("") == "" + + +class TestSafeDecimal: + def test_string_input(self): + assert _safe_decimal("3.99") == Decimal("3.99") + + def test_float_input(self): + assert _safe_decimal(3.99) == Decimal("3.99") + + def test_int_input(self): + assert _safe_decimal(4) == Decimal("4") + + def test_none_returns_default(self): + assert _safe_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _safe_decimal(None, Decimal("1")) == Decimal("1") + + def test_invalid_returns_default(self): + assert _safe_decimal("not-a-number") == Decimal("0") + + def test_decimal_passthrough(self): + assert _safe_decimal(Decimal("5.50")) == Decimal("5.50") + + +class TestParseMeijerItem: + def test_basic_item(self): + raw = { + "description": "Kroger Whole Milk 1 Gallon", + "upc": "0041250000001", + "quantity": 1, + "unitPrice": "3.99", + "extendedPrice": "3.99", + "category": "DAIRY", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Kroger Whole Milk 1 Gallon" + assert item.upc == "41250000001" # leading zeros stripped + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("3.99") + assert item.extended_price == Decimal("3.99") + assert item.category_raw == "DAIRY" + + def test_alternate_field_names(self): + raw = { + "name": "Eggs Large 12 ct", + "upcCode": "012345", + "qty": 2, + "price": "4.50", + "totalPrice": "9.00", + "department": "EGGS", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Eggs Large 12 ct" + assert item.upc == "12345" + assert item.quantity == Decimal("2") + assert item.unit_price == Decimal("4.50") + assert item.extended_price == Decimal("9.00") + assert item.category_raw == "EGGS" + + def test_calculates_extended_from_unit_price(self): + raw = { + "description": "Bananas", + "unitPrice": "0.59", + "quantity": 3, + } + item = parse_meijer_item(raw) + assert item.extended_price == Decimal("1.77") + + def test_discounts_parsed(self): + raw = { + "description": "Cereal", + "unitPrice": "4.99", + "extendedPrice": "4.99", + "regularPrice": "5.99", + "salePrice": "4.99", + "couponAmount": "1.00", + "loyaltyAmount": "0.50", + } + item = parse_meijer_item(raw) + assert item.regular_price == Decimal("5.99") + assert item.sale_price == Decimal("4.99") + assert item.coupon_discount == Decimal("1.00") + assert item.loyalty_discount == Decimal("0.50") + + def test_alternate_discount_names(self): + raw = { + "description": "Bread", + "unitPrice": "2.99", + "extendedPrice": "2.99", + "couponDiscount": "0.75", + "loyaltyDiscount": "0.25", + } + item = parse_meijer_item(raw) + assert item.coupon_discount == Decimal("0.75") + assert item.loyalty_discount == Decimal("0.25") + + def test_missing_fields_default_gracefully(self): + raw = {"description": "Mystery Item"} + item = parse_meijer_item(raw) + assert item.product_name_raw == "Mystery Item" + assert item.upc is None + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("0") + assert item.regular_price is None + assert item.category_raw is None + + def test_no_upc_returns_none(self): + raw = {"description": "Loose Bananas", "unitPrice": "1.00", "extendedPrice": "1.00"} + item = parse_meijer_item(raw) + assert item.upc is None + + +class TestNormalizeReceipt: + def test_full_receipt(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receiptId": "REC-001", + "date": "2026-03-15", + "total": "25.47", + "subtotal": "23.00", + "tax": "2.47", + "savings": "3.00", + "items": [ + {"description": "Milk", "unitPrice": "3.99", "extendedPrice": "3.99"}, + {"description": "Bread", "unitPrice": "2.50", "extendedPrice": "2.50"}, + ], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-001" + assert purchase.purchase_date == date(2026, 3, 15) + assert purchase.total == Decimal("25.47") + assert purchase.subtotal == Decimal("23.00") + assert purchase.tax == Decimal("2.47") + assert purchase.savings_total == Decimal("3.00") + assert len(purchase.items) == 2 + assert purchase.items[0].product_name_raw == "Milk" + assert purchase.raw_data == raw + + def test_alternate_receipt_fields(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receipt_id": "REC-002", + "purchaseDate": "2026-03-14", + "totalAmount": "10.00", + "taxAmount": "0.75", + "totalSavings": "1.50", + "items": [], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-002" + assert purchase.purchase_date == date(2026, 3, 14) + assert purchase.total == Decimal("10.00") + assert purchase.tax == Decimal("0.75") + assert purchase.savings_total == Decimal("1.50") + + def test_missing_date_defaults_to_today(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date.today() + + def test_generates_receipt_id_if_missing(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "date": "2026-03-15", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id # Should be a generated UUID string + + def test_date_object_passthrough(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"date": date(2026, 1, 1), "total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date(2026, 1, 1) diff --git a/tests/test_regression/__init__.py b/tests/test_regression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_regression/test_layout_changes.py b/tests/test_regression/test_layout_changes.py new file mode 100644 index 0000000..7843c43 --- /dev/null +++ b/tests/test_regression/test_layout_changes.py @@ -0,0 +1,435 @@ +"""Regression tests: graceful handling of page layout changes. + +Retailers frequently change their API response structures, field names, +and nesting. These tests verify that both parsers degrade gracefully when +encountering alternative or missing fields — producing valid output +instead of crashing. +""" + +from decimal import Decimal + +from receiptwitness.parsers.kroger import parse_kroger_receipt +from receiptwitness.parsers.meijer import parse_meijer_receipt +from receiptwitness.scrapers.base import RawReceipt + + +class TestKrogerFieldNameVariations: + """Kroger changes field names between app versions and API revisions.""" + + def test_alternative_item_key_line_items(self): + raw = RawReceipt( + receipt_id="KR-ALT-1", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "lineItems": [{"description": "MILK", "basePrice": 3.99, "totalPrice": 3.99}], + "total": 3.99, + } + }, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "MILK" + + def test_alternative_item_key_receipt_items(self): + raw = RawReceipt( + receipt_id="KR-ALT-2", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "receiptItems": [ + {"description": "EGGS", "basePrice": 5.49, "totalPrice": 5.49} + ], + "total": 5.49, + } + }, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "EGGS" + + def test_alternative_description_fields(self): + """Test productName and itemDescription fallbacks.""" + for field in ("productName", "itemDescription", "name"): + raw = RawReceipt( + receipt_id="KR-DESC", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [{field: "TEST PRODUCT", "basePrice": 1.00, "totalPrice": 1.00}], + "total": 1.00, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["product_name_raw"] == "TEST PRODUCT" + + def test_alternative_price_fields(self): + """Test unitPrice and price fallbacks for basePrice.""" + raw = RawReceipt( + receipt_id="KR-PRICE-1", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [{"description": "ITEM A", "unitPrice": 2.50, "totalPrice": 2.50}], + "total": 2.50, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["unit_price"] == Decimal("2.50") + + raw2 = RawReceipt( + receipt_id="KR-PRICE-2", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [{"description": "ITEM B", "price": 4.00, "totalPrice": 4.00}], + "total": 4.00, + } + }, + ) + result2 = parse_kroger_receipt(raw2) + assert result2["items"][0]["unit_price"] == Decimal("4.00") + + def test_alternative_total_fields(self): + """Test orderTotal, grandTotal fallbacks.""" + for field in ("orderTotal", "grandTotal"): + raw = RawReceipt( + receipt_id="KR-TOT", + purchase_date="2026-03-12", + raw_data={field: 42.50, "detail": {}}, + ) + result = parse_kroger_receipt(raw) + assert result["total"] == Decimal("42.50") + + def test_alternative_savings_fields(self): + """Test youSaved and totalDiscount fallbacks.""" + raw = RawReceipt( + receipt_id="KR-SAV-1", + purchase_date="2026-03-12", + raw_data={"youSaved": 5.00, "detail": {}}, + ) + result = parse_kroger_receipt(raw) + assert result["savings_total"] == Decimal("5.00") + + def test_alternative_tax_field(self): + raw = RawReceipt( + receipt_id="KR-TAX", + purchase_date="2026-03-12", + raw_data={"salesTax": 3.25, "detail": {}}, + ) + result = parse_kroger_receipt(raw) + assert result["tax"] == Decimal("3.25") + + def test_alternative_quantity_field_qty(self): + raw = RawReceipt( + receipt_id="KR-QTY", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + {"description": "APPLES", "qty": 5, "basePrice": 1.00, "totalPrice": 5.00} + ], + "total": 5.00, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["quantity"] == Decimal("5") + + def test_alternative_upc_field_kroger_product_id(self): + raw = RawReceipt( + receipt_id="KR-UPC", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "krogerProductId": "12345678", + "basePrice": 1.00, + "totalPrice": 1.00, + } + ], + "total": 1.00, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["upc"] == "12345678" + + def test_missing_extended_price_computed(self): + """When totalPrice is missing, extended_price = unit_price * quantity.""" + raw = RawReceipt( + receipt_id="KR-CALC", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [{"description": "EGGS", "basePrice": 5.49, "quantity": 2}], + "total": 10.98, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["extended_price"] == Decimal("5.49") * Decimal("2") + + +class TestMeijerFieldNameVariations: + """Meijer XHR endpoints may change field names between SPA versions.""" + + def test_alternative_item_key_line_items(self): + raw = RawReceipt( + receipt_id="MJ-ALT-1", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "lineItems": [{"description": "BANANAS", "price": 0.69, "extendedPrice": 0.69}], + "total": 0.69, + } + }, + ) + result = parse_meijer_receipt(raw) + assert len(result["items"]) == 1 + assert result["items"][0]["product_name_raw"] == "BANANAS" + + def test_alternative_description_fields(self): + for field in ("itemDescription", "name"): + raw = RawReceipt( + receipt_id="MJ-DESC", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [{field: "TEST ITEM", "price": 1.00, "extendedPrice": 1.00}], + "total": 1.00, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["product_name_raw"] == "TEST ITEM" + + def test_alternative_price_field_unit_price(self): + raw = RawReceipt( + receipt_id="MJ-PRICE", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [{"description": "MILK", "unitPrice": 3.49, "totalPrice": 3.49}], + "total": 3.49, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["unit_price"] == Decimal("3.49") + + def test_alternative_extended_price_field_total_price(self): + raw = RawReceipt( + receipt_id="MJ-EXT", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [{"description": "CEREAL", "price": 4.99, "totalPrice": 4.99}], + "total": 4.99, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["extended_price"] == Decimal("4.99") + + def test_alternative_total_field_transaction_total(self): + raw = RawReceipt( + receipt_id="MJ-TOT", + purchase_date="2026-03-10", + raw_data={"transactionTotal": 55.00, "detail": {}}, + ) + result = parse_meijer_receipt(raw) + assert result["total"] == Decimal("55.00") + + def test_alternative_loyalty_field(self): + raw = RawReceipt( + receipt_id="MJ-LOY", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "price": 5.00, + "extendedPrice": 5.00, + "loyaltyDiscount": 0.50, + } + ], + "total": 5.00, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["loyalty_discount"] == Decimal("0.50") + + def test_alternative_upc_field_uppercase(self): + raw = RawReceipt( + receipt_id="MJ-UPC", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "UPC": "0012345678", + "price": 1.00, + "extendedPrice": 1.00, + } + ], + "total": 1.00, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["upc"] == "12345678" + + def test_alternative_category_field(self): + raw = RawReceipt( + receipt_id="MJ-CAT", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "price": 1.00, + "extendedPrice": 1.00, + "departmentDescription": "FROZEN", + } + ], + "total": 1.00, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["category_raw"] == "FROZEN" + + def test_missing_extended_price_computed(self): + raw = RawReceipt( + receipt_id="MJ-CALC", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [{"description": "MILK", "price": 3.49, "quantity": 2}], + "total": 6.98, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["extended_price"] == Decimal("3.49") * Decimal("2") + + def test_missing_description_fallback(self): + raw = RawReceipt( + receipt_id="MJ-NODESC", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [{"price": 1.00, "extendedPrice": 1.00}], + "total": 1.00, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["product_name_raw"] == "UNKNOWN ITEM" + + +class TestMixedFieldVersions: + """Test receipts that mix field naming conventions (happens during rollouts).""" + + def test_kroger_mixed_item_fields(self): + """Some items use old names, some use new names in same receipt.""" + raw = RawReceipt( + receipt_id="KR-MIX", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + {"description": "OLD STYLE", "basePrice": 2.00, "totalPrice": 2.00}, + {"productName": "NEW STYLE", "unitPrice": 3.00, "extendedAmount": 3.00}, + ], + "total": 5.00, + } + }, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 2 + assert result["items"][0]["product_name_raw"] == "OLD STYLE" + assert result["items"][0]["unit_price"] == Decimal("2.00") + assert result["items"][1]["product_name_raw"] == "NEW STYLE" + assert result["items"][1]["unit_price"] == Decimal("3.00") + + def test_kroger_completely_unknown_structure_no_crash(self): + """Receipt with unrecognized structure should return empty items.""" + raw = RawReceipt( + receipt_id="KR-UNKNOWN", + purchase_date="2026-03-12", + raw_data={"something_unexpected": [1, 2, 3], "detail": {"foo": "bar"}}, + ) + result = parse_kroger_receipt(raw) + assert result["receipt_id"] == "KR-UNKNOWN" + assert result["items"] == [] + + def test_meijer_completely_unknown_structure_no_crash(self): + raw = RawReceipt( + receipt_id="MJ-UNKNOWN", + purchase_date="2026-03-10", + raw_data={"something_unexpected": [1, 2, 3], "detail": {"foo": "bar"}}, + ) + result = parse_meijer_receipt(raw) + assert result["receipt_id"] == "MJ-UNKNOWN" + assert result["items"] == [] + + def test_kroger_null_fields_no_crash(self): + """Fields with None values should be handled gracefully.""" + raw = RawReceipt( + receipt_id="KR-NULL", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "basePrice": None, + "totalPrice": None, + "quantity": None, + "upc": None, + "department": None, + } + ], + "total": None, + "subtotal": None, + "tax": None, + } + }, + ) + result = parse_kroger_receipt(raw) + assert result["items"][0]["product_name_raw"] == "ITEM" + assert result["items"][0]["unit_price"] == Decimal("0") + + def test_meijer_null_fields_no_crash(self): + raw = RawReceipt( + receipt_id="MJ-NULL", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + { + "description": "ITEM", + "price": None, + "extendedPrice": None, + "quantity": None, + "upc": None, + "category": None, + } + ], + "total": None, + } + }, + ) + result = parse_meijer_receipt(raw) + assert result["items"][0]["product_name_raw"] == "ITEM" + assert result["items"][0]["unit_price"] == Decimal("0") diff --git a/tests/test_regression/test_rate_limiting.py b/tests/test_regression/test_rate_limiting.py new file mode 100644 index 0000000..1c55495 --- /dev/null +++ b/tests/test_regression/test_rate_limiting.py @@ -0,0 +1,365 @@ +"""Regression tests: rate limiting and retry behavior. + +Validates that scrapers enforce human-like delays between requests +and handle rate-limit/error responses gracefully without infinite retries. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, patch + +import pytest + +from receiptwitness.scrapers.base import SessionData +from receiptwitness.scrapers.kroger import DEFAULT_USER_AGENT, KrogerScraper +from receiptwitness.scrapers.meijer import MeijerScraper + + +class TestHumanDelayBehavior: + """Verify that human_delay respects configured bounds.""" + + @pytest.mark.asyncio + async def test_delay_within_bounds(self): + """human_delay should sleep between min_ms/1000 and max_ms/1000 seconds.""" + scraper = KrogerScraper() + sleep_path = "receiptwitness.scrapers.base.asyncio.sleep" + with patch(sleep_path, new_callable=AsyncMock) as mock_sleep: + await scraper.human_delay(100, 200) + mock_sleep.assert_called_once() + delay = mock_sleep.call_args[0][0] + assert 0.1 <= delay <= 0.2 + + @pytest.mark.asyncio + async def test_delay_uses_settings_defaults(self): + """Without explicit args, should use settings.min/max_request_delay_ms.""" + scraper = MeijerScraper() + sleep_path = "receiptwitness.scrapers.base.asyncio.sleep" + with ( + patch("receiptwitness.scrapers.base.settings") as mock_settings, + patch(sleep_path, new_callable=AsyncMock) as mock_sleep, + ): + mock_settings.min_request_delay_ms = 1000 + mock_settings.max_request_delay_ms = 5000 + await scraper.human_delay() + mock_sleep.assert_called_once() + delay = mock_sleep.call_args[0][0] + assert 1.0 <= delay <= 5.0 + + @pytest.mark.asyncio + async def test_delay_is_randomized(self): + """Multiple calls should produce different delays (probabilistic).""" + scraper = KrogerScraper() + delays = [] + sleep_path2 = "receiptwitness.scrapers.base.asyncio.sleep" + with patch(sleep_path2, new_callable=AsyncMock) as mock_sleep: + for _ in range(20): + await scraper.human_delay(100, 5000) + delays.append(mock_sleep.call_args[0][0]) + # With range 100-5000ms, 20 calls should have at least 2 distinct values + assert len(set(delays)) >= 2 + + +class TestKrogerRateLimiting: + """Verify Kroger scraper calls human_delay between receipt fetches.""" + + @pytest.mark.asyncio + async def test_delay_called_between_receipts(self): + """Scraper must call human_delay for each receipt detail fetch.""" + scraper = KrogerScraper() + valid_session = SessionData( + cookies=[{"name": "s", "value": "v", "domain": ".kroger.com", "path": "/"}], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=2), + ) + + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "orders": [ + { + "orderId": f"KR-{i}", + "purchaseDate": "2026-03-10T14:00:00Z", + "storeNumber": "357", + } + for i in range(3) + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response] + [mock_detail_response] * 3) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock) as mock_delay, + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 3 + # human_delay called at least once per receipt (after initial page nav) + # Plus once for the initial navigation delay + assert mock_delay.call_count >= 3 + + +class TestMeijerRateLimiting: + """Verify Meijer scraper calls human_delay between receipt fetches.""" + + @pytest.mark.asyncio + async def test_delay_called_between_receipts(self): + scraper = MeijerScraper() + valid_session = SessionData( + cookies=[{"name": "s", "value": "v", "domain": ".meijer.com", "path": "/"}], + user_agent="test", + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=4), + ) + + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + { + "transactionId": f"TXN-{i}", + "transactionDate": "2026-03-10T14:00:00Z", + "storeNumber": "42", + } + for i in range(3) + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response] + [mock_detail_response] * 3) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock) as mock_delay, + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 3 + assert mock_delay.call_count >= 3 + + +class TestGracefulErrorRecovery: + """Scrapers should not retry endlessly on errors.""" + + @pytest.mark.asyncio + async def test_kroger_api_500_returns_empty_not_retry(self): + """500 error should return empty list, not retry.""" + scraper = KrogerScraper() + valid_session = SessionData( + cookies=[{"name": "s", "value": "v", "domain": ".kroger.com", "path": "/"}], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=2), + ) + + mock_api_response = AsyncMock() + mock_api_response.ok = False + mock_api_response.status = 500 + mock_api_response.status_text = "Internal Server Error" + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + # Should only call the API once — no retries + assert mock_request.get.call_count == 1 + + @pytest.mark.asyncio + async def test_kroger_429_returns_empty_not_retry(self): + """Rate limit (429) should return empty, not retry.""" + scraper = KrogerScraper() + valid_session = SessionData( + cookies=[{"name": "s", "value": "v", "domain": ".kroger.com", "path": "/"}], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=2), + ) + + mock_api_response = AsyncMock() + mock_api_response.ok = False + mock_api_response.status = 429 + mock_api_response.status_text = "Too Many Requests" + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + assert mock_request.get.call_count == 1 + + @pytest.mark.asyncio + async def test_meijer_detail_exception_continues(self): + """Exception fetching one receipt detail should not abort remaining receipts.""" + scraper = MeijerScraper() + valid_session = SessionData( + cookies=[{"name": "s", "value": "v", "domain": ".meijer.com", "path": "/"}], + user_agent="test", + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=4), + ) + + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + { + "transactionId": "TXN-1", + "transactionDate": "2026-03-10T14:00:00Z", + "storeNumber": "42", + }, + { + "transactionId": "TXN-2", + "transactionDate": "2026-03-11T10:00:00Z", + "storeNumber": "42", + }, + ] + } + ) + + # First detail call raises exception, second succeeds + mock_detail_fail = AsyncMock() + mock_detail_fail.ok = False + mock_detail_fail.status = 500 + + mock_detail_ok = AsyncMock() + mock_detail_ok.ok = True + mock_detail_ok.json = AsyncMock(return_value={"items": []}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock( + side_effect=[mock_api_response, mock_detail_fail, mock_detail_ok] + ) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + # Both receipts should be returned — the first with empty detail + assert len(receipts) == 2 + assert receipts[0].raw_data.get("detail") == {} + assert receipts[1].receipt_id == "TXN-2" diff --git a/tests/test_regression/test_schema_validation.py b/tests/test_regression/test_schema_validation.py new file mode 100644 index 0000000..8dfb10e --- /dev/null +++ b/tests/test_regression/test_schema_validation.py @@ -0,0 +1,364 @@ +"""Regression tests: scraper output matches expected schema. + +Validates that parsed receipts from both Kroger and Meijer conform to the +PurchaseCreate schema contract. Uses recorded fixtures to ensure outputs +remain stable across code changes. +""" + +from decimal import Decimal + +from receiptwitness.parsers.kroger import parse_kroger_receipt +from receiptwitness.parsers.meijer import parse_meijer_receipt +from receiptwitness.scrapers.base import RawReceipt + +# Required top-level keys in a parsed receipt +RECEIPT_REQUIRED_KEYS = {"receipt_id", "purchase_date", "total", "items", "raw_data"} +RECEIPT_OPTIONAL_KEYS = {"subtotal", "tax", "savings_total", "source_url"} + +# Required keys in each parsed item +ITEM_REQUIRED_KEYS = { + "product_name_raw", + "upc", + "quantity", + "unit_price", + "extended_price", +} +ITEM_OPTIONAL_KEYS = { + "regular_price", + "sale_price", + "coupon_discount", + "loyalty_discount", + "category_raw", +} + + +def _validate_receipt_schema(result: dict) -> None: + """Assert that a parsed receipt dict conforms to the expected schema.""" + # All required keys present + for key in RECEIPT_REQUIRED_KEYS: + assert key in result, f"Missing required key: {key}" + + # Types + assert isinstance(result["receipt_id"], str) + assert isinstance(result["purchase_date"], str) + assert isinstance(result["total"], Decimal) + assert isinstance(result["items"], list) + assert isinstance(result["raw_data"], dict) + + # Optional keys should be correct types when present + if result.get("subtotal") is not None: + assert isinstance(result["subtotal"], Decimal) + if result.get("tax") is not None: + assert isinstance(result["tax"], Decimal) + if result.get("savings_total") is not None: + assert isinstance(result["savings_total"], Decimal) + if result.get("source_url") is not None: + assert isinstance(result["source_url"], str) + + # No unexpected keys + all_keys = RECEIPT_REQUIRED_KEYS | RECEIPT_OPTIONAL_KEYS + for key in result: + assert key in all_keys, f"Unexpected key in receipt: {key}" + + +def _validate_item_schema(item: dict) -> None: + """Assert that a parsed item dict conforms to the expected schema.""" + for key in ITEM_REQUIRED_KEYS: + assert key in item, f"Missing required item key: {key}" + + assert isinstance(item["product_name_raw"], str) + assert len(item["product_name_raw"]) > 0 + assert isinstance(item["quantity"], Decimal) + assert isinstance(item["unit_price"], Decimal) + assert isinstance(item["extended_price"], Decimal) + + # UPC can be None or str + if item["upc"] is not None: + assert isinstance(item["upc"], str) + # UPC should not have leading zeros (stripped during parsing) + assert not item["upc"].startswith("0"), f"UPC has leading zeros: {item['upc']}" + + # Optional Decimal fields + for opt_key in ("regular_price", "sale_price", "coupon_discount", "loyalty_discount"): + if item.get(opt_key) is not None: + assert isinstance(item[opt_key], Decimal), f"{opt_key} should be Decimal" + + if item.get("category_raw") is not None: + assert isinstance(item["category_raw"], str) + + # No unexpected keys + all_keys = ITEM_REQUIRED_KEYS | ITEM_OPTIONAL_KEYS + for key in item: + assert key in all_keys, f"Unexpected key in item: {key}" + + +class TestKrogerSchemaValidation: + def test_full_receipt_schema(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + store_number="00357", + raw_data=kroger_receipt_data, + source_url="https://www.kroger.com/atlas/v1/receipt/api?orderId=KR-2026-0312-4471", + ) + result = parse_kroger_receipt(raw) + _validate_receipt_schema(result) + for item in result["items"]: + _validate_item_schema(item) + + def test_item_count_excludes_voided_and_returned(self, kroger_receipt_data): + """Fixture has 10 items, 2 should be excluded (voided + returned).""" + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + assert len(result["items"]) == 8 + + def test_totals_are_positive_decimals(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + assert result["total"] > Decimal("0") + assert result["subtotal"] > Decimal("0") + assert result["tax"] > Decimal("0") + assert result["savings_total"] > Decimal("0") + + def test_receipt_id_preserved(self, kroger_receipt_data): + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + assert result["receipt_id"] == "KR-2026-0312-4471" + + def test_known_product_prices(self, kroger_receipt_data): + """Verify specific products produce correct price extraction.""" + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + items_by_name = {i["product_name_raw"]: i for i in result["items"]} + + # Milk: $3.99, regular $4.29 + milk = items_by_name["KROGER WHOLE MILK GAL"] + assert milk["unit_price"] == Decimal("3.99") + assert milk["regular_price"] == Decimal("4.29") + assert milk["sale_price"] == Decimal("3.99") + + # Eggs: qty 2, $5.49 each, total $10.98 + eggs = items_by_name["SIMPLE TRUTH ORG EGGS 12CT"] + assert eggs["quantity"] == Decimal("2") + assert eggs["unit_price"] == Decimal("5.49") + assert eggs["extended_price"] == Decimal("10.98") + + # Deli turkey: weighted item, 0.68 lb + turkey = items_by_name["KROGER DELI TURKEY BREAST"] + assert turkey["quantity"] == Decimal("0.68") + assert turkey["upc"] is None + + def test_multi_quantity_item_correct(self, kroger_receipt_data): + """Pasta is qty=3, unit=$2.49, total=$7.47.""" + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + pasta = [i for i in result["items"] if "PASTA" in i["product_name_raw"]][0] + assert pasta["quantity"] == Decimal("3") + assert pasta["unit_price"] == Decimal("2.49") + assert pasta["extended_price"] == Decimal("7.47") + + def test_coupon_discount_captured(self, kroger_receipt_data): + """Tide Pods has $2.00 coupon.""" + raw = RawReceipt( + receipt_id="KR-2026-0312-4471", + purchase_date="2026-03-12T16:45:00Z", + raw_data=kroger_receipt_data, + ) + result = parse_kroger_receipt(raw) + tide = [i for i in result["items"] if "TIDE" in i["product_name_raw"]][0] + assert tide["coupon_discount"] == Decimal("2.00") + + +class TestMeijerSchemaValidation: + def test_full_receipt_schema(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + store_number="42", + raw_data=meijer_receipt_data, + source_url="https://www.meijer.com/bin/meijer/profile/receipt?receiptId=TXN-2026-0310-001", + ) + result = parse_meijer_receipt(raw) + _validate_receipt_schema(result) + for item in result["items"]: + _validate_item_schema(item) + + def test_item_count_excludes_voided(self, meijer_receipt_data): + """Fixture has 6 items, 1 should be excluded (voided soda).""" + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + assert len(result["items"]) == 5 + + def test_totals_are_positive_decimals(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + assert result["total"] > Decimal("0") + assert result["subtotal"] > Decimal("0") + assert result["tax"] > Decimal("0") + assert result["savings_total"] > Decimal("0") + + def test_receipt_id_preserved(self, meijer_receipt_data): + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + assert result["receipt_id"] == "TXN-2026-0310-001" + + def test_known_product_prices(self, meijer_receipt_data): + """Verify specific Meijer products produce correct price extraction.""" + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + items_by_name = {i["product_name_raw"]: i for i in result["items"]} + + # Bananas: $0.69 + bananas = items_by_name["ORGANIC BANANAS"] + assert bananas["unit_price"] == Decimal("0.69") + assert bananas["mperks_discount"] if "mperks_discount" in bananas else True + assert bananas["loyalty_discount"] == Decimal("0.10") + + # Milk: qty 2, $3.49 each, total $6.98 + milk = items_by_name["MEIJER 2% MILK GAL"] + assert milk["quantity"] == Decimal("2") + assert milk["unit_price"] == Decimal("3.49") + assert milk["extended_price"] == Decimal("6.98") + + # Weighted deli turkey: 0.75 lb at $8.99/lb + turkey = items_by_name["WEIGHTED DELI TURKEY"] + assert turkey["quantity"] == Decimal("0.75") + assert turkey["upc"] is None + + def test_mperks_discount_captured(self, meijer_receipt_data): + """Paper towels has $1.00 mPerks discount.""" + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + towels = [i for i in result["items"] if "PAPER TOWELS" in i["product_name_raw"]][0] + assert towels["loyalty_discount"] == Decimal("1.00") + assert towels["coupon_discount"] == Decimal("1.00") + + def test_cheerios_coupon_discount(self, meijer_receipt_data): + """Cheerios has $0.50 coupon.""" + raw = RawReceipt( + receipt_id="TXN-2026-0310-001", + purchase_date="2026-03-10T14:30:00Z", + raw_data=meijer_receipt_data, + ) + result = parse_meijer_receipt(raw) + cheerios = [i for i in result["items"] if "CHEERIOS" in i["product_name_raw"]][0] + assert cheerios["coupon_discount"] == Decimal("0.50") + + +class TestEmptyAndEdgeCaseSchemas: + """Regression tests for edge-case receipts that should not crash.""" + + def test_kroger_empty_receipt(self): + raw = RawReceipt(receipt_id="KR-EMPTY", purchase_date="2026-03-12", raw_data={}) + result = parse_kroger_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] + assert result["total"] == Decimal("0") + + def test_meijer_empty_receipt(self): + raw = RawReceipt(receipt_id="MJ-EMPTY", purchase_date="2026-03-10", raw_data={}) + result = parse_meijer_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] + assert result["total"] == Decimal("0") + + def test_kroger_receipt_no_detail(self): + raw = RawReceipt( + receipt_id="KR-NODET", + purchase_date="2026-03-12", + raw_data={"total": 50.00}, + ) + result = parse_kroger_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] + assert result["total"] == Decimal("50.00") + + def test_meijer_receipt_no_detail(self): + raw = RawReceipt( + receipt_id="MJ-NODET", + purchase_date="2026-03-10", + raw_data={"total": 30.00}, + ) + result = parse_meijer_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] + assert result["total"] == Decimal("30.00") + + def test_kroger_receipt_all_voided(self): + """A receipt where every item is voided should have 0 items.""" + raw = RawReceipt( + receipt_id="KR-ALLVOID", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + {"description": "VOIDED A", "basePrice": 5.0, "voided": True}, + {"description": "VOIDED B", "basePrice": 3.0, "status": "VOIDED"}, + {"description": "RETURNED C", "basePrice": 7.0, "status": "RETURNED"}, + {"description": "RETURNED D", "basePrice": 2.0, "returnFlag": True}, + ], + "total": 0, + } + }, + ) + result = parse_kroger_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] + + def test_meijer_receipt_all_voided(self): + raw = RawReceipt( + receipt_id="MJ-ALLVOID", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + {"description": "VOIDED A", "price": 5.0, "voided": True}, + {"description": "VOIDED B", "price": 3.0, "status": "VOIDED"}, + ], + "total": 0, + } + }, + ) + result = parse_meijer_receipt(raw) + _validate_receipt_schema(result) + assert result["items"] == [] diff --git a/tests/test_scrapers/__init__.py b/tests/test_scrapers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_scrapers/test_base.py b/tests/test_scrapers/test_base.py new file mode 100644 index 0000000..d0cabac --- /dev/null +++ b/tests/test_scrapers/test_base.py @@ -0,0 +1,58 @@ +"""Tests for the base scraper class.""" + +from datetime import datetime +from unittest.mock import patch + +import pytest + +from receiptwitness.scrapers.base import BaseScraper, RawReceipt, SessionData + + +class ConcreteScraper(BaseScraper): + """Concrete implementation for testing the abstract base.""" + + async def login(self, username, password): + return SessionData( + cookies=[], + user_agent="test", + created_at=datetime.now(), + ) + + async def check_session(self, session): + return True + + async def scrape_receipts(self, session, since=None): + return [] + + def parse_receipt(self, raw): + return {} + + +class TestBaseScraper: + @pytest.mark.asyncio + async def test_human_delay_respects_bounds(self): + scraper = ConcreteScraper() + with patch("receiptwitness.scrapers.base.asyncio.sleep") as mock_sleep: + mock_sleep.return_value = None + await scraper.human_delay(min_ms=100, max_ms=200) + call_args = mock_sleep.call_args[0][0] + assert 0.1 <= call_args <= 0.2 + + def test_raw_receipt_dataclass(self): + receipt = RawReceipt( + receipt_id="test-123", + purchase_date="2026-03-10", + store_number="42", + raw_data={"key": "value"}, + ) + assert receipt.receipt_id == "test-123" + assert receipt.raw_data == {"key": "value"} + + def test_session_data_defaults(self): + session = SessionData( + cookies=[], + user_agent="test", + created_at=datetime.now(), + ) + assert session.expires_at is None + assert session.extra == {} diff --git a/tests/test_scrapers/test_kroger_scraper.py b/tests/test_scrapers/test_kroger_scraper.py new file mode 100644 index 0000000..3a88516 --- /dev/null +++ b/tests/test_scrapers/test_kroger_scraper.py @@ -0,0 +1,574 @@ +"""Tests for the Kroger scraper. + +These tests mock Playwright to avoid requiring real Kroger credentials +or network access. They verify the scraper's control flow, session handling, +date filtering, and error resilience. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from receiptwitness.scrapers.base import RawReceipt, SessionData +from receiptwitness.scrapers.kroger import ( + DEFAULT_TIMEZONE, + DEFAULT_USER_AGENT, + DEFAULT_VIEWPORT, + KROGER_BASE, + KROGER_LOGIN_PAGE, + KROGER_PURCHASE_HISTORY, + KrogerScraper, +) + + +@pytest.fixture +def scraper(): + return KrogerScraper() + + +@pytest.fixture +def valid_session(): + return SessionData( + cookies=[{"name": "session", "value": "abc123", "domain": ".kroger.com", "path": "/"}], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=2), + extra={"retailer": "kroger"}, + ) + + +@pytest.fixture +def expired_session(): + return SessionData( + cookies=[{"name": "session", "value": "expired", "domain": ".kroger.com", "path": "/"}], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC) - timedelta(hours=4), + expires_at=datetime.now(UTC) - timedelta(hours=2), + ) + + +class TestKrogerScraperConstants: + def test_base_url(self): + assert KROGER_BASE == "https://www.kroger.com" + + def test_login_page(self): + assert KROGER_LOGIN_PAGE == "https://www.kroger.com/signin" + + def test_purchase_history_page(self): + assert KROGER_PURCHASE_HISTORY == "https://www.kroger.com/mypurchases" + + def test_default_user_agent_is_chrome(self): + assert "Chrome" in DEFAULT_USER_AGENT + assert "Windows" in DEFAULT_USER_AGENT + + def test_default_viewport_hd(self): + assert DEFAULT_VIEWPORT == {"width": 1920, "height": 1080} + + def test_default_timezone(self): + assert DEFAULT_TIMEZONE == "America/New_York" + + +class TestCheckSession: + @pytest.mark.asyncio + async def test_expired_session_returns_false(self, scraper, expired_session): + result = await scraper.check_session(expired_session) + assert result is False + + @pytest.mark.asyncio + async def test_no_expiry_checks_via_browser(self, scraper): + session = SessionData( + cookies=[], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=None, + ) + mock_page = AsyncMock() + mock_page.url = "https://www.kroger.com/account/dashboard" + mock_response = MagicMock() + mock_response.ok = True + mock_page.goto = AsyncMock(return_value=mock_response) + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw: + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + result = await scraper.check_session(session) + assert result is True + + @pytest.mark.asyncio + async def test_session_redirected_to_signin_returns_false(self, scraper): + session = SessionData( + cookies=[], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=None, + ) + mock_page = AsyncMock() + mock_page.url = "https://www.kroger.com/signin?redirectUrl=account" + mock_response = MagicMock() + mock_response.ok = True + mock_page.goto = AsyncMock(return_value=mock_response) + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw: + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + result = await scraper.check_session(session) + assert result is False + + +class TestLogin: + @pytest.mark.asyncio + async def test_login_returns_session_data(self, scraper): + mock_page = AsyncMock() + mock_page.url = "https://www.kroger.com/" + + # Mock locator chain + mock_email = AsyncMock() + mock_password = AsyncMock() + mock_button = AsyncMock() + mock_page.locator = MagicMock(side_effect=[mock_email, mock_password, mock_button]) + mock_page.wait_for_url = AsyncMock() + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.cookies = AsyncMock( + return_value=[ + {"name": "kroger_session", "value": "test123", "domain": ".kroger.com", "path": "/"} + ] + ) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + session = await scraper.login("user@test.com", "password123") + + assert isinstance(session, SessionData) + assert len(session.cookies) == 1 + assert session.cookies[0]["name"] == "kroger_session" + assert session.user_agent == DEFAULT_USER_AGENT + assert session.expires_at is not None + assert session.extra == {"retailer": "kroger"} + + +class TestScrapeReceipts: + @pytest.mark.asyncio + async def test_scrape_returns_receipts(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.status = 200 + mock_api_response.json = AsyncMock( + return_value={ + "orders": [ + { + "orderId": "KR-001", + "purchaseDate": "2026-03-10T14:00:00Z", + "storeNumber": "357", + }, + { + "orderId": "KR-002", + "purchaseDate": "2026-03-11T10:00:00Z", + "storeNumber": "357", + }, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={"items": []}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock( + side_effect=[mock_api_response, mock_detail_response, mock_detail_response] + ) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 2 + assert receipts[0].receipt_id == "KR-001" + assert receipts[1].receipt_id == "KR-002" + assert isinstance(receipts[0], RawReceipt) + + @pytest.mark.asyncio + async def test_scrape_filters_by_date(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "orders": [ + { + "orderId": "KR-OLD", + "purchaseDate": "2026-01-01T10:00:00Z", + "storeNumber": "357", + }, + { + "orderId": "KR-NEW", + "purchaseDate": "2026-03-15T10:00:00Z", + "storeNumber": "357", + }, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + since = datetime(2026, 3, 1, tzinfo=UTC) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session, since=since) + + assert len(receipts) == 1 + assert receipts[0].receipt_id == "KR-NEW" + + @pytest.mark.asyncio + async def test_scrape_handles_api_failure(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = False + mock_api_response.status = 500 + mock_api_response.status_text = "Internal Server Error" + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + + @pytest.mark.asyncio + async def test_scrape_handles_unexpected_response(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock(return_value="not a dict") + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + + @pytest.mark.asyncio + async def test_scrape_alternative_field_names(self, scraper, valid_session): + """Kroger may use 'purchases' instead of 'orders'.""" + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "purchases": [ + { + "receiptId": "KR-ALT-001", + "transactionDate": "2026-03-10T14:00:00Z", + "divisionNumber": "014", + } + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 1 + assert receipts[0].receipt_id == "KR-ALT-001" + + @pytest.mark.asyncio + async def test_scrape_skips_orders_without_id(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "orders": [ + {"purchaseDate": "2026-03-10T14:00:00Z"}, # no id + {"orderId": "KR-VALID", "purchaseDate": "2026-03-10T14:00:00Z"}, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert len(receipts) == 1 + assert receipts[0].receipt_id == "KR-VALID" + + @pytest.mark.asyncio + async def test_scrape_skips_orders_with_null_id(self, scraper, valid_session): + """Ensure orderId: null doesn't produce receipt_id='None' (str(None) bug).""" + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "orders": [ + {"orderId": None, "receiptId": None, "purchaseDate": "2026-03-10T14:00:00Z"}, + {"orderId": "KR-REAL", "purchaseDate": "2026-03-10T14:00:00Z"}, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.kroger.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert len(receipts) == 1 + assert receipts[0].receipt_id == "KR-REAL" + # Verify no receipt has the string "None" as its ID + assert all(r.receipt_id != "None" for r in receipts) + + +class TestParseReceipt: + def test_parse_receipt_delegates_to_parser(self, scraper): + raw = RawReceipt( + receipt_id="KR-001", + purchase_date="2026-03-12", + raw_data={ + "detail": { + "items": [ + { + "description": "TEST ITEM", + "basePrice": 5.00, + "totalPrice": 5.00, + } + ], + "total": 5.00, + } + }, + ) + result = scraper.parse_receipt(raw) + assert result["receipt_id"] == "KR-001" + assert len(result["items"]) == 1 + + def test_receipt_detail_failure_returns_empty(self, scraper): + """Verify receipt detail failures produce empty detail.""" + raw = RawReceipt( + receipt_id="KR-FAIL", + purchase_date="2026-03-12", + raw_data={"total": 10.00, "detail": {}}, + ) + result = scraper.parse_receipt(raw) + assert result["receipt_id"] == "KR-FAIL" + assert result["items"] == [] diff --git a/tests/test_scrapers/test_meijer_scraper.py b/tests/test_scrapers/test_meijer_scraper.py new file mode 100644 index 0000000..05664e1 --- /dev/null +++ b/tests/test_scrapers/test_meijer_scraper.py @@ -0,0 +1,585 @@ +"""Tests for the Meijer scraper. + +These tests mock Playwright to avoid requiring real Meijer credentials +or network access. They verify the scraper's control flow, session handling, +date filtering, and error resilience. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from receiptwitness.scrapers.base import RawReceipt, SessionData +from receiptwitness.scrapers.meijer import ( + DEFAULT_TIMEZONE, + DEFAULT_USER_AGENT, + DEFAULT_VIEWPORT, + MEIJER_BASE, + MEIJER_LOGIN_PAGE, + MEIJER_MPERKS_HOME, + MEIJER_PURCHASE_HISTORY, + MeijerScraper, +) + + +@pytest.fixture +def scraper(): + return MeijerScraper() + + +@pytest.fixture +def valid_session(): + return SessionData( + cookies=[ + {"name": "meijer_session", "value": "abc123", "domain": ".meijer.com", "path": "/"} + ], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=4), + ) + + +@pytest.fixture +def expired_session(): + return SessionData( + cookies=[ + {"name": "meijer_session", "value": "expired", "domain": ".meijer.com", "path": "/"} + ], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC) - timedelta(hours=8), + expires_at=datetime.now(UTC) - timedelta(hours=4), + ) + + +class TestMeijerScraperConstants: + def test_base_url(self): + assert MEIJER_BASE == "https://www.meijer.com" + + def test_login_page(self): + assert MEIJER_LOGIN_PAGE == "https://www.meijer.com/shopping/login.html" + + def test_mperks_home(self): + assert MEIJER_MPERKS_HOME == "https://www.meijer.com/mperks.html" + + def test_purchase_history_url(self): + assert ( + MEIJER_PURCHASE_HISTORY == "https://www.meijer.com/bin/meijer/profile/purchasehistory" + ) + + def test_default_user_agent_is_chrome(self): + assert "Chrome" in DEFAULT_USER_AGENT + assert "Windows" in DEFAULT_USER_AGENT + + def test_default_viewport_hd(self): + assert DEFAULT_VIEWPORT == {"width": 1920, "height": 1080} + + def test_default_timezone(self): + assert DEFAULT_TIMEZONE == "America/Detroit" + + +class TestCheckSession: + @pytest.mark.asyncio + async def test_expired_session_returns_false(self, scraper, expired_session): + result = await scraper.check_session(expired_session) + assert result is False + + @pytest.mark.asyncio + async def test_no_expiry_checks_via_browser(self, scraper): + session = SessionData( + cookies=[], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=None, + ) + mock_page = AsyncMock() + mock_page.url = "https://www.meijer.com/mperks.html" + mock_response = MagicMock() + mock_response.ok = True + mock_page.goto = AsyncMock(return_value=mock_response) + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw: + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + result = await scraper.check_session(session) + assert result is True + + @pytest.mark.asyncio + async def test_session_redirected_to_login_returns_false(self, scraper): + session = SessionData( + cookies=[], + user_agent=DEFAULT_USER_AGENT, + created_at=datetime.now(UTC), + expires_at=None, + ) + mock_page = AsyncMock() + mock_page.url = "https://www.meijer.com/shopping/login.html?redirect=mperks" + mock_response = MagicMock() + mock_response.ok = True + mock_page.goto = AsyncMock(return_value=mock_response) + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw: + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + result = await scraper.check_session(session) + assert result is False + + +class TestLogin: + @pytest.mark.asyncio + async def test_login_returns_session_data(self, scraper): + mock_page = AsyncMock() + mock_page.url = "https://www.meijer.com/mperks.html" + + # Mock locator chain + mock_email = AsyncMock() + mock_password = AsyncMock() + mock_button = AsyncMock() + mock_page.locator = MagicMock(side_effect=[mock_email, mock_password, mock_button]) + mock_page.wait_for_url = AsyncMock() + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.cookies = AsyncMock( + return_value=[ + {"name": "meijer_session", "value": "test456", "domain": ".meijer.com", "path": "/"} + ] + ) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + session = await scraper.login("user@test.com", "password123") + + assert isinstance(session, SessionData) + assert len(session.cookies) == 1 + assert session.cookies[0]["name"] == "meijer_session" + assert session.user_agent == DEFAULT_USER_AGENT + assert session.expires_at is not None + # Meijer sessions last 4 hours + assert session.expires_at > session.created_at + timedelta(hours=3) + + +class TestScrapeReceipts: + @pytest.mark.asyncio + async def test_scrape_returns_receipts(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.status = 200 + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + { + "transactionId": "TXN-001", + "transactionDate": "2026-03-10T14:00:00Z", + "storeNumber": "42", + }, + { + "transactionId": "TXN-002", + "transactionDate": "2026-03-11T10:00:00Z", + "storeNumber": "42", + }, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={"items": []}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock( + side_effect=[mock_api_response, mock_detail_response, mock_detail_response] + ) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 2 + assert receipts[0].receipt_id == "TXN-001" + assert receipts[1].receipt_id == "TXN-002" + assert isinstance(receipts[0], RawReceipt) + + @pytest.mark.asyncio + async def test_scrape_filters_by_date(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + { + "transactionId": "TXN-OLD", + "transactionDate": "2026-01-01T10:00:00Z", + "storeNumber": "42", + }, + { + "transactionId": "TXN-NEW", + "transactionDate": "2026-03-15T10:00:00Z", + "storeNumber": "42", + }, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + since = datetime(2026, 3, 1, tzinfo=UTC) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session, since=since) + + assert len(receipts) == 1 + assert receipts[0].receipt_id == "TXN-NEW" + + @pytest.mark.asyncio + async def test_scrape_handles_api_failure(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = False + mock_api_response.status = 500 + mock_api_response.status_text = "Internal Server Error" + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + + @pytest.mark.asyncio + async def test_scrape_handles_unexpected_response(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock(return_value="not a dict") + + mock_request = AsyncMock() + mock_request.get = AsyncMock(return_value=mock_api_response) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert receipts == [] + + @pytest.mark.asyncio + async def test_scrape_alternative_field_names(self, scraper, valid_session): + """Meijer may use 'purchaseHistory' instead of 'transactions'.""" + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "purchaseHistory": [ + { + "receiptId": "MJ-ALT-001", + "purchaseDate": "2026-03-10T14:00:00Z", + "storeId": "99", + } + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + + assert len(receipts) == 1 + assert receipts[0].receipt_id == "MJ-ALT-001" + + @pytest.mark.asyncio + async def test_scrape_skips_transactions_without_id(self, scraper, valid_session): + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + {"transactionDate": "2026-03-10T14:00:00Z"}, # no id + {"transactionId": "TXN-VALID", "transactionDate": "2026-03-10T14:00:00Z"}, + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = True + mock_detail_response.json = AsyncMock(return_value={}) + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert len(receipts) == 1 + assert receipts[0].receipt_id == "TXN-VALID" + + @pytest.mark.asyncio + async def test_scrape_receipt_detail_failure_returns_empty_detail(self, scraper, valid_session): + """Receipt detail API failure should not crash the scraper.""" + mock_api_response = AsyncMock() + mock_api_response.ok = True + mock_api_response.json = AsyncMock( + return_value={ + "transactions": [ + { + "transactionId": "TXN-DETAIL-FAIL", + "transactionDate": "2026-03-10T14:00:00Z", + "storeNumber": "42", + } + ] + } + ) + + mock_detail_response = AsyncMock() + mock_detail_response.ok = False + mock_detail_response.status = 404 + + mock_request = AsyncMock() + mock_request.get = AsyncMock(side_effect=[mock_api_response, mock_detail_response]) + + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.request = mock_request + + mock_context = AsyncMock() + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_context.add_cookies = AsyncMock() + mock_context.add_init_script = AsyncMock() + mock_browser = AsyncMock() + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.browser = mock_browser + + mock_pw = AsyncMock() + mock_pw.chromium.launch = AsyncMock(return_value=mock_browser) + + with ( + patch("receiptwitness.scrapers.meijer.async_playwright") as mock_apw, + patch.object(scraper, "human_delay", new_callable=AsyncMock), + ): + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_pw) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_apw.return_value = mock_cm + + receipts = await scraper.scrape_receipts(valid_session) + assert len(receipts) == 1 + assert receipts[0].receipt_id == "TXN-DETAIL-FAIL" + assert receipts[0].raw_data.get("detail") == {} + + +class TestParseReceipt: + def test_parse_receipt_delegates_to_parser(self, scraper): + raw = RawReceipt( + receipt_id="TXN-001", + purchase_date="2026-03-10", + raw_data={ + "detail": { + "items": [ + { + "description": "TEST ITEM", + "price": 5.00, + "extendedPrice": 5.00, + } + ], + "total": 5.00, + } + }, + ) + result = scraper.parse_receipt(raw) + assert result["receipt_id"] == "TXN-001" + assert len(result["items"]) == 1 + + def test_receipt_detail_failure_returns_empty(self, scraper): + raw = RawReceipt( + receipt_id="TXN-FAIL", + purchase_date="2026-03-10", + raw_data={"total": 10.00, "detail": {}}, + ) + result = scraper.parse_receipt(raw) + assert result["receipt_id"] == "TXN-FAIL" + assert result["items"] == [] diff --git a/tests/test_session/__init__.py b/tests/test_session/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_session/test_encryption.py b/tests/test_session/test_encryption.py new file mode 100644 index 0000000..59a57fa --- /dev/null +++ b/tests/test_session/test_encryption.py @@ -0,0 +1,61 @@ +"""Tests for session encryption/decryption.""" + +from unittest.mock import patch + +import pytest +from cryptography.fernet import Fernet, InvalidToken + +from receiptwitness.session.encryption import decrypt_session_data, encrypt_session_data + +TEST_KEY = Fernet.generate_key().decode() + + +@pytest.fixture(autouse=True) +def _mock_encryption_key(): + with patch("receiptwitness.session.encryption.settings") as mock_settings: + mock_settings.session_encryption_key = TEST_KEY + yield + + +class TestEncryptDecrypt: + def test_roundtrip(self): + data = { + "cookies": [{"name": "session", "value": "abc123", "domain": ".meijer.com"}], + "user_agent": "Mozilla/5.0", + } + encrypted = encrypt_session_data(data) + assert isinstance(encrypted, str) + assert encrypted != str(data) + + decrypted = decrypt_session_data(encrypted) + assert decrypted == data + + def test_different_data_different_ciphertext(self): + data1 = {"key": "value1"} + data2 = {"key": "value2"} + enc1 = encrypt_session_data(data1) + enc2 = encrypt_session_data(data2) + assert enc1 != enc2 + + def test_decrypt_with_wrong_key_fails(self): + data = {"cookies": []} + encrypted = encrypt_session_data(data) + + wrong_key = Fernet.generate_key().decode() + with patch("receiptwitness.session.encryption.settings") as mock_settings: + mock_settings.session_encryption_key = wrong_key + with pytest.raises(InvalidToken): + decrypt_session_data(encrypted) + + def test_decrypt_tampered_data_fails(self): + data = {"cookies": []} + encrypted = encrypt_session_data(data) + tampered = encrypted[:-5] + "XXXXX" + with pytest.raises(Exception): + decrypt_session_data(tampered) + + def test_no_key_raises_error(self): + with patch("receiptwitness.session.encryption.settings") as mock_settings: + mock_settings.session_encryption_key = "" + with pytest.raises(ValueError, match="RW_SESSION_ENCRYPTION_KEY"): + encrypt_session_data({"test": True}) diff --git a/tests/test_session/test_manager.py b/tests/test_session/test_manager.py new file mode 100644 index 0000000..68e1015 --- /dev/null +++ b/tests/test_session/test_manager.py @@ -0,0 +1,102 @@ +"""Tests for session manager logic.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, patch + +import pytest +from cryptography.fernet import Fernet + +from receiptwitness.scrapers.base import SessionData +from receiptwitness.session.manager import ( + get_valid_session, + session_from_db_record, + session_to_db_value, +) + +TEST_KEY = Fernet.generate_key().decode() + + +@pytest.fixture(autouse=True) +def _mock_encryption_key(): + with patch("receiptwitness.session.encryption.settings") as mock_settings: + mock_settings.session_encryption_key = TEST_KEY + yield + + +def _make_session(hours_until_expire: int = 4) -> SessionData: + now = datetime.now(UTC) + return SessionData( + cookies=[{"name": "sid", "value": "test", "domain": ".meijer.com"}], + user_agent="Mozilla/5.0", + created_at=now, + expires_at=now + timedelta(hours=hours_until_expire), + ) + + +class TestSessionSerialization: + def test_roundtrip(self): + session = _make_session() + db_value = session_to_db_value(session) + restored = session_from_db_record(db_value) + + assert restored is not None + assert restored.cookies == session.cookies + assert restored.user_agent == session.user_agent + + def test_none_returns_none(self): + assert session_from_db_record(None) is None + + def test_invalid_encrypted_returns_none(self): + assert session_from_db_record("garbage-data") is None + + +class TestGetValidSession: + @pytest.mark.asyncio + async def test_valid_existing_session(self): + session = _make_session() + db_value = session_to_db_value(session) + + scraper = AsyncMock() + scraper.check_session.return_value = True + + result, was_refreshed = await get_valid_session(scraper, db_value, "user", "pass") + assert not was_refreshed + assert result.cookies == session.cookies + scraper.login.assert_not_called() + + @pytest.mark.asyncio + async def test_expired_session_triggers_login(self): + session = _make_session(hours_until_expire=-1) # already expired + db_value = session_to_db_value(session) + + new_session = _make_session() + scraper = AsyncMock() + scraper.login.return_value = new_session + + result, was_refreshed = await get_valid_session(scraper, db_value, "user", "pass") + assert was_refreshed + scraper.login.assert_called_once_with("user", "pass") + + @pytest.mark.asyncio + async def test_no_existing_session_triggers_login(self): + new_session = _make_session() + scraper = AsyncMock() + scraper.login.return_value = new_session + + result, was_refreshed = await get_valid_session(scraper, None, "user", "pass") + assert was_refreshed + scraper.login.assert_called_once() + + @pytest.mark.asyncio + async def test_failed_session_check_triggers_login(self): + session = _make_session() + db_value = session_to_db_value(session) + + new_session = _make_session() + scraper = AsyncMock() + scraper.check_session.return_value = False + scraper.login.return_value = new_session + + result, was_refreshed = await get_valid_session(scraper, db_value, "user", "pass") + assert was_refreshed + scraper.login.assert_called_once() From cc0957fc92c4c70772ae07af690e094480254339 Mon Sep 17 00:00:00 2001 From: Coupon Carl Date: Sat, 28 Mar 2026 02:25:07 +0000 Subject: [PATCH 4/4] docs: update README and CLAUDE.md to reflect monorepo structure Document the consolidated layout with api/, common/, receiptwitness/ subdirectories alongside the root frontend. Co-Authored-By: Paperclip --- CLAUDE.md | 19 +++++++---- README.md | 96 ++++++++++++++++++++----------------------------------- 2 files changed, 46 insertions(+), 69 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 98dd729..623f979 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,20 +1,25 @@ -# CartSnitch Frontend +# CartSnitch Monorepo ## Project Context -CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/cartsnitch`) is the mobile-first Progressive Web App — the flagship repo and primary user interface. +CartSnitch is a self-hosted grocery price intelligence platform. This repo (`cartsnitch/cartsnitch`) is the **monorepo** containing the flagship frontend PWA and core backend services. **GitHub org:** github.com/cartsnitch **Domain:** cartsnitch.com -### CartSnitch Services +### Monorepo Layout + +| Directory | Service | Purpose | +|-----------|---------|---------| +| `/` (root) | Frontend | React PWA, mobile-first (this directory) | +| `api/` | API Gateway | Frontend-facing REST API | +| `common/` | Common | Shared Python models, schemas, Alembic migrations | +| `receiptwitness/` | ReceiptWitness | Purchase data ingestion via retailer scrapers | + +### Other CartSnitch Repos (still separate) | Repo | Service | Purpose | |------|---------|---------| -| `cartsnitch/common` | — | Shared models, schemas, utilities | -| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers | -| `cartsnitch/api` | API Gateway | Frontend-facing REST API | -| `cartsnitch/cartsnitch` | Frontend | React PWA, mobile-first (this repo) | | `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | | `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | | `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | diff --git a/README.md b/README.md index 7dbf7eb..6b626cf 100644 --- a/README.md +++ b/README.md @@ -1,73 +1,45 @@ -# React + TypeScript + Vite +# CartSnitch Monorepo -This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. +CartSnitch is a self-hosted grocery price intelligence platform. This repo consolidates the core services and the flagship frontend PWA. -Currently, two official plugins are available: +## Services -- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Oxc](https://oxc.rs) -- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) +| Directory | Service | Purpose | +|-----------|---------|---------| +| `/` (root) | **Frontend** | React 18 PWA — mobile-first price intelligence UI | +| `api/` | **API Gateway** | FastAPI — frontend-facing REST API | +| `common/` | **Common** | Shared Python models, schemas, Alembic migrations | +| `receiptwitness/` | **ReceiptWitness** | Purchase ingestion via retailer scrapers | -## React Compiler +## Quick Start -The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation). +### Frontend (root) -## Expanding the ESLint configuration - -If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules: - -```js -export default defineConfig([ - globalIgnores(['dist']), - { - files: ['**/*.{ts,tsx}'], - extends: [ - // Other configs... - - // Remove tseslint.configs.recommended and replace with this - tseslint.configs.recommendedTypeChecked, - // Alternatively, use this for stricter rules - tseslint.configs.strictTypeChecked, - // Optionally, add this for stylistic rules - tseslint.configs.stylisticTypeChecked, - - // Other configs... - ], - languageOptions: { - parserOptions: { - project: ['./tsconfig.node.json', './tsconfig.app.json'], - tsconfigRootDir: import.meta.dirname, - }, - // other options... - }, - }, -]) +```bash +npm install +npm run dev # http://localhost:5173 +npm run build # production build +npm run test # unit tests (Vitest) ``` -You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules: +### Python Services -```js -// eslint.config.js -import reactX from 'eslint-plugin-react-x' -import reactDom from 'eslint-plugin-react-dom' +Each Python service uses [uv](https://github.com/astral-sh/uv) and has its own `pyproject.toml`: -export default defineConfig([ - globalIgnores(['dist']), - { - files: ['**/*.{ts,tsx}'], - extends: [ - // Other configs... - // Enable lint rules for React - reactX.configs['recommended-typescript'], - // Enable lint rules for React DOM - reactDom.configs.recommended, - ], - languageOptions: { - parserOptions: { - project: ['./tsconfig.node.json', './tsconfig.app.json'], - tsconfigRootDir: import.meta.dirname, - }, - // other options... - }, - }, -]) +```bash +cd api # or common / receiptwitness +uv sync +uv run pytest ``` + +## Development Workflow + +- **Never push directly to main.** Always open a PR from a feature branch. +- Branch naming: `feature/` or `fix/` +- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` + +## Architecture + +For full details see [CLAUDE.md](./CLAUDE.md) or the per-service `CLAUDE.md` in each subdirectory. + +CartSnitch is a polyrepo-style monorepo: each service can be built and deployed independently, but sharing code between `common/` and the other Python services is done via local path dependencies in `pyproject.toml`.