commit 4cf6f91e954b770198578bcb8db5d98ac964bfed Author: Coupon Carl Date: Sat Mar 28 02:24:14 2026 +0000 Squashed 'common/' content from commit 28b2939 git-subtree-dir: common git-subtree-split: 28b2939037b5932ca5d5a6c734b292c012ac675f diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7483099 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,88 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + lint: + runs-on: runners-cartsnitch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install ruff + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + + typecheck: + runs-on: runners-cartsnitch + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install -e ".[dev]" + - name: Type check + run: mypy src/cartsnitch_common + + test: + runs-on: runners-cartsnitch + services: + postgres: + image: postgres:15-alpine + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + env: + POSTGRES_USER: cartsnitch + POSTGRES_PASSWORD: cartsnitch_test + POSTGRES_DB: cartsnitch_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + DATABASE_URL: postgresql://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + CARTSNITCH_DATABASE_URL_SYNC: postgresql://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - run: pip install -e ".[dev]" + - name: Run migrations + run: alembic upgrade head + - name: Run tests + run: pytest --tb=short -q + + build: + runs-on: runners-cartsnitch + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install build + - name: Build package + run: python -m build diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e044f1f --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.egg-info/ +dist/ +build/ +.pytest_cache/ +*.egg +.venv/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ef248cd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,185 @@ +# CartSnitch Common Library + +## Project Context + +CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/common`) is the shared Python library that all CartSnitch services depend on. + +**GitHub org:** github.com/cartsnitch +**Domain:** cartsnitch.com + +### CartSnitch Services + +| Repo | Service | Purpose | +|------|---------|---------| +| `cartsnitch/common` | — | Shared models, schemas, utilities (this repo) | +| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers | +| `cartsnitch/api` | API Gateway | Frontend-facing REST API | +| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) | +| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison | +| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring | +| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization | +| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations | + +### Architecture Decisions + +- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline. +- **Shared DB:** One PostgreSQL cluster (CNPG on K8s, docker-compose locally). Each service owns its tables but shares the database. Services access other services' data via REST API, not direct cross-service DB queries. +- **Inter-service comms:** REST (synchronous) + Redis pub/sub (async events). +- **Target scale:** 500–1,000 users initially. +- **Target retailers (MVP):** Meijer (mPerks), Kroger, Target (Circle) in Southeast Michigan. + +## What This Repo Contains + +This is a Python package (`cartsnitch-common`) that provides: + +1. **SQLAlchemy ORM models** — the canonical database schema shared across services +2. **Pydantic schemas** — request/response models for inter-service API contracts +3. **Database utilities** — engine/session factory, connection management +4. **Configuration** — shared settings via pydantic-settings (DB URL, Redis URL, etc.) +5. **Event definitions** — Redis pub/sub event types and payloads +6. **Constants** — store slugs, category enums, etc. + +## Tech Stack + +- Python 3.12+ +- SQLAlchemy 2.0 (async support) +- Alembic (migrations live in this repo since it owns the schema) +- Pydantic v2 +- pydantic-settings (env-based configuration) +- Redis (py-redis for pub/sub event definitions) + +## Database Schema + +All migrations are managed from this repo via Alembic. Services depend on `cartsnitch-common` to get the models. + +### Core Tables + +``` +stores + id (PK), name, slug (meijer|kroger|target), logo_url, website_url, created_at + +store_locations + id (PK), store_id (FK), address, city, state, zip, lat, lng + +users + id (PK), email, hashed_password, display_name, created_at, updated_at + +user_store_accounts + id (PK), user_id (FK), store_id (FK), session_data (encrypted JSONB), session_expires_at, last_sync_at, status (active|expired|error) + +purchases + id (PK), user_id (FK), store_id (FK), store_location_id (FK), receipt_id (unique per store), purchase_date, total, subtotal, tax, savings_total, source_url, raw_data (JSONB), ingested_at + +purchase_items + id (PK), purchase_id (FK), product_name_raw, upc, quantity, unit_price, extended_price, regular_price, sale_price, coupon_discount, loyalty_discount, category_raw, normalized_product_id (FK, nullable) + +normalized_products + id (PK), canonical_name, category, subcategory, brand, size, size_unit, upc_variants (JSONB), created_at, updated_at + +price_history + id (PK), normalized_product_id (FK), store_id (FK), observed_date, regular_price, sale_price, loyalty_price, coupon_price, source (receipt|catalog|weekly_ad), purchase_item_id (FK, nullable) + +coupons + id (PK), store_id (FK), normalized_product_id (FK, nullable), title, description, discount_type (percent|fixed|bogo|buy_x_get_y), discount_value, min_purchase, valid_from, valid_to, requires_clip, coupon_code, source_url, scraped_at + +shrinkflation_events + id (PK), normalized_product_id (FK), detected_date, old_size, new_size, old_unit, new_unit, price_at_old_size, price_at_new_size, confidence, notes +``` + +## Repo Structure + +``` +cartsnitch-common/ +├── CLAUDE.md +├── README.md +├── pyproject.toml # Package definition, installable via pip +├── alembic.ini +├── alembic/ +│ ├── env.py +│ └── versions/ +├── src/ +│ └── cartsnitch_common/ +│ ├── __init__.py +│ ├── config.py # Shared settings (DB_URL, REDIS_URL, etc.) +│ ├── database.py # Engine, session factory, async support +│ ├── models/ +│ │ ├── __init__.py # Re-exports all models +│ │ ├── base.py # DeclarativeBase, common mixins (timestamps, etc.) +│ │ ├── store.py # Store, StoreLocation +│ │ ├── user.py # User, UserStoreAccount +│ │ ├── purchase.py # Purchase, PurchaseItem +│ │ ├── product.py # NormalizedProduct +│ │ ├── price.py # PriceHistory +│ │ ├── coupon.py # Coupon +│ │ └── shrinkflation.py # ShrinkflationEvent +│ ├── schemas/ +│ │ ├── __init__.py +│ │ ├── purchase.py # Pydantic request/response schemas +│ │ ├── product.py +│ │ ├── price.py +│ │ ├── coupon.py +│ │ └── events.py # Redis pub/sub event payloads +│ ├── events.py # Event bus helpers (publish/subscribe) +│ └── constants.py # Store slugs, enums +└── tests/ + ├── conftest.py + ├── test_models.py + └── test_schemas.py +``` + +## Packaging + +This package is published as `cartsnitch-common` and installed by other services via: + +``` +# In each service's pyproject.toml +dependencies = [ + "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@main", +] +``` + +Or if using a private PyPI registry, publish there. For local dev, install in editable mode: + +```bash +pip install -e /path/to/common +``` + +## Development Workflow + +- **Never push directly to main.** Always create feature branches and open PRs. +- Branch naming: `feature/` or `fix/` +- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:` +- Alembic migrations must be reviewed carefully — they affect all services. +- Bump the version in `pyproject.toml` when changing schemas or models so downstream services can pin versions. +- Run `alembic upgrade head` in local dev after pulling changes. + +## Event Bus (Redis Pub/Sub) + +Events are the primary async communication mechanism between services. Event types are defined in this repo so all services share the same contract. + +### Event Channels + +- `cartsnitch.receipts.ingested` — ReceiptWitness publishes when new receipt data is saved +- `cartsnitch.prices.updated` — Published when new price data points are recorded +- `cartsnitch.products.normalized` — Published when product normalization resolves a match +- `cartsnitch.coupons.updated` — ClipArtist publishes when coupon data refreshes +- `cartsnitch.alerts.price_increase` — StickerShock publishes when a significant price increase is detected +- `cartsnitch.alerts.shrinkflation` — ShrinkRay publishes when shrinkflation is detected + +### Event Payload Structure + +```json +{ + "event_type": "cartsnitch.receipts.ingested", + "timestamp": "2026-03-15T12:00:00Z", + "service": "receiptwitness", + "payload": { ... } +} +``` + +## Important Notes + +- This is the schema owner. All Alembic migrations live here. No other service runs its own migrations. +- When adding new models or changing existing ones, always create a migration and bump the package version. +- Pydantic schemas in `schemas/` define the API contracts between services. These are the source of truth for inter-service communication. +- The `database.py` module should support both sync and async sessions since different services may use different patterns. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2ce4733 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +# syntax=docker/dockerfile:1 +FROM python:3.12-slim AS base + +WORKDIR /app + +COPY pyproject.toml ./ +RUN pip install --no-cache-dir . + +COPY src/ src/ +COPY alembic/ alembic/ +COPY alembic.ini ./ + +FROM base AS test +RUN pip install --no-cache-dir ".[dev]" +COPY tests/ tests/ +CMD ["pytest", "--tb=short", "-q"] + +FROM base AS prod +CMD ["python", "-c", "import cartsnitch_common; print(f'cartsnitch-common ready')"] diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..00a0b14 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql://localhost:5432/cartsnitch + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..cd893e0 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,51 @@ +"""Alembic environment configuration for CartSnitch.""" + +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context +from cartsnitch_common.models.base import Base + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC") +if db_url: + config.set_main_option("sqlalchemy.url", db_url) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fe3b097 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ee348c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,47 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cartsnitch-common" +version = "2026.321.0" +description = "Shared models, schemas, and utilities for CartSnitch services" +requires-python = ">=3.12" +dependencies = [ + "sqlalchemy[asyncio]>=2.0,<3.0", + "alembic>=1.13,<2.0", + "pydantic[email]>=2.0,<3.0", + "pydantic-settings>=2.0,<3.0", + "asyncpg>=0.29,<1.0", + "redis>=5.0,<6.0", + "psycopg2-binary>=2.9,<3.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "ruff>=0.4", + "mypy>=1.10", + "faker>=33.0,<34.0", +] +seed = [ + "faker>=33.0,<34.0", +] + +[project.scripts] +cartsnitch-seed = "cartsnitch_common.seed.__main__:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/cartsnitch_common"] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..833ba3b --- /dev/null +++ b/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["local>cartsnitch/.github:renovate-config"] +} diff --git a/scripts/stats/README.md b/scripts/stats/README.md new file mode 100644 index 0000000..86d424f --- /dev/null +++ b/scripts/stats/README.md @@ -0,0 +1,60 @@ +# Launch Announcement Validation Queries + +Scripts to validate the two statistics cited in the CartSnitch launch announcement: + +1. **847 products that shrank in the past 12 months** +2. **$336/year potential savings from buying the same items at the cheapest store** + +## Status + +These queries are written against the production data model but **cannot be run yet** — production infrastructure (CAR-99, CAR-104) is still being deployed. Once production data is available, run these scripts to confirm the cited numbers. + +## Queries + +### Stat 1: Shrinkflation count (`shrinkflation_count.sql`) + +Counts distinct `normalized_product_id` values with at least one `ShrinkflationEvent` where `detected_date` falls within the past 12 months. + +**Key assumptions:** +- "Past 12 months" is relative to query execution date (`CURRENT_DATE - INTERVAL '12 months'`). +- A product counts once even if it has multiple shrinkflation events in the window. +- The 847 figure was generated from a specific date — re-running will drift as the window slides. + +### Stat 2: Annual savings potential (`savings_potential.sql`) + +**Methodology:** + +For each `normalized_product_id` with price observations from **two or more distinct stores** in the past 90 days: + +1. Take the **most recent `regular_price`** per `(normalized_product_id, store_id)` pair. +2. Compute `cheapest_price` = MIN across stores, `avg_price` = AVG across stores. +3. `savings_per_purchase` = `avg_price - cheapest_price`. + +To arrive at **annual** savings per family: + +- Assume a **typical family purchases each product ~N times per year** (see `PURCHASE_FREQUENCY_PER_YEAR` constant in `validate_launch_stats.py`). +- Default assumption: products purchased on average 26×/year (~every 2 weeks for regularly bought items). +- Sum across all eligible products: `Σ(savings_per_purchase × frequency)`. + +**Sensitivity knobs:** +- `PURCHASE_FREQUENCY_PER_YEAR` — adjust purchase cadence assumption +- `LOOKBACK_DAYS` — how recent a price observation must be to be "current" (default: 90 days) +- `MIN_STORES_FOR_COMPARISON` — minimum number of stores a product must appear at (default: 2) + +The $336 figure assumes the defaults above. If actual purchase frequencies differ significantly, rerun `validate_launch_stats.py --freq `. + +## Running + +```bash +# Requires DATABASE_URL env var pointing at production Postgres +python scripts/stats/validate_launch_stats.py + +# Adjust purchase frequency assumption (default: 26 times/year) +python scripts/stats/validate_launch_stats.py --freq 20 + +# Run just stat 1 or stat 2 +python scripts/stats/validate_launch_stats.py --stat 1 +python scripts/stats/validate_launch_stats.py --stat 2 +``` + +Raw SQL files (`shrinkflation_count.sql`, `savings_potential.sql`) can also be run directly with `psql`. diff --git a/scripts/stats/savings_potential.sql b/scripts/stats/savings_potential.sql new file mode 100644 index 0000000..f575bcc --- /dev/null +++ b/scripts/stats/savings_potential.sql @@ -0,0 +1,121 @@ +-- ============================================================================= +-- Stat 2: Annual savings potential from cross-store price comparison +-- Validates: "$336/year potential savings from buying the same items +-- at the cheapest store" (launch announcement) +-- +-- Methodology: +-- 1. For each (normalized_product_id, store_id), take the MOST RECENT +-- regular_price within the past 90 days ("current" price). +-- 2. Keep only products observed at 2+ distinct stores. +-- 3. For each product: savings_per_purchase = avg_price - min_price across stores. +-- 4. Annualise: multiply by an assumed purchase frequency of 26x/year +-- (~every 2 weeks for regularly purchased grocery items). +-- 5. Sum across all eligible products to get total annual savings potential. +-- +-- Sensitivity: +-- Change the frequency constant (26) and lookback interval (90 days) to +-- explore how sensitive the $336 figure is to these assumptions. +-- +-- Run against production Postgres once infrastructure is available. +-- ============================================================================= + +-- Step 1: most-recent price per (product, store) within the past 90 days +WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + s.slug AS store_slug, + ph.regular_price AS current_price, + ph.observed_date + FROM price_history ph + JOIN stores s ON s.id = ph.store_id + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '90 days' + AND ph.regular_price > 0 + ORDER BY + ph.normalized_product_id, + ph.store_id, + ph.observed_date DESC +), + +-- Step 2: aggregate per product — only keep products seen at 2+ stores +product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price, + MAX(lp.current_price) AS most_expensive_price, + MAX(lp.current_price) - MIN(lp.current_price) AS price_range + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= 2 +), + +-- Step 3: compute savings_per_purchase and annualise +-- Purchase frequency assumption: 26 purchases/year per product (~every 2 weeks) +-- This is a conservative estimate for regularly purchased grocery items. +savings_per_product AS ( + SELECT + pps.normalized_product_id, + np.canonical_name, + np.category, + pps.store_count, + pps.cheapest_price, + pps.avg_price, + pps.price_range, + ROUND(pps.avg_price - pps.cheapest_price, 2) AS savings_per_purchase, + ROUND((pps.avg_price - pps.cheapest_price) * 26, 2) AS annual_savings_at_26x + FROM product_price_spread pps + JOIN normalized_products np ON np.id = pps.normalized_product_id +) + +-- Final summary: total annual savings potential +SELECT + COUNT(*) AS eligible_product_count, + ROUND(AVG(savings_per_purchase), 4) AS avg_savings_per_purchase, + ROUND(SUM(annual_savings_at_26x), 2) AS total_annual_savings_26x_freq, + -- Sensitivity: alternative frequencies + ROUND(SUM(savings_per_purchase) * 20, 2) AS total_annual_savings_20x_freq, + ROUND(SUM(savings_per_purchase) * 52, 2) AS total_annual_savings_52x_freq +FROM savings_per_product; + + +-- Per-product detail (top 50 by annual savings opportunity) +WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + s.slug AS store_slug, + ph.regular_price AS current_price, + ph.observed_date + FROM price_history ph + JOIN stores s ON s.id = ph.store_id + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '90 days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC +), +product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= 2 +) +SELECT + np.canonical_name, + np.category, + np.brand, + np.size, + np.size_unit, + pps.store_count, + pps.cheapest_price, + ROUND(pps.avg_price, 2) AS avg_price, + ROUND(pps.avg_price - pps.cheapest_price, 2) AS savings_per_purchase, + ROUND((pps.avg_price - pps.cheapest_price) * 26, 2) AS annual_savings_at_26x +FROM product_price_spread pps +JOIN normalized_products np ON np.id = pps.normalized_product_id +ORDER BY annual_savings_at_26x DESC +LIMIT 50; diff --git a/scripts/stats/shrinkflation_count.sql b/scripts/stats/shrinkflation_count.sql new file mode 100644 index 0000000..7312190 --- /dev/null +++ b/scripts/stats/shrinkflation_count.sql @@ -0,0 +1,39 @@ +-- ============================================================================= +-- Stat 1: Products that shrank in the past 12 months +-- Validates: "847 products that shrank in the past 12 months" (launch announcement) +-- +-- Run against production Postgres once infrastructure is available. +-- Results will drift as the 12-month window slides forward from execution date. +-- ============================================================================= + +-- Primary count: distinct products with ≥1 shrinkflation event in the past year +SELECT + COUNT(DISTINCT se.normalized_product_id) AS shrinkflation_product_count +FROM shrinkflation_events se +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months'; + + +-- Breakdown by product category (for deeper reporting) +SELECT + COALESCE(np.category, 'unknown') AS category, + COUNT(DISTINCT se.normalized_product_id) AS products_with_shrinkflation +FROM shrinkflation_events se +JOIN normalized_products np ON np.id = se.normalized_product_id +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +GROUP BY np.category +ORDER BY products_with_shrinkflation DESC; + + +-- Breakdown by confidence band (high/medium/low events) +-- Confidence >= 0.80 = "clear" shrinkflation signal +SELECT + CASE + WHEN se.confidence >= 0.80 THEN 'high (>=0.80)' + WHEN se.confidence >= 0.50 THEN 'medium (0.50-0.79)' + ELSE 'low (<0.50)' + END AS confidence_band, + COUNT(DISTINCT se.normalized_product_id) AS products +FROM shrinkflation_events se +WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +GROUP BY confidence_band +ORDER BY MIN(se.confidence) DESC; diff --git a/scripts/stats/validate_launch_stats.py b/scripts/stats/validate_launch_stats.py new file mode 100644 index 0000000..b0e152f --- /dev/null +++ b/scripts/stats/validate_launch_stats.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +validate_launch_stats.py — Validate CartSnitch launch announcement statistics. + +Validates two statistics from content/marketing/launch-announcement.md: + 1. "847 products that shrank in the past 12 months" + 2. "$336/year potential savings from buying the same items at the cheapest store" + +Usage: + DATABASE_URL=postgresql+asyncpg://... python scripts/stats/validate_launch_stats.py + python scripts/stats/validate_launch_stats.py --freq 20 # change purchase frequency + python scripts/stats/validate_launch_stats.py --stat 1 # run stat 1 only + python scripts/stats/validate_launch_stats.py --stat 2 # run stat 2 only + +NOTE: Production infrastructure is not yet deployed (CAR-99, CAR-104). This script +cannot be run against real data until those are complete. The data model has been +verified to support both queries. + +Ref: CAR-162 +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import sys +from decimal import Decimal + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +# ────────────────────────────────────────────────────────────────────────────── +# Configuration / assumptions +# ────────────────────────────────────────────────────────────────────────────── + +DEFAULT_PURCHASE_FREQUENCY_PER_YEAR: int = 26 +"""Default purchase frequency assumption. + +26 = roughly every 2 weeks for a typical grocery staple. +Adjust with --freq to explore sensitivity. +""" + +PRICE_LOOKBACK_DAYS: int = 90 +"""How many days back to look for a "current" price observation.""" + +MIN_STORES_FOR_COMPARISON: int = 2 +"""Minimum number of distinct stores a product must appear at to be eligible.""" + + +# ────────────────────────────────────────────────────────────────────────────── +# Stat 1: shrinkflation count +# ────────────────────────────────────────────────────────────────────────────── + +SHRINKFLATION_COUNT_SQL = sa.text(""" + SELECT COUNT(DISTINCT se.normalized_product_id) AS shrinkflation_product_count + FROM shrinkflation_events se + WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' +""") + +SHRINKFLATION_BY_CATEGORY_SQL = sa.text(""" + SELECT + COALESCE(np.category, 'unknown') AS category, + COUNT(DISTINCT se.normalized_product_id) AS product_count + FROM shrinkflation_events se + JOIN normalized_products np ON np.id = se.normalized_product_id + WHERE se.detected_date >= CURRENT_DATE - INTERVAL '12 months' + GROUP BY np.category + ORDER BY product_count DESC +""") + + +async def run_stat_1(session: AsyncSession) -> None: + """Validate: 847 products shrank in the past 12 months.""" + print("\n" + "=" * 70) + print("STAT 1: Products with shrinkflation events in the past 12 months") + print("Expected: ~847") + print("=" * 70) + + result = await session.execute(SHRINKFLATION_COUNT_SQL) + row = result.fetchone() + count = row[0] if row else 0 + print(f"\n Distinct products: {count:,}") + + announced = 847 + delta = count - announced + pct = (abs(delta) / announced * 100) if announced else 0 + status = "✓ MATCHES" if abs(delta) <= 10 else f"⚠ DIFFERS by {delta:+d} ({pct:.1f}%)" + print(f" Announced value: {announced:,}") + print(f" Status: {status}") + + print("\n Breakdown by category:") + cat_result = await session.execute(SHRINKFLATION_BY_CATEGORY_SQL) + for cat_row in cat_result.fetchall(): + print(f" {cat_row[0]:<20s} {cat_row[1]:>5,}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Stat 2: annual savings potential +# ────────────────────────────────────────────────────────────────────────────── + + +def savings_summary_sql(freq: int, lookback_days: int, min_stores: int) -> sa.TextClause: + """Build the savings summary query with runtime parameters.""" + return sa.text(f""" + WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + ph.regular_price AS current_price + FROM price_history ph + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '{lookback_days} days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC + ), + product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= {min_stores} + ) + SELECT + COUNT(*) AS eligible_products, + ROUND(AVG(avg_price - cheapest_price)::numeric, 4) AS avg_savings_per_purchase, + ROUND(SUM((avg_price - cheapest_price) * {freq})::numeric, 2) + AS total_annual_savings + FROM product_price_spread + """) + + +def savings_top_products_sql(freq: int, lookback_days: int, min_stores: int) -> sa.TextClause: + """Top 20 products by annual savings opportunity.""" + return sa.text(f""" + WITH latest_prices AS ( + SELECT DISTINCT ON (ph.normalized_product_id, ph.store_id) + ph.normalized_product_id, + ph.store_id, + ph.regular_price AS current_price + FROM price_history ph + WHERE ph.observed_date >= CURRENT_DATE - INTERVAL '{lookback_days} days' + AND ph.regular_price > 0 + ORDER BY ph.normalized_product_id, ph.store_id, ph.observed_date DESC + ), + product_price_spread AS ( + SELECT + lp.normalized_product_id, + COUNT(DISTINCT lp.store_id) AS store_count, + MIN(lp.current_price) AS cheapest_price, + AVG(lp.current_price) AS avg_price + FROM latest_prices lp + GROUP BY lp.normalized_product_id + HAVING COUNT(DISTINCT lp.store_id) >= {min_stores} + ) + SELECT + np.canonical_name, + np.brand, + np.category, + ROUND((pps.avg_price - pps.cheapest_price)::numeric, 2) AS savings_per_purchase, + ROUND(((pps.avg_price - pps.cheapest_price) * {freq})::numeric, 2) AS annual_savings + FROM product_price_spread pps + JOIN normalized_products np ON np.id = pps.normalized_product_id + ORDER BY annual_savings DESC + LIMIT 20 + """) + + +async def run_stat_2(session: AsyncSession, freq: int) -> None: + """Validate: $336/year potential savings from cross-store price comparison.""" + print("\n" + "=" * 70) + print("STAT 2: Annual savings potential from buying at cheapest store") + print( + f"Assumptions: purchase freq={freq}x/year, price lookback={PRICE_LOOKBACK_DAYS}d, " + f"min_stores={MIN_STORES_FOR_COMPARISON}" + ) + print("Expected: ~$336/year") + print("=" * 70) + + result = await session.execute( + savings_summary_sql(freq, PRICE_LOOKBACK_DAYS, MIN_STORES_FOR_COMPARISON) + ) + row = result.fetchone() + if not row or row[0] == 0: + print("\n No eligible products found. Is production data loaded?") + return + + eligible, avg_save, total_annual = row + print(f"\n Eligible products (in 2+ stores): {eligible:,}") + print(f" Avg savings per purchase: ${avg_save:.4f}") + print(f" Estimated annual savings: ${total_annual:,.2f}") + + announced = Decimal("336.00") + delta = total_annual - announced + pct = abs(delta) / announced * 100 + # Allow ±10% tolerance for frequency assumption variance + status = "✓ WITHIN 10%" if pct <= 10 else f"⚠ DIFFERS by ${delta:+.2f} ({pct:.1f}%)" + print(f" Announced value: ${announced:,.2f}") + print(f" Status: {status}") + + print("\n Sensitivity (same data, different frequency assumptions):") + for alt_freq in (13, 20, 26, 40, 52): + alt = float(avg_save) * int(eligible) * alt_freq + marker = " ← default" if alt_freq == freq else "" + print(f" {alt_freq:>2}x/year: ${alt:>8,.2f}{marker}") + + print("\n Top 20 products by annual savings opportunity:") + top_result = await session.execute( + savings_top_products_sql(freq, PRICE_LOOKBACK_DAYS, MIN_STORES_FOR_COMPARISON) + ) + print(f" {'Product':<40s} {'Brand':<20s} {'Save/Buy':>8} {'Annual':>8}") + print(f" {'-' * 40} {'-' * 20} {'-' * 8} {'-' * 8}") + for r in top_result.fetchall(): + name = (r[0] or "")[:39] + brand = (r[1] or "")[:19] + print(f" {name:<40s} {brand:<20s} ${r[3]:>7.2f} ${r[4]:>7.2f}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Entry point +# ────────────────────────────────────────────────────────────────────────────── + + +async def main(stat: int | None, freq: int) -> None: + db_url = os.getenv("DATABASE_URL") + if not db_url: + print("ERROR: DATABASE_URL environment variable is not set.", file=sys.stderr) + print("Set it to your production Postgres URL, e.g.:", file=sys.stderr) + print(" export DATABASE_URL=postgresql+asyncpg://user:pass@host/db", file=sys.stderr) + sys.exit(1) + + engine = create_async_engine(db_url, echo=False) + async with AsyncSession(engine) as session: + if stat is None or stat == 1: + await run_stat_1(session) + if stat is None or stat == 2: + await run_stat_2(session, freq) + + await engine.dispose() + print("\nDone.\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--stat", + type=int, + choices=[1, 2], + default=None, + help="Run only stat 1 or stat 2 (default: both)", + ) + parser.add_argument( + "--freq", + type=int, + default=DEFAULT_PURCHASE_FREQUENCY_PER_YEAR, + help=( + "Purchase frequency per product per year " + f"(default: {DEFAULT_PURCHASE_FREQUENCY_PER_YEAR})" + ), + ) + args = parser.parse_args() + asyncio.run(main(stat=args.stat, freq=args.freq)) diff --git a/src/cartsnitch_common/__init__.py b/src/cartsnitch_common/__init__.py new file mode 100644 index 0000000..5ee899a --- /dev/null +++ b/src/cartsnitch_common/__init__.py @@ -0,0 +1,3 @@ +"""CartSnitch Common Library — shared models, schemas, and utilities.""" + +__version__ = "0.3.0" diff --git a/src/cartsnitch_common/config.py b/src/cartsnitch_common/config.py new file mode 100644 index 0000000..70b4153 --- /dev/null +++ b/src/cartsnitch_common/config.py @@ -0,0 +1,18 @@ +"""Shared configuration for CartSnitch services via pydantic-settings.""" + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Environment-driven settings shared by all CartSnitch services.""" + + model_config = SettingsConfigDict(env_prefix="CARTSNITCH_", env_file=".env") + + database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + database_url_sync: str = "postgresql+psycopg2://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + redis_url: str = "redis://localhost:6379/0" + debug: bool = False + log_level: str = "INFO" + + +settings = Settings() diff --git a/src/cartsnitch_common/constants.py b/src/cartsnitch_common/constants.py new file mode 100644 index 0000000..b7a716c --- /dev/null +++ b/src/cartsnitch_common/constants.py @@ -0,0 +1,85 @@ +"""Constants and enums shared across CartSnitch services.""" + +from enum import StrEnum + + +class StoreSlug(StrEnum): + """Supported retailer slugs.""" + + MEIJER = "meijer" + KROGER = "kroger" + TARGET = "target" + + +class AccountStatus(StrEnum): + """User store account link status.""" + + ACTIVE = "active" + EXPIRED = "expired" + ERROR = "error" + + +class DiscountType(StrEnum): + """Coupon discount type.""" + + PERCENT = "percent" + FIXED = "fixed" + BOGO = "bogo" + BUY_X_GET_Y = "buy_x_get_y" + + +class PriceSource(StrEnum): + """Source of a price observation.""" + + RECEIPT = "receipt" + CATALOG = "catalog" + WEEKLY_AD = "weekly_ad" + + +class EventType(StrEnum): + """Redis pub/sub event types.""" + + RECEIPTS_INGESTED = "cartsnitch.receipts.ingested" + PRICES_UPDATED = "cartsnitch.prices.updated" + PRODUCTS_NORMALIZED = "cartsnitch.products.normalized" + COUPONS_UPDATED = "cartsnitch.coupons.updated" + ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase" + ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation" + + +class ProductCategory(StrEnum): + """Top-level product categories.""" + + PRODUCE = "produce" + DAIRY = "dairy" + MEAT = "meat" + BAKERY = "bakery" + FROZEN = "frozen" + PANTRY = "pantry" + BEVERAGES = "beverages" + SNACKS = "snacks" + HOUSEHOLD = "household" + PERSONAL_CARE = "personal_care" + OTHER = "other" + + +class MatchConfidence(StrEnum): + """Confidence level for product matching.""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class SizeUnit(StrEnum): + """Standardized product size units.""" + + OZ = "oz" + FL_OZ = "fl_oz" + LB = "lb" + G = "g" + KG = "kg" + ML = "ml" + L = "l" + CT = "ct" + PK = "pk" diff --git a/src/cartsnitch_common/database.py b/src/cartsnitch_common/database.py new file mode 100644 index 0000000..76a4f35 --- /dev/null +++ b/src/cartsnitch_common/database.py @@ -0,0 +1,45 @@ +"""Database engine and session factories for sync and async usage.""" + +from collections.abc import AsyncGenerator, Generator + +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker + +from cartsnitch_common.config import settings + + +def get_async_engine(url: str | None = None): + """Create an async SQLAlchemy engine.""" + return create_async_engine(url or settings.database_url, echo=settings.debug) + + +def get_sync_engine(url: str | None = None): + """Create a sync SQLAlchemy engine.""" + return create_engine(url or settings.database_url_sync, echo=settings.debug) + + +def get_async_session_factory(url: str | None = None) -> async_sessionmaker[AsyncSession]: + """Create an async session factory.""" + engine = get_async_engine(url) + return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +def get_sync_session_factory(url: str | None = None) -> sessionmaker[Session]: + """Create a sync session factory.""" + engine = get_sync_engine(url) + return sessionmaker(engine, expire_on_commit=False) + + +async def get_async_session(url: str | None = None) -> AsyncGenerator[AsyncSession, None]: + """Dependency for async session injection.""" + factory = get_async_session_factory(url) + async with factory() as session: + yield session + + +def get_sync_session(url: str | None = None) -> Generator[Session, None, None]: + """Dependency for sync session injection.""" + factory = get_sync_session_factory(url) + with factory() as session: + yield session diff --git a/src/cartsnitch_common/events.py b/src/cartsnitch_common/events.py new file mode 100644 index 0000000..986362f --- /dev/null +++ b/src/cartsnitch_common/events.py @@ -0,0 +1,28 @@ +"""Event bus helpers for Redis pub/sub.""" + +from datetime import UTC, datetime +from typing import Any, cast + +from redis import Redis + +from cartsnitch_common.constants import EventType +from cartsnitch_common.schemas.events import EventEnvelope + + +def publish_event( + redis_client: Redis, + event_type: EventType, + service: str, + payload: dict[str, Any], +) -> int: + """Publish an event to the Redis pub/sub channel. + + Returns the number of subscribers that received the message. + """ + envelope = EventEnvelope( + event_type=event_type, + timestamp=datetime.now(UTC), + service=service, + payload=payload, + ) + return cast(int, redis_client.publish(event_type.value, envelope.model_dump_json())) diff --git a/src/cartsnitch_common/models/__init__.py b/src/cartsnitch_common/models/__init__.py new file mode 100644 index 0000000..fb12ee3 --- /dev/null +++ b/src/cartsnitch_common/models/__init__.py @@ -0,0 +1,25 @@ +"""SQLAlchemy ORM models — re-exports all models for convenience.""" + +from cartsnitch_common.models.base import Base, TimestampMixin +from cartsnitch_common.models.coupon import Coupon +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.purchase import Purchase, PurchaseItem +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent +from cartsnitch_common.models.store import Store, StoreLocation +from cartsnitch_common.models.user import User, UserStoreAccount + +__all__ = [ + "Base", + "TimestampMixin", + "Store", + "StoreLocation", + "User", + "UserStoreAccount", + "Purchase", + "PurchaseItem", + "NormalizedProduct", + "PriceHistory", + "Coupon", + "ShrinkflationEvent", +] diff --git a/src/cartsnitch_common/models/base.py b/src/cartsnitch_common/models/base.py new file mode 100644 index 0000000..f93cf79 --- /dev/null +++ b/src/cartsnitch_common/models/base.py @@ -0,0 +1,30 @@ +"""Base model and mixins for all CartSnitch ORM models.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Base class for all CartSnitch models.""" + + +class TimestampMixin: + """Mixin providing created_at / updated_at columns.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + ) + + +class UUIDPrimaryKeyMixin: + """Mixin providing a UUID primary key.""" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid() + ) diff --git a/src/cartsnitch_common/models/coupon.py b/src/cartsnitch_common/models/coupon.py new file mode 100644 index 0000000..6e6305d --- /dev/null +++ b/src/cartsnitch_common/models/coupon.py @@ -0,0 +1,42 @@ +"""Coupon model.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import DiscountType +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.store import Store + + +class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A coupon or deal for a product at a store.""" + + __tablename__ = "coupons" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + title: Mapped[str] = mapped_column(String(300), nullable=False) + description: Mapped[str | None] = mapped_column(String(1000)) + discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False) + discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + valid_from: Mapped[date | None] = mapped_column(Date) + valid_to: Mapped[date | None] = mapped_column(Date) + requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + coupon_code: Mapped[str | None] = mapped_column(String(100)) + source_url: Mapped[str | None] = mapped_column(String(500)) + scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="coupons") + normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons") diff --git a/src/cartsnitch_common/models/price.py b/src/cartsnitch_common/models/price.py new file mode 100644 index 0000000..5814aac --- /dev/null +++ b/src/cartsnitch_common/models/price.py @@ -0,0 +1,50 @@ +"""PriceHistory model — tracks product prices over time.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Index, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.purchase import PurchaseItem + from cartsnitch_common.models.store import Store + + +class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single price observation for a product at a store on a date.""" + + __tablename__ = "price_history" + __table_args__ = ( + Index( + "ix_price_history_product_store_date", + "normalized_product_id", + "store_id", + "observed_date", + ), + ) + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + observed_date: Mapped[date] = mapped_column(Date, nullable=False) + regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source: Mapped[PriceSource] = mapped_column(String(20), nullable=False) + purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id")) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories") + store: Mapped["Store"] = relationship(back_populates="price_histories") + purchase_item: Mapped["PurchaseItem | None"] = relationship( + back_populates="price_history_entries" + ) diff --git a/src/cartsnitch_common/models/product.py b/src/cartsnitch_common/models/product.py new file mode 100644 index 0000000..215e57e --- /dev/null +++ b/src/cartsnitch_common/models/product.py @@ -0,0 +1,39 @@ +"""NormalizedProduct model — the canonical product identity.""" + +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import ProductCategory, SizeUnit +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.coupon import Coupon + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.purchase import PurchaseItem + from cartsnitch_common.models.shrinkflation import ShrinkflationEvent + + +class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Canonical product identity — matches products across retailers.""" + + __tablename__ = "normalized_products" + + canonical_name: Mapped[str] = mapped_column(String(300), nullable=False) + category: Mapped[ProductCategory | None] = mapped_column(String(50)) + subcategory: Mapped[str | None] = mapped_column(String(100)) + brand: Mapped[str | None] = mapped_column(String(200)) + size: Mapped[str | None] = mapped_column(String(50)) + size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) + upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) + + # Relationships + purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") + price_histories: Mapped[list["PriceHistory"]] = relationship( + back_populates="normalized_product" + ) + coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product") + shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship( + back_populates="normalized_product" + ) diff --git a/src/cartsnitch_common/models/purchase.py b/src/cartsnitch_common/models/purchase.py new file mode 100644 index 0000000..3797ef2 --- /dev/null +++ b/src/cartsnitch_common/models/purchase.py @@ -0,0 +1,91 @@ +"""Purchase and PurchaseItem models.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import ( + JSON, + Date, + DateTime, + ForeignKey, + Index, + Numeric, + String, + UniqueConstraint, + func, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.product import NormalizedProduct + from cartsnitch_common.models.store import Store, StoreLocation + from cartsnitch_common.models.user import User + + +class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """A single shopping trip / receipt.""" + + __tablename__ = "purchases" + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id")) + receipt_id: Mapped[str] = mapped_column(String(200), nullable=False) + purchase_date: Mapped[date] = mapped_column(Date, nullable=False) + total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + source_url: Mapped[str | None] = mapped_column(String(500)) + raw_data: Mapped[dict | None] = mapped_column(JSON) + ingested_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="purchases") + store: Mapped["Store"] = relationship(back_populates="purchases") + store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases") + items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase") + + __table_args__ = ( + Index("ix_purchases_user_store", "user_id", "store_id"), + UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"), + ) + + +class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Individual line item on a receipt.""" + + __tablename__ = "purchase_items" + + purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False) + product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False) + upc: Mapped[str | None] = mapped_column(String(20)) + quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1) + unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) + regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + category_raw: Mapped[str | None] = mapped_column(String(100)) + normalized_product_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("normalized_products.id") + ) + + # Relationships + purchase: Mapped["Purchase"] = relationship(back_populates="items") + normalized_product: Mapped["NormalizedProduct | None"] = relationship( + back_populates="purchase_items" + ) + price_history_entries: Mapped[list["PriceHistory"]] = relationship( + back_populates="purchase_item" + ) diff --git a/src/cartsnitch_common/models/shrinkflation.py b/src/cartsnitch_common/models/shrinkflation.py new file mode 100644 index 0000000..d198713 --- /dev/null +++ b/src/cartsnitch_common/models/shrinkflation.py @@ -0,0 +1,41 @@ +"""ShrinkflationEvent model.""" + +import uuid +from datetime import date +from decimal import Decimal +from typing import TYPE_CHECKING + +from sqlalchemy import Date, ForeignKey, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.product import NormalizedProduct + + +class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Detected shrinkflation event — product size changed while price held or rose.""" + + __tablename__ = "shrinkflation_events" + + normalized_product_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("normalized_products.id"), nullable=False + ) + detected_date: Mapped[date] = mapped_column(Date, nullable=False) + old_size: Mapped[str] = mapped_column(String(50), nullable=False) + new_size: Mapped[str] = mapped_column(String(50), nullable=False) + old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False) + price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + confidence: Mapped[Decimal] = mapped_column( + Numeric(3, 2), nullable=False, default=Decimal("1.00") + ) + notes: Mapped[str | None] = mapped_column(String(1000)) + + # Relationships + normalized_product: Mapped["NormalizedProduct"] = relationship( + back_populates="shrinkflation_events" + ) diff --git a/src/cartsnitch_common/models/store.py b/src/cartsnitch_common/models/store.py new file mode 100644 index 0000000..cde7760 --- /dev/null +++ b/src/cartsnitch_common/models/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation models.""" + +import uuid +from typing import TYPE_CHECKING + +from sqlalchemy import Float, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import StoreSlug +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.coupon import Coupon + from cartsnitch_common.models.price import PriceHistory + from cartsnitch_common.models.purchase import Purchase + from cartsnitch_common.models.user import UserStoreAccount + + +class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Supported retailer.""" + + __tablename__ = "stores" + + name: Mapped[str] = mapped_column(String(100), nullable=False) + slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True) + logo_url: Mapped[str | None] = mapped_column(String(500)) + website_url: Mapped[str | None] = mapped_column(String(500)) + + # Relationships + locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store") + user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store") + price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store") + coupons: Mapped[list["Coupon"]] = relationship(back_populates="store") + + +class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Physical store location.""" + + __tablename__ = "store_locations" + + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + address: Mapped[str] = mapped_column(String(300), nullable=False) + city: Mapped[str] = mapped_column(String(100), nullable=False) + state: Mapped[str] = mapped_column(String(2), nullable=False) + zip: Mapped[str] = mapped_column(String(10), nullable=False) + lat: Mapped[float | None] = mapped_column(Float) + lng: Mapped[float | None] = mapped_column(Float) + + # Relationships + store: Mapped["Store"] = relationship(back_populates="locations") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location") diff --git a/src/cartsnitch_common/models/user.py b/src/cartsnitch_common/models/user.py new file mode 100644 index 0000000..e2b1bfb --- /dev/null +++ b/src/cartsnitch_common/models/user.py @@ -0,0 +1,51 @@ +"""User and UserStoreAccount models.""" + +import uuid +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from cartsnitch_common.constants import AccountStatus +from cartsnitch_common.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin + +if TYPE_CHECKING: + from cartsnitch_common.models.purchase import Purchase + from cartsnitch_common.models.store import Store + + +class User(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Application user.""" + + __tablename__ = "users" + + email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100)) + + # Relationships + store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user") + purchases: Mapped[list["Purchase"]] = relationship(back_populates="user") + + +class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base): + """Link between a user and their retailer account credentials.""" + + __tablename__ = "user_store_accounts" + __table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),) + + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False) + store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False) + # WARNING: Contains retailer session cookies/tokens. Encryption-at-rest + # required before production deployment (e.g., pgcrypto or app-level encryption). + session_data: Mapped[dict | None] = mapped_column(JSON) + session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + status: Mapped[AccountStatus] = mapped_column( + String(20), nullable=False, default=AccountStatus.ACTIVE + ) + + # Relationships + user: Mapped["User"] = relationship(back_populates="store_accounts") + store: Mapped["Store"] = relationship(back_populates="user_accounts") diff --git a/src/cartsnitch_common/normalization.py b/src/cartsnitch_common/normalization.py new file mode 100644 index 0000000..d448fa9 --- /dev/null +++ b/src/cartsnitch_common/normalization.py @@ -0,0 +1,156 @@ +"""Product normalization — Phase 1: UPC matching + fuzzy name matching. + +Matches products across retailers by: +1. Exact UPC match (highest confidence) +2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) +""" + +import re +from dataclasses import dataclass +from enum import StrEnum + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from cartsnitch_common.models.product import NormalizedProduct + + +class MatchMethod(StrEnum): + """How a product match was determined.""" + + UPC = "upc" + NAME = "name" + + +@dataclass(frozen=True) +class MatchResult: + """Result of a product normalization attempt.""" + + product: NormalizedProduct + confidence: float + method: MatchMethod + + +# Noise words stripped during name cleaning +_NOISE_WORDS = frozenset( + { + "the", + "a", + "an", + "and", + "or", + "of", + "with", + "in", + "for", + "to", + "brand", + "original", + "classic", + "new", + "improved", + } +) + +# Regex for extracting size info (e.g., "16 oz", "1.5 lb", "12 ct") +_SIZE_PATTERN = re.compile( + r"(\d+(?:\.\d+)?)\s*(oz|fl\s*oz|lb|lbs|g|kg|ml|l|ct|pk|count|pack)\b", + re.IGNORECASE, +) + + +def clean_name(name: str) -> str: + """Normalize a product name for comparison. + + - Lowercase + - Remove size info (e.g., "16 oz") + - Strip noise words + - Collapse whitespace + """ + cleaned = name.lower() + cleaned = _SIZE_PATTERN.sub("", cleaned) + cleaned = re.sub(r"[^\w\s]", " ", cleaned) + tokens = cleaned.split() + tokens = [t for t in tokens if t not in _NOISE_WORDS] + return " ".join(tokens) + + +def extract_size_info(name: str) -> tuple[str, str] | None: + """Extract (size, unit) from a product name, if present.""" + match = _SIZE_PATTERN.search(name) + if match: + return match.group(1), match.group(2).lower().replace(" ", "_") + return None + + +def jaccard_similarity(a: str, b: str) -> float: + """Token-based Jaccard similarity between two cleaned names.""" + tokens_a = set(a.split()) + tokens_b = set(b.split()) + if not tokens_a or not tokens_b: + return 0.0 + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) + + +def match_by_upc(session: Session, upc: str) -> MatchResult | None: + """Find a normalized product by exact UPC match. + + Loads products with upc_variants and checks membership in Python + for cross-database compatibility (works on both PostgreSQL and SQLite). + """ + # TODO: Use PostgreSQL JSON containment query (@>) for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None)) + products = session.execute(stmt).scalars().all() + for product in products: + if product.upc_variants and upc in product.upc_variants: + return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) + return None + + +def match_by_name( + session: Session, + name: str, + threshold: float = 0.5, +) -> MatchResult | None: + """Find the best normalized product by fuzzy name matching. + + Loads all normalized products and computes Jaccard similarity. + Returns the best match above the threshold, or None. + """ + # TODO: Use pg_trgm similarity index for production. + # Current approach loads all products into memory — acceptable for tests + # and small datasets, but will not scale. + cleaned = clean_name(name) + stmt = select(NormalizedProduct) + products = session.execute(stmt).scalars().all() + + best_match: NormalizedProduct | None = None + best_score = 0.0 + + for product in products: + score = jaccard_similarity(cleaned, clean_name(product.canonical_name)) + if score > best_score and score >= threshold: + best_score = score + best_match = product + + if best_match: + return MatchResult(product=best_match, confidence=best_score, method=MatchMethod.NAME) + return None + + +def normalize_product( + session: Session, + name: str, + upc: str | None = None, + name_threshold: float = 0.5, +) -> MatchResult | None: + """Full normalization pipeline: UPC first, then fuzzy name fallback.""" + if upc: + result = match_by_upc(session, upc) + if result: + return result + return match_by_name(session, name, threshold=name_threshold) diff --git a/src/cartsnitch_common/pipeline/__init__.py b/src/cartsnitch_common/pipeline/__init__.py new file mode 100644 index 0000000..82b586b --- /dev/null +++ b/src/cartsnitch_common/pipeline/__init__.py @@ -0,0 +1,26 @@ +"""Data pipeline — receipt normalization, product matching, price tracking, shrinkflation.""" + +from cartsnitch_common.pipeline.matching import ( + ConfidenceLevel, + ProductMatcher, + match_purchase_item, +) +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_price_trend, + record_price_from_item, +) +from cartsnitch_common.pipeline.receipt import normalize_receipt, parse_meijer_item +from cartsnitch_common.pipeline.shrinkflation import detect_shrinkflation + +__all__ = [ + "ConfidenceLevel", + "PriceDelta", + "ProductMatcher", + "detect_shrinkflation", + "get_price_trend", + "match_purchase_item", + "normalize_receipt", + "parse_meijer_item", + "record_price_from_item", +] diff --git a/src/cartsnitch_common/pipeline/matching.py b/src/cartsnitch_common/pipeline/matching.py new file mode 100644 index 0000000..ef4512f --- /dev/null +++ b/src/cartsnitch_common/pipeline/matching.py @@ -0,0 +1,136 @@ +"""Product matching & dedup — UPC primary, fuzzy name fallback, confidence scoring. + +Wraps the Phase 1 normalization module with confidence-level classification +and batch matching for purchase ingestion. +""" + +import uuid +from dataclasses import dataclass + +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import ( + MatchMethod, + MatchResult, + extract_size_info, + normalize_product, +) +from cartsnitch_common.schemas.purchase import PurchaseItemCreate + +# Re-export for convenience +ConfidenceLevel = MatchConfidence + + +@dataclass(frozen=True) +class MatchOutcome: + """Result of matching a single purchase item to a normalized product.""" + + item_index: int + match: MatchResult | None + confidence_level: MatchConfidence + created_new: bool = False + + +def classify_confidence(score: float, method: MatchMethod) -> MatchConfidence: + """Classify a match score into high/medium/low confidence.""" + if method == MatchMethod.UPC: + return MatchConfidence.HIGH + # Name-based matching thresholds + if score >= 0.8: + return MatchConfidence.HIGH + if score >= 0.5: + return MatchConfidence.MEDIUM + return MatchConfidence.LOW + + +def _create_product_from_item( + session: Session, + item: PurchaseItemCreate, +) -> NormalizedProduct: + """Create a new NormalizedProduct from a purchase item that had no match.""" + size_info = extract_size_info(item.product_name_raw) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=item.product_name_raw, + size=size_info[0] if size_info else None, + size_unit=size_info[1] if size_info else None, + upc_variants=[item.upc] if item.upc else [], + ) + session.add(product) + session.flush() + return product + + +class ProductMatcher: + """Batch product matcher for purchase ingestion. + + Usage: + matcher = ProductMatcher(session) + outcomes = matcher.match_items(items) + """ + + def __init__( + self, + session: Session, + name_threshold: float = 0.4, + auto_create: bool = True, + ): + self.session = session + self.name_threshold = name_threshold + self.auto_create = auto_create + + def match_single( + self, + item: PurchaseItemCreate, + ) -> tuple[NormalizedProduct | None, MatchResult | None, MatchConfidence]: + """Match a single purchase item to a normalized product. + + Returns (product, match_result, confidence_level). + If auto_create is True and no match found, creates a new product. + """ + result = normalize_product( + self.session, + item.product_name_raw, + upc=item.upc, + name_threshold=self.name_threshold, + ) + + if result: + confidence = classify_confidence(result.confidence, result.method) + return result.product, result, confidence + + if self.auto_create: + product = _create_product_from_item(self.session, item) + return product, None, MatchConfidence.LOW + + return None, None, MatchConfidence.LOW + + def match_items(self, items: list[PurchaseItemCreate]) -> list[MatchOutcome]: + """Match a batch of purchase items. Returns outcomes in order.""" + outcomes: list[MatchOutcome] = [] + for idx, item in enumerate(items): + product, result, confidence = self.match_single(item) + created = result is None and product is not None + outcomes.append( + MatchOutcome( + item_index=idx, + match=result, + confidence_level=confidence, + created_new=created, + ) + ) + return outcomes + + +def match_purchase_item( + session: Session, + item: PurchaseItemCreate, + name_threshold: float = 0.4, + auto_create: bool = True, +) -> tuple[NormalizedProduct | None, MatchConfidence]: + """Convenience function: match a single item, return (product, confidence).""" + matcher = ProductMatcher(session, name_threshold=name_threshold, auto_create=auto_create) + product, _, confidence = matcher.match_single(item) + return product, confidence diff --git a/src/cartsnitch_common/pipeline/price_tracking.py b/src/cartsnitch_common/pipeline/price_tracking.py new file mode 100644 index 0000000..24c3d5f --- /dev/null +++ b/src/cartsnitch_common/pipeline/price_tracking.py @@ -0,0 +1,130 @@ +"""Price history tracking — record prices and detect deltas. + +On each purchase ingestion, writes price_history records and detects +price changes from previous entries for the same product+store. +""" + +import uuid +from dataclasses import dataclass +from datetime import date +from decimal import Decimal + +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.models.price import PriceHistory + + +@dataclass(frozen=True) +class PriceDelta: + """A detected price change for a product at a store.""" + + product_id: uuid.UUID + store_id: uuid.UUID + old_price: Decimal + new_price: Decimal + change_amount: Decimal + change_percent: Decimal + old_date: date + new_date: date + + @property + def is_increase(self) -> bool: + return self.change_amount > 0 + + @property + def is_decrease(self) -> bool: + return self.change_amount < 0 + + +def get_latest_price( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, +) -> PriceHistory | None: + """Get the most recent price entry for a product at a store.""" + stmt = ( + select(PriceHistory) + .where( + and_( + PriceHistory.normalized_product_id == product_id, + PriceHistory.store_id == store_id, + ) + ) + .order_by(PriceHistory.observed_date.desc()) + .limit(1) + ) + return session.execute(stmt).scalar_one_or_none() + + +def record_price_from_item( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, + observed_date: date, + regular_price: Decimal, + sale_price: Decimal | None = None, + loyalty_price: Decimal | None = None, + coupon_price: Decimal | None = None, + purchase_item_id: uuid.UUID | None = None, + source: PriceSource = PriceSource.RECEIPT, +) -> tuple[PriceHistory, PriceDelta | None]: + """Record a price observation and return any detected delta. + + Returns (price_history_entry, price_delta_or_none). + """ + previous = get_latest_price(session, product_id, store_id) + + entry = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product_id, + store_id=store_id, + observed_date=observed_date, + regular_price=regular_price, + sale_price=sale_price, + loyalty_price=loyalty_price, + coupon_price=coupon_price, + source=source, + purchase_item_id=purchase_item_id, + ) + session.add(entry) + session.flush() + + delta = None + if previous and previous.regular_price != regular_price: + change = regular_price - previous.regular_price + pct = (change / previous.regular_price * 100) if previous.regular_price else Decimal("0") + delta = PriceDelta( + product_id=product_id, + store_id=store_id, + old_price=previous.regular_price, + new_price=regular_price, + change_amount=change, + change_percent=pct.quantize(Decimal("0.01")), + old_date=previous.observed_date, + new_date=observed_date, + ) + + return entry, delta + + +def get_price_trend( + session: Session, + product_id: uuid.UUID, + store_id: uuid.UUID, + limit: int = 30, +) -> list[PriceHistory]: + """Get recent price history for a product at a store, newest first.""" + stmt = ( + select(PriceHistory) + .where( + and_( + PriceHistory.normalized_product_id == product_id, + PriceHistory.store_id == store_id, + ) + ) + .order_by(PriceHistory.observed_date.desc()) + .limit(limit) + ) + return list(session.execute(stmt).scalars().all()) diff --git a/src/cartsnitch_common/pipeline/receipt.py b/src/cartsnitch_common/pipeline/receipt.py new file mode 100644 index 0000000..7d3e863 --- /dev/null +++ b/src/cartsnitch_common/pipeline/receipt.py @@ -0,0 +1,144 @@ +"""Receipt normalization — parse raw Meijer scraper output into purchase records. + +Maps raw receipt fields, cleans product names, extracts quantities/units. +""" + +import re +from datetime import date +from decimal import Decimal, InvalidOperation + +from cartsnitch_common.schemas.purchase import PurchaseCreate, PurchaseItemCreate + + +def _clean_product_name(raw: str) -> str: + """Clean raw product name from scraper output.""" + cleaned = raw.strip() + # Remove leading/trailing non-alphanumeric chars + cleaned = re.sub(r"^\W+|\W+$", "", cleaned) + # Collapse internal whitespace + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned + + +def _safe_decimal( + value: str | float | int | Decimal | None, + default: Decimal = Decimal("0"), +) -> Decimal: + """Safely convert a value to Decimal.""" + if value is None: + return default + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError): + return default + + +def parse_meijer_item(raw_item: dict) -> PurchaseItemCreate: + """Parse a single Meijer scraper line item into a PurchaseItemCreate. + + Expected raw_item keys (from Meijer scraper): + - description / name: product name + - upc / upcCode: UPC barcode + - quantity / qty: number of units + - unitPrice / price: per-unit price + - extendedPrice / totalPrice: line total + - regularPrice: shelf price before discounts + - salePrice: sale price if applicable + - couponAmount / couponDiscount: coupon savings + - loyaltyAmount / loyaltyDiscount: loyalty savings + - category / department: raw category + """ + name = raw_item.get("description") or raw_item.get("name") or "" + cleaned_name = _clean_product_name(name) + + upc = raw_item.get("upc") or raw_item.get("upcCode") + if upc: + upc = str(upc).strip().lstrip("0") or str(upc).strip() + + qty = _safe_decimal( + raw_item.get("quantity") or raw_item.get("qty"), + default=Decimal("1"), + ) + + unit_price = _safe_decimal(raw_item.get("unitPrice") or raw_item.get("price")) + extended = _safe_decimal(raw_item.get("extendedPrice") or raw_item.get("totalPrice")) + if extended == Decimal("0") and unit_price > 0: + extended = unit_price * qty + + regular = raw_item.get("regularPrice") + sale = raw_item.get("salePrice") + coupon = raw_item.get("couponAmount") or raw_item.get("couponDiscount") + loyalty = raw_item.get("loyaltyAmount") or raw_item.get("loyaltyDiscount") + category = raw_item.get("category") or raw_item.get("department") + + return PurchaseItemCreate( + product_name_raw=cleaned_name, + upc=upc, + quantity=qty, + unit_price=unit_price, + extended_price=extended, + regular_price=_safe_decimal(regular) if regular is not None else None, + sale_price=_safe_decimal(sale) if sale is not None else None, + coupon_discount=_safe_decimal(coupon) if coupon is not None else None, + loyalty_discount=_safe_decimal(loyalty) if loyalty is not None else None, + category_raw=str(category).strip() if category else None, + ) + + +def normalize_receipt( + raw_receipt: dict, + user_id: str, + store_id: str, +) -> PurchaseCreate: + """Parse a complete Meijer raw receipt into a PurchaseCreate. + + Expected raw_receipt keys: + - receiptId / receipt_id / id: unique receipt identifier + - date / purchaseDate / purchase_date: purchase date (YYYY-MM-DD or similar) + - total / totalAmount: receipt total + - subtotal: pre-tax subtotal + - tax / taxAmount: tax amount + - savings / totalSavings: total discount savings + - items: list of raw line item dicts + """ + import uuid + + receipt_id = str( + raw_receipt.get("receiptId") + or raw_receipt.get("receipt_id") + or raw_receipt.get("id") + or uuid.uuid4() + ) + + raw_date = ( + raw_receipt.get("date") + or raw_receipt.get("purchaseDate") + or raw_receipt.get("purchase_date") + ) + if isinstance(raw_date, str): + purchase_date = date.fromisoformat(raw_date[:10]) + elif isinstance(raw_date, date): + purchase_date = raw_date + else: + purchase_date = date.today() + + total = _safe_decimal(raw_receipt.get("total") or raw_receipt.get("totalAmount")) + subtotal = raw_receipt.get("subtotal") + tax = raw_receipt.get("tax") or raw_receipt.get("taxAmount") + savings = raw_receipt.get("savings") or raw_receipt.get("totalSavings") + + raw_items = raw_receipt.get("items") or [] + items = [parse_meijer_item(item) for item in raw_items] + + return PurchaseCreate( + user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id, + store_id=uuid.UUID(store_id) if isinstance(store_id, str) else store_id, + receipt_id=receipt_id, + purchase_date=purchase_date, + total=total, + subtotal=_safe_decimal(subtotal) if subtotal is not None else None, + tax=_safe_decimal(tax) if tax is not None else None, + savings_total=_safe_decimal(savings) if savings is not None else None, + raw_data=raw_receipt, + items=items, + ) diff --git a/src/cartsnitch_common/pipeline/shrinkflation.py b/src/cartsnitch_common/pipeline/shrinkflation.py new file mode 100644 index 0000000..0e6d2b3 --- /dev/null +++ b/src/cartsnitch_common/pipeline/shrinkflation.py @@ -0,0 +1,165 @@ +"""Shrinkflation detection — compare unit sizes across price history. + +Flags cases where a product's size decreased while price stayed flat or increased. +""" + +import uuid +from dataclasses import dataclass +from datetime import date +from decimal import Decimal + +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent + +# Conversion factors to a common base unit (grams for weight, ml for volume, count for discrete) +_WEIGHT_TO_GRAMS: dict[SizeUnit, Decimal] = { + SizeUnit.G: Decimal("1"), + SizeUnit.KG: Decimal("1000"), + SizeUnit.OZ: Decimal("28.3495"), + SizeUnit.LB: Decimal("453.592"), +} + +_VOLUME_TO_ML: dict[SizeUnit, Decimal] = { + SizeUnit.ML: Decimal("1"), + SizeUnit.L: Decimal("1000"), + SizeUnit.FL_OZ: Decimal("29.5735"), +} + +_COUNT_UNITS: set[SizeUnit] = {SizeUnit.CT, SizeUnit.PK} + + +def _to_comparable(size: str, unit: SizeUnit) -> Decimal | None: + """Convert a size+unit to a comparable numeric value. + + Returns None if units are not comparable (different measurement systems). + """ + try: + size_val = Decimal(size) + except Exception: + return None + + if unit in _WEIGHT_TO_GRAMS: + return size_val * _WEIGHT_TO_GRAMS[unit] + if unit in _VOLUME_TO_ML: + return size_val * _VOLUME_TO_ML[unit] + if unit in _COUNT_UNITS: + return size_val + return None + + +def _units_comparable(unit_a: SizeUnit, unit_b: SizeUnit) -> bool: + """Check if two units are in the same measurement system.""" + if unit_a in _WEIGHT_TO_GRAMS and unit_b in _WEIGHT_TO_GRAMS: + return True + if unit_a in _VOLUME_TO_ML and unit_b in _VOLUME_TO_ML: + return True + return unit_a in _COUNT_UNITS and unit_b in _COUNT_UNITS + + +@dataclass(frozen=True) +class ShrinkflationCandidate: + """A potential shrinkflation detection before writing to DB.""" + + product: NormalizedProduct + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + old_price: Decimal | None + new_price: Decimal | None + confidence: Decimal + size_change_pct: Decimal + + +def detect_shrinkflation( + session: Session, + product: NormalizedProduct, + new_size: str, + new_unit: SizeUnit, + new_price: Decimal | None = None, + detected_date: date | None = None, + min_size_decrease_pct: Decimal = Decimal("1"), +) -> ShrinkflationEvent | None: + """Check if a product's size has decreased (shrinkflation). + + Compares the new size against the product's recorded size. + If size decreased while price stayed flat or increased, records a shrinkflation event. + + Returns the ShrinkflationEvent if detected, None otherwise. + """ + if not product.size or not product.size_unit: + return None + + old_unit = SizeUnit(product.size_unit) + if not _units_comparable(old_unit, new_unit): + return None + + old_comparable = _to_comparable(product.size, old_unit) + new_comparable = _to_comparable(new_size, new_unit) + + if old_comparable is None or new_comparable is None: + return None + + if new_comparable >= old_comparable: + return None # Size didn't decrease + + size_change_pct = ((old_comparable - new_comparable) / old_comparable * 100).quantize( + Decimal("0.01") + ) + if size_change_pct < min_size_decrease_pct: + return None + + # Check existing events to avoid duplicates + existing = session.execute( + select(ShrinkflationEvent).where( + and_( + ShrinkflationEvent.normalized_product_id == product.id, + ShrinkflationEvent.old_size == product.size, + ShrinkflationEvent.new_size == new_size, + ) + ) + ).scalar_one_or_none() + + if existing: + return existing + + # Confidence: higher if size change is significant and price didn't drop + confidence = Decimal("0.70") + if size_change_pct >= Decimal("5"): + confidence = Decimal("0.85") + if size_change_pct >= Decimal("10"): + confidence = Decimal("0.95") + + # Get the last known price for comparison + old_price: Decimal | None = None + if product.price_histories: + latest = max(product.price_histories, key=lambda ph: ph.observed_date) + old_price = latest.regular_price + + if old_price is not None and new_price is not None and new_price < old_price: + # Price actually dropped — less likely to be shrinkflation + confidence = max(Decimal("0.30"), confidence - Decimal("0.30")) + + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=detected_date or date.today(), + old_size=product.size, + new_size=new_size, + old_unit=old_unit, + new_unit=new_unit, + price_at_old_size=old_price, + price_at_new_size=new_price, + confidence=confidence, + notes=( + f"Size decreased {size_change_pct}% ({product.size} {old_unit} → {new_size} {new_unit})" + ), + ) + session.add(event) + session.flush() + + return event diff --git a/src/cartsnitch_common/py.typed b/src/cartsnitch_common/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/cartsnitch_common/schemas/__init__.py b/src/cartsnitch_common/schemas/__init__.py new file mode 100644 index 0000000..aad52b4 --- /dev/null +++ b/src/cartsnitch_common/schemas/__init__.py @@ -0,0 +1,49 @@ +"""Pydantic v2 schemas for inter-service API contracts.""" + +from cartsnitch_common.schemas.coupon import CouponCreate, CouponRead +from cartsnitch_common.schemas.events import EventEnvelope +from cartsnitch_common.schemas.price import PriceHistoryCreate, PriceHistoryRead +from cartsnitch_common.schemas.product import NormalizedProductCreate, NormalizedProductRead +from cartsnitch_common.schemas.purchase import ( + PurchaseCreate, + PurchaseItemCreate, + PurchaseItemRead, + PurchaseRead, +) +from cartsnitch_common.schemas.shrinkflation import ShrinkflationEventCreate, ShrinkflationEventRead +from cartsnitch_common.schemas.store import ( + StoreCreate, + StoreLocationCreate, + StoreLocationRead, + StoreRead, +) +from cartsnitch_common.schemas.user import ( + UserCreate, + UserRead, + UserStoreAccountCreate, + UserStoreAccountRead, +) + +__all__ = [ + "StoreCreate", + "StoreRead", + "StoreLocationCreate", + "StoreLocationRead", + "UserCreate", + "UserRead", + "UserStoreAccountCreate", + "UserStoreAccountRead", + "PurchaseCreate", + "PurchaseRead", + "PurchaseItemCreate", + "PurchaseItemRead", + "NormalizedProductCreate", + "NormalizedProductRead", + "PriceHistoryCreate", + "PriceHistoryRead", + "CouponCreate", + "CouponRead", + "ShrinkflationEventCreate", + "ShrinkflationEventRead", + "EventEnvelope", +] diff --git a/src/cartsnitch_common/schemas/coupon.py b/src/cartsnitch_common/schemas/coupon.py new file mode 100644 index 0000000..cae2a20 --- /dev/null +++ b/src/cartsnitch_common/schemas/coupon.py @@ -0,0 +1,45 @@ +"""Coupon Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import DiscountType + + +class CouponCreate(BaseModel): + store_id: uuid.UUID + normalized_product_id: uuid.UUID | None = None + title: str + description: str | None = None + discount_type: DiscountType + discount_value: Decimal | None = None + min_purchase: Decimal | None = None + valid_from: date | None = None + valid_to: date | None = None + requires_clip: bool = False + coupon_code: str | None = None + source_url: str | None = None + + +class CouponRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + store_id: uuid.UUID + normalized_product_id: uuid.UUID | None + title: str + description: str | None + discount_type: DiscountType + discount_value: Decimal | None + min_purchase: Decimal | None + valid_from: date | None + valid_to: date | None + requires_clip: bool + coupon_code: str | None + source_url: str | None + scraped_at: datetime | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/events.py b/src/cartsnitch_common/schemas/events.py new file mode 100644 index 0000000..cff9adc --- /dev/null +++ b/src/cartsnitch_common/schemas/events.py @@ -0,0 +1,17 @@ +"""Redis pub/sub event envelope and payload schemas.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from cartsnitch_common.constants import EventType + + +class EventEnvelope(BaseModel): + """Standard event wrapper for all Redis pub/sub messages.""" + + event_type: EventType + timestamp: datetime + service: str + payload: dict[str, Any] diff --git a/src/cartsnitch_common/schemas/price.py b/src/cartsnitch_common/schemas/price.py new file mode 100644 index 0000000..4c46e1c --- /dev/null +++ b/src/cartsnitch_common/schemas/price.py @@ -0,0 +1,38 @@ +"""PriceHistory Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import PriceSource + + +class PriceHistoryCreate(BaseModel): + normalized_product_id: uuid.UUID + store_id: uuid.UUID + observed_date: date + regular_price: Decimal + sale_price: Decimal | None = None + loyalty_price: Decimal | None = None + coupon_price: Decimal | None = None + source: PriceSource + purchase_item_id: uuid.UUID | None = None + + +class PriceHistoryRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + normalized_product_id: uuid.UUID + store_id: uuid.UUID + observed_date: date + regular_price: Decimal + sale_price: Decimal | None + loyalty_price: Decimal | None + coupon_price: Decimal | None + source: PriceSource + purchase_item_id: uuid.UUID | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/product.py b/src/cartsnitch_common/schemas/product.py new file mode 100644 index 0000000..366e661 --- /dev/null +++ b/src/cartsnitch_common/schemas/product.py @@ -0,0 +1,33 @@ +"""NormalizedProduct Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel + +from cartsnitch_common.constants import ProductCategory, SizeUnit + + +class NormalizedProductCreate(BaseModel): + canonical_name: str + category: ProductCategory | None = None + subcategory: str | None = None + brand: str | None = None + size: str | None = None + size_unit: SizeUnit | None = None + upc_variants: list[str] = [] + + +class NormalizedProductRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + canonical_name: str + category: ProductCategory | None + subcategory: str | None + brand: str | None + size: str | None + size_unit: SizeUnit | None + upc_variants: list | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/purchase.py b/src/cartsnitch_common/schemas/purchase.py new file mode 100644 index 0000000..05959be --- /dev/null +++ b/src/cartsnitch_common/schemas/purchase.py @@ -0,0 +1,73 @@ +"""Purchase and PurchaseItem Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + + +class PurchaseItemCreate(BaseModel): + product_name_raw: str + upc: str | None = None + quantity: Decimal = Decimal("1") + unit_price: Decimal + extended_price: Decimal + regular_price: Decimal | None = None + sale_price: Decimal | None = None + coupon_discount: Decimal | None = None + loyalty_discount: Decimal | None = None + category_raw: str | None = None + normalized_product_id: uuid.UUID | None = None + + +class PurchaseItemRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + purchase_id: uuid.UUID + product_name_raw: str + upc: str | None + quantity: Decimal + unit_price: Decimal + extended_price: Decimal + regular_price: Decimal | None + sale_price: Decimal | None + coupon_discount: Decimal | None + loyalty_discount: Decimal | None + category_raw: str | None + normalized_product_id: uuid.UUID | None + + +class PurchaseCreate(BaseModel): + user_id: uuid.UUID + store_id: uuid.UUID + store_location_id: uuid.UUID | None = None + receipt_id: str + purchase_date: date + total: Decimal + subtotal: Decimal | None = None + tax: Decimal | None = None + savings_total: Decimal | None = None + source_url: str | None = None + raw_data: dict | None = None + items: list[PurchaseItemCreate] = [] + + +class PurchaseRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + user_id: uuid.UUID + store_id: uuid.UUID + store_location_id: uuid.UUID | None + receipt_id: str + purchase_date: date + total: Decimal + subtotal: Decimal | None + tax: Decimal | None + savings_total: Decimal | None + source_url: str | None + ingested_at: datetime + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/shrinkflation.py b/src/cartsnitch_common/schemas/shrinkflation.py new file mode 100644 index 0000000..4edb507 --- /dev/null +++ b/src/cartsnitch_common/schemas/shrinkflation.py @@ -0,0 +1,40 @@ +"""ShrinkflationEvent Pydantic schemas.""" + +import uuid +from datetime import date, datetime +from decimal import Decimal + +from pydantic import BaseModel + +from cartsnitch_common.constants import SizeUnit + + +class ShrinkflationEventCreate(BaseModel): + normalized_product_id: uuid.UUID + detected_date: date + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + price_at_old_size: Decimal | None = None + price_at_new_size: Decimal | None = None + confidence: Decimal = Decimal("1.00") + notes: str | None = None + + +class ShrinkflationEventRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + normalized_product_id: uuid.UUID + detected_date: date + old_size: str + new_size: str + old_unit: SizeUnit + new_unit: SizeUnit + price_at_old_size: Decimal | None + price_at_new_size: Decimal | None + confidence: Decimal + notes: str | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/store.py b/src/cartsnitch_common/schemas/store.py new file mode 100644 index 0000000..99fa83f --- /dev/null +++ b/src/cartsnitch_common/schemas/store.py @@ -0,0 +1,52 @@ +"""Store and StoreLocation Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel + +from cartsnitch_common.constants import StoreSlug + + +class StoreCreate(BaseModel): + name: str + slug: StoreSlug + logo_url: str | None = None + website_url: str | None = None + + +class StoreRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + name: str + slug: StoreSlug + logo_url: str | None + website_url: str | None + created_at: datetime + updated_at: datetime + + +class StoreLocationCreate(BaseModel): + store_id: uuid.UUID + address: str + city: str + state: str + zip: str + lat: float | None = None + lng: float | None = None + + +class StoreLocationRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + store_id: uuid.UUID + address: str + city: str + state: str + zip: str + lat: float | None + lng: float | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/schemas/user.py b/src/cartsnitch_common/schemas/user.py new file mode 100644 index 0000000..2c174ba --- /dev/null +++ b/src/cartsnitch_common/schemas/user.py @@ -0,0 +1,44 @@ +"""User and UserStoreAccount Pydantic schemas.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel, EmailStr + +from cartsnitch_common.constants import AccountStatus + + +class UserCreate(BaseModel): + email: EmailStr + password: str + display_name: str | None = None + + +class UserRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + email: str + display_name: str | None + created_at: datetime + updated_at: datetime + + +class UserStoreAccountCreate(BaseModel): + user_id: uuid.UUID + store_id: uuid.UUID + session_data: dict | None = None + status: AccountStatus = AccountStatus.ACTIVE + + +class UserStoreAccountRead(BaseModel): + model_config = {"from_attributes": True} + + id: uuid.UUID + user_id: uuid.UUID + store_id: uuid.UUID + status: AccountStatus + session_expires_at: datetime | None + last_sync_at: datetime | None + created_at: datetime + updated_at: datetime diff --git a/src/cartsnitch_common/seed/__init__.py b/src/cartsnitch_common/seed/__init__.py new file mode 100644 index 0000000..199d0be --- /dev/null +++ b/src/cartsnitch_common/seed/__init__.py @@ -0,0 +1 @@ +"""Deterministic seed data generator for CartSnitch dev environment.""" diff --git a/src/cartsnitch_common/seed/__main__.py b/src/cartsnitch_common/seed/__main__.py new file mode 100644 index 0000000..2a623fd --- /dev/null +++ b/src/cartsnitch_common/seed/__main__.py @@ -0,0 +1,50 @@ +"""Entry point for `python -m cartsnitch_common.seed` and `cartsnitch-seed` CLI.""" + +import argparse +import sys + +from cartsnitch_common.seed.config import SEED_VALUE + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="cartsnitch-seed", + description="Generate deterministic seed data for the CartSnitch dev environment.", + ) + parser.add_argument( + "--database-url", + default=None, + help=( + "PostgreSQL connection URL (sync driver). " + "Defaults to CARTSNITCH_DATABASE_URL_SYNC env var or built-in default." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print planned record counts without writing to the database.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED_VALUE, + help=f"Random seed for deterministic output (default: {SEED_VALUE}).", + ) + + args = parser.parse_args() + + try: + from cartsnitch_common.seed.runner import run_seed + + run_seed( + database_url=args.database_url, + seed_value=args.seed, + dry_run=args.dry_run, + ) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/cartsnitch_common/seed/config.py b/src/cartsnitch_common/seed/config.py new file mode 100644 index 0000000..ae4c27b --- /dev/null +++ b/src/cartsnitch_common/seed/config.py @@ -0,0 +1,38 @@ +"""Seed configuration constants.""" + +from datetime import date + +# Random seed for deterministic output +SEED_VALUE: int = 42 + +# Date window: 6 months of history ending today (relative to seed baseline) +SEED_BASELINE_DATE: date = date(2026, 3, 21) +SEED_START_DATE: date = date(2025, 9, 21) +SEED_END_DATE: date = date(2026, 3, 21) + +# Scale targets +NUM_STORES: int = 3 +NUM_LOCATIONS_PER_STORE: int = 5 # 15 total +NUM_USERS: int = 500 +NUM_ACTIVE_USERS: int = 50 +NUM_USER_STORE_ACCOUNTS: int = 100 +NUM_PRODUCTS: int = 500 +NUM_PURCHASES: int = 5_000 +NUM_PURCHASE_ITEMS: int = 25_000 +NUM_PRICE_HISTORY: int = 50_000 +NUM_COUPONS: int = 200 +NUM_SHRINKFLATION_EVENTS: int = 20 + +# Price-increase products (for StickerShock detection) +# 10% of products should show a significant price increase (>10%) over the window +NUM_PRICE_INCREASE_PRODUCTS: int = 50 # ~10% of 500 + +# Coupon mix +COUPON_EXPIRED_PCT: float = 0.60 +COUPON_ACTIVE_PCT: float = 0.40 + +# Items per purchase (target avg to hit 25K total from 5K purchases) +AVG_ITEMS_PER_PURCHASE: int = 5 + +# Price history: ~100 observations per product (500 products * 100 = 50K) +PRICE_OBS_PER_PRODUCT: int = 100 diff --git a/src/cartsnitch_common/seed/generators/__init__.py b/src/cartsnitch_common/seed/generators/__init__.py new file mode 100644 index 0000000..08cbc5f --- /dev/null +++ b/src/cartsnitch_common/seed/generators/__init__.py @@ -0,0 +1 @@ +"""Seed data generators.""" diff --git a/src/cartsnitch_common/seed/generators/coupons.py b/src/cartsnitch_common/seed/generators/coupons.py new file mode 100644 index 0000000..890e90a --- /dev/null +++ b/src/cartsnitch_common/seed/generators/coupons.py @@ -0,0 +1,107 @@ +"""Generate Coupon seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta +from decimal import Decimal + +from faker import Faker + +from cartsnitch_common.constants import DiscountType +from cartsnitch_common.seed.config import ( + COUPON_EXPIRED_PCT, + NUM_COUPONS, + SEED_END_DATE, + SEED_START_DATE, +) + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +_COUPON_TITLES: list[str] = [ + "Save {val} on {product}", + "{val} off your next {product} purchase", + "Get {val} off {product}", + "Buy {product}, save {val}", + "Weekend special: {val} off {product}", + "Member exclusive: {val} off {product}", + "Digital coupon: {val} off {product}", +] + + +def generate_coupons( + fake: Faker, + products: list[dict], + stores: list[dict], +) -> list[dict]: + """Return NUM_COUPONS coupon records with realistic mix of active/expired.""" + now = datetime.now(tz=UTC) + today = SEED_END_DATE + coupons = [] + + num_expired = int(NUM_COUPONS * COUPON_EXPIRED_PCT) + num_active = NUM_COUPONS - num_expired + + def make_coupon(is_active: bool) -> dict: + store = random.choice(stores) + product = random.choice(products) if random.random() > 0.1 else None + product_name = product["canonical_name"].split(" ", 2)[-1] if product else "any item" + + discount_type = random.choice(list(DiscountType)) + + if discount_type == DiscountType.PERCENT: + discount_value = _decimal(random.choice([5, 10, 15, 20, 25, 30])) + title = f"Save {int(discount_value)}% on {product_name}" + elif discount_type == DiscountType.FIXED: + discount_value = _decimal(random.choice([0.50, 1.00, 1.50, 2.00, 2.50, 3.00, 5.00])) + title = f"Save ${discount_value} on {product_name}" + elif discount_type == DiscountType.BOGO: + discount_value = None + title = f"BOGO: Buy one {product_name}, get one free" + else: # BUY_X_GET_Y + discount_value = None + title = f"Buy 2 {product_name}, get 1 free" + + if is_active: + valid_from = today - timedelta(days=random.randint(1, 30)) + valid_to = today + timedelta(days=random.randint(1, 60)) + else: + valid_to = today - timedelta(days=random.randint(1, 180)) + valid_from = valid_to - timedelta(days=random.randint(7, 30)) + + requires_clip = random.random() > 0.5 + coupon_code = fake.bothify(text="??##-??##").upper() if not requires_clip else None + min_purchase = _decimal(random.choice([0, 0, 0, 5.00, 10.00, 15.00])) or None + + scraped_at = datetime( + SEED_START_DATE.year, SEED_START_DATE.month, SEED_START_DATE.day, tzinfo=UTC + ) + timedelta(days=random.randint(0, 180)) + + return { + "id": uuid.uuid4(), + "store_id": store["id"], + "normalized_product_id": product["id"] if product else None, + "title": title, + "description": fake.sentence(nb_words=10), + "discount_type": discount_type, + "discount_value": discount_value, + "min_purchase": min_purchase, + "valid_from": valid_from, + "valid_to": valid_to, + "requires_clip": requires_clip, + "coupon_code": coupon_code, + "source_url": None, + "scraped_at": scraped_at, + "created_at": now, + "updated_at": now, + } + + for _ in range(num_expired): + coupons.append(make_coupon(is_active=False)) + for _ in range(num_active): + coupons.append(make_coupon(is_active=True)) + + random.shuffle(coupons) + return coupons diff --git a/src/cartsnitch_common/seed/generators/prices.py b/src/cartsnitch_common/seed/generators/prices.py new file mode 100644 index 0000000..3e87344 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/prices.py @@ -0,0 +1,162 @@ +"""Generate PriceHistory seed data with realistic patterns for StickerShock detection.""" + +import random +import uuid +from datetime import UTC, date, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.constants import PriceSource +from cartsnitch_common.seed.config import ( + NUM_PRICE_HISTORY, + NUM_PRICE_INCREASE_PRODUCTS, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + +# Holidays within the seed window for seasonal sales (approx) +_SALE_PERIODS: list[tuple[date, date]] = [ + (date(2025, 11, 27), date(2025, 11, 30)), # Thanksgiving / Black Friday + (date(2025, 12, 20), date(2025, 12, 26)), # Christmas + (date(2026, 1, 1), date(2026, 1, 2)), # New Year + (date(2026, 2, 14), date(2026, 2, 15)), # Valentine's Day +] + + +def _is_sale_period(d: date) -> bool: + return any(start <= d <= end for start, end in _SALE_PERIODS) + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +def _base_price_for_product(product: dict) -> float: + """Assign a realistic base price based on category.""" + from cartsnitch_common.constants import ProductCategory + + category_ranges: dict[ProductCategory, tuple[float, float]] = { + ProductCategory.PRODUCE: (1.49, 6.99), + ProductCategory.DAIRY: (2.99, 8.99), + ProductCategory.MEAT: (4.99, 19.99), + ProductCategory.BAKERY: (2.49, 7.99), + ProductCategory.FROZEN: (3.99, 12.99), + ProductCategory.PANTRY: (1.99, 9.99), + ProductCategory.BEVERAGES: (0.99, 6.99), + ProductCategory.SNACKS: (2.49, 6.99), + ProductCategory.HOUSEHOLD: (3.99, 19.99), + ProductCategory.PERSONAL_CARE: (3.99, 14.99), + } + cat: ProductCategory | None = product.get("category") + lo, hi = category_ranges.get(cat, (1.99, 9.99)) if cat is not None else (1.99, 9.99) + return random.uniform(lo, hi) + + +def generate_price_history( + products: list[dict], + stores: list[dict], + purchase_items: list[dict], +) -> list[dict]: + """Return ~NUM_PRICE_HISTORY price history records with realistic patterns. + + Pattern types (assigned per product): + - sudden_jump: flat then >10% price increase at a random point + - gradual_creep: slow steady increase over the window + - stable: nearly flat price with small noise + - sale_driven: drops during holiday periods, returns after + - volatile: random walk + + 10% of products (NUM_PRICE_INCREASE_PRODUCTS) will show a detectable + price increase (>10%) that StickerShock can flag. + """ + now = datetime.now(tz=UTC) + records: list[dict] = [] + + # Build purchase-item lookup: (product_id, store_id) -> [purchase_item_id] + item_lookup: dict[tuple, list[uuid.UUID]] = {} + for item in purchase_items: + key = (item["normalized_product_id"], item.get("_store_id")) + item_lookup.setdefault(key, []).append(item["id"]) + + total = NUM_PRICE_HISTORY + per_product_per_store = total // (len(products) * len(stores)) + per_product_per_store = max(per_product_per_store, 1) + + # Assign patterns + product_patterns: list[str] = [] + price_increase_indices = set(random.sample(range(len(products)), NUM_PRICE_INCREASE_PRODUCTS)) + pattern_pool = ["sale_driven", "stable", "gradual_creep", "volatile"] + for i in range(len(products)): + if i in price_increase_indices: + product_patterns.append(random.choice(["sudden_jump", "gradual_creep"])) + else: + product_patterns.append(random.choice(pattern_pool)) + + for i, product in enumerate(products): + pattern = product_patterns[i] + base_price = _base_price_for_product(product) + + # Jump point for sudden_jump (50-80% through window) + jump_day = int(_DATE_RANGE_DAYS * random.uniform(0.5, 0.8)) + jump_factor = random.uniform(1.10, 1.25) # 10-25% increase + + for store in stores: + # Generate obs dates spread across the window + obs_days = sorted( + random.sample( + range(_DATE_RANGE_DAYS + 1), + min(per_product_per_store, _DATE_RANGE_DAYS + 1), + ) + ) + + for day_offset in obs_days: + obs_date = SEED_START_DATE + timedelta(days=day_offset) + progress = day_offset / max(_DATE_RANGE_DAYS, 1) + + # Compute regular price by pattern + if pattern == "sudden_jump": + if day_offset < jump_day: + price = base_price + random.uniform(-0.05, 0.05) + else: + price = base_price * jump_factor + random.uniform(-0.05, 0.05) + elif pattern == "gradual_creep": + price = base_price * (1 + 0.12 * progress) + random.uniform(-0.10, 0.10) + elif pattern == "stable": + price = base_price + random.uniform(-0.10, 0.10) + elif pattern == "volatile": + price = base_price * random.uniform(0.85, 1.15) + else: + price = base_price + random.uniform(-0.05, 0.05) + + price = max(0.99, price) + regular_price = _decimal(price) + + # Sale price during holiday periods + sale_price: Decimal | None = None + if _is_sale_period(obs_date): + sale_price = _decimal(price * random.uniform(0.75, 0.90)) + + records.append( + { + "id": uuid.uuid4(), + "normalized_product_id": product["id"], + "store_id": store["id"], + "observed_date": obs_date, + "regular_price": regular_price, + "sale_price": sale_price, + "loyalty_price": None, + "coupon_price": None, + "source": ( + PriceSource.RECEIPT if random.random() > 0.3 else PriceSource.CATALOG + ), + "purchase_item_id": None, + "created_at": now, + "updated_at": now, + } + ) + + if len(records) >= NUM_PRICE_HISTORY: + return records + + return records diff --git a/src/cartsnitch_common/seed/generators/products.py b/src/cartsnitch_common/seed/generators/products.py new file mode 100644 index 0000000..6fad535 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/products.py @@ -0,0 +1,253 @@ +"""Generate NormalizedProduct seed data.""" + +import random +import uuid +from datetime import UTC, datetime + +from faker import Faker + +from cartsnitch_common.constants import ProductCategory, SizeUnit +from cartsnitch_common.seed.config import NUM_PRODUCTS + +# Product templates per category: (category, brands, names, sizes, default_unit) +_PRODUCT_TEMPLATES: list[tuple[ProductCategory, list[str], list[str], list[str], SizeUnit]] = [ + ( + ProductCategory.PRODUCE, + ["Organic Valley", "Earthbound Farm", "Local Farm", "Fresh Farms"], + [ + "Bananas", + "Apples", + "Baby Carrots", + "Spinach", + "Broccoli", + "Strawberries", + "Blueberries", + "Grapes", + "Tomatoes", + "Lettuce", + ], + ["1 lb", "2 lb", "16 oz", "12 oz", "5 oz", "6 oz", "32 oz"], + SizeUnit.LB, + ), + ( + ProductCategory.DAIRY, + ["Kraft", "Tillamook", "Great Value", "Land O'Lakes", "Daisy", "Organic Valley"], + [ + "Whole Milk", + "2% Milk", + "Cheddar Cheese", + "Mozzarella", + "Greek Yogurt", + "Butter", + "Cream Cheese", + "Sour Cream", + "Heavy Cream", + "Cottage Cheese", + ], + ["16 oz", "32 oz", "64 oz", "1 gallon", "8 oz", "12 oz", "5 oz"], + SizeUnit.FL_OZ, + ), + ( + ProductCategory.MEAT, + ["Tyson", "Perdue", "Smithfield", "Oscar Mayer", "Applegate", "Kirkland"], + [ + "Chicken Breast", + "Ground Beef", + "Pork Chops", + "Bacon", + "Turkey", + "Salmon", + "Tilapia", + "Sausage", + "Hot Dogs", + "Deli Ham", + ], + ["1 lb", "2 lb", "3 lb", "12 oz", "16 oz", "24 oz"], + SizeUnit.LB, + ), + ( + ProductCategory.BAKERY, + ["Nature's Own", "Dave's Killer Bread", "Pepperidge Farm", "Sara Lee", "Arnold"], + [ + "White Bread", + "Whole Wheat Bread", + "Sourdough", + "Bagels", + "English Muffins", + "Croissants", + "Dinner Rolls", + "Hamburger Buns", + "Hot Dog Buns", + "Muffins", + ], + ["20 oz", "24 oz", "6 ct", "8 ct", "12 ct", "16 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.FROZEN, + ["Stouffer's", "Amy's", "Birds Eye", "Green Giant", "Totino's", "DiGiorno"], + [ + "Frozen Pizza", + "Mac and Cheese", + "Frozen Burritos", + "Chicken Nuggets", + "Fish Sticks", + "Frozen Vegetables", + "Ice Cream", + "Frozen Waffles", + "Tater Tots", + "Frozen Lasagna", + ], + ["12 oz", "16 oz", "24 oz", "32 oz", "4 ct", "8 ct"], + SizeUnit.OZ, + ), + ( + ProductCategory.PANTRY, + ["Campbell's", "Hunt's", "Kraft", "Heinz", "Del Monte", "General Mills", "Kellogg's"], + [ + "Pasta Sauce", + "Canned Tomatoes", + "Chicken Noodle Soup", + "Peanut Butter", + "Jelly", + "Olive Oil", + "Rice", + "Pasta", + "Oatmeal", + "Cereal", + ], + ["15 oz", "24 oz", "32 oz", "18 oz", "16 oz", "24 oz", "48 oz", "64 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.BEVERAGES, + ["Coca-Cola", "Pepsi", "Tropicana", "Minute Maid", "Gatorade", "LaCroix", "Nestle"], + [ + "Cola", + "Diet Cola", + "Orange Juice", + "Apple Juice", + "Sports Drink", + "Sparkling Water", + "Iced Coffee", + "Energy Drink", + "Lemonade", + "Green Tea", + ], + ["12 fl oz", "20 fl oz", "32 fl oz", "64 fl oz", "2 liter", "6 pk", "12 pk"], + SizeUnit.FL_OZ, + ), + ( + ProductCategory.SNACKS, + ["Frito-Lay", "Nabisco", "Kellogg's", "Pepperidge Farm", "Clif Bar", "KIND", "Planters"], + [ + "Potato Chips", + "Tortilla Chips", + "Pretzels", + "Crackers", + "Granola Bars", + "Trail Mix", + "Popcorn", + "Cookies", + "Nuts", + "Fruit Snacks", + ], + ["7 oz", "10 oz", "16 oz", "6 ct", "12 ct", "18 ct", "3.5 oz"], + SizeUnit.OZ, + ), + ( + ProductCategory.HOUSEHOLD, + ["Tide", "Dawn", "Bounty", "Charmin", "Clorox", "Method", "Seventh Generation"], + [ + "Laundry Detergent", + "Dish Soap", + "Paper Towels", + "Toilet Paper", + "Bleach", + "All-Purpose Cleaner", + "Fabric Softener", + "Dryer Sheets", + "Trash Bags", + "Sponges", + ], + ["32 oz", "64 oz", "100 oz", "6 pk", "12 pk", "24 ct", "2 pk"], + SizeUnit.OZ, + ), + ( + ProductCategory.PERSONAL_CARE, + ["Dove", "Pantene", "Colgate", "Crest", "Gillette", "L'Oreal", "Neutrogena"], + [ + "Shampoo", + "Conditioner", + "Body Wash", + "Toothpaste", + "Deodorant", + "Face Wash", + "Lotion", + "Razor", + "Shaving Cream", + "Hand Soap", + ], + ["12 oz", "24 oz", "32 oz", "3.4 oz", "6 oz", "8 oz", "2 pk"], + SizeUnit.OZ, + ), +] + + +def _generate_upc() -> str: + """Generate a fake 12-digit UPC.""" + digits = [random.randint(0, 9) for _ in range(11)] + odd_sum = sum(digits[i] for i in range(0, 11, 2)) + even_sum = sum(digits[i] for i in range(1, 11, 2)) + check = (10 - ((odd_sum * 3 + even_sum) % 10)) % 10 + digits.append(check) + return "".join(str(d) for d in digits) + + +def generate_products(fake: Faker) -> list[dict]: + """Return NUM_PRODUCTS normalized product records.""" + now = datetime.now(tz=UTC) + products = [] + used_upcs: set[str] = set() + + per_category = NUM_PRODUCTS // len(_PRODUCT_TEMPLATES) + remainder = NUM_PRODUCTS % len(_PRODUCT_TEMPLATES) + + for i, (category, brands, names, sizes, default_unit) in enumerate(_PRODUCT_TEMPLATES): + count = per_category + (1 if i < remainder else 0) + for _ in range(count): + brand = random.choice(brands) + product_name = random.choice(names) + size_str = random.choice(sizes) + canonical_name = f"{brand} {product_name} {size_str}" + + size_parts = size_str.split(" ", 1) + size_val = size_parts[0] + + num_upcs = random.randint(1, 3) + upcs: list[str] = [] + for _ in range(num_upcs): + upc = _generate_upc() + attempts = 0 + while upc in used_upcs and attempts < 10: + upc = _generate_upc() + attempts += 1 + used_upcs.add(upc) + upcs.append(upc) + + products.append( + { + "id": uuid.uuid4(), + "canonical_name": canonical_name, + "category": category, + "subcategory": product_name, + "brand": brand, + "size": size_val, + "size_unit": default_unit, + "upc_variants": upcs, + "created_at": now, + "updated_at": now, + } + ) + + return products diff --git a/src/cartsnitch_common/seed/generators/purchases.py b/src/cartsnitch_common/seed/generators/purchases.py new file mode 100644 index 0000000..d023c73 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/purchases.py @@ -0,0 +1,156 @@ +"""Generate Purchase and PurchaseItem seed data.""" + +import random +import uuid +from datetime import UTC, date, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.seed.config import ( + NUM_PURCHASE_ITEMS, + NUM_PURCHASES, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + + +def _random_date() -> date: + return SEED_START_DATE + timedelta(days=random.randint(0, _DATE_RANGE_DAYS)) + + +def _decimal(val: float, places: int = 2) -> Decimal: + return Decimal(str(round(val, places))) + + +def generate_purchases( + users: list[dict], + stores: list[dict], + store_locations: list[dict], +) -> list[dict]: + """Return NUM_PURCHASES purchase records.""" + now = datetime.now(tz=UTC) + active_users = [u for u in users if u["_active"]] + inactive_users = [u for u in users if not u["_active"]] + + # Build location index by store_id + locs_by_store: dict = {} + for loc in store_locations: + locs_by_store.setdefault(loc["store_id"], []).append(loc) + + purchases = [] + seen_receipts: set[tuple] = set() + + # Active users get 80% of purchases + active_count = int(NUM_PURCHASES * 0.8) + inactive_count = NUM_PURCHASES - active_count + + def make_purchase(user: dict, store: dict) -> dict | None: + receipt_id = f"RCT-{random.randint(100000, 999999)}" + key = (user["id"], store["id"], receipt_id) + if key in seen_receipts: + return None + seen_receipts.add(key) + subtotal = _decimal(random.uniform(5.0, 150.0)) + tax = _decimal(float(subtotal) * 0.06) + savings = _decimal(random.uniform(0.0, float(subtotal) * 0.3)) + total = _decimal(float(subtotal) + float(tax) - float(savings)) + purchase_date = _random_date() + store_locs = locs_by_store.get(store["id"], []) + store_location_id = random.choice(store_locs)["id"] if store_locs else None + ingested_at = datetime( + purchase_date.year, purchase_date.month, purchase_date.day, tzinfo=UTC + ) + timedelta(hours=random.randint(1, 48)) + return { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "store_location_id": store_location_id, + "receipt_id": receipt_id, + "purchase_date": purchase_date, + "total": total, + "subtotal": subtotal, + "tax": tax, + "savings_total": savings if float(savings) > 0 else None, + "source_url": None, + "raw_data": None, + "ingested_at": ingested_at, + "created_at": now, + "updated_at": now, + } + + for _ in range(active_count): + user = random.choice(active_users) + store = random.choice(stores) + p = make_purchase(user, store) + if p: + purchases.append(p) + + for _ in range(inactive_count): + user = random.choice(inactive_users) + store = random.choice(stores) + p = make_purchase(user, store) + if p: + purchases.append(p) + + return purchases[:NUM_PURCHASES] + + +def generate_purchase_items( + purchases: list[dict], + products: list[dict], +) -> list[dict]: + """Return ~NUM_PURCHASE_ITEMS purchase item records distributed across purchases.""" + now = datetime.now(tz=UTC) + items: list[dict] = [] + total_target = NUM_PURCHASE_ITEMS + num_purchases = len(purchases) + + # Distribute items: avg 5 per purchase with variance + for i, purchase in enumerate(purchases): + # Remaining purchases get proportional share + remaining_purchases = num_purchases - i + remaining_items = total_target - len(items) + if remaining_purchases <= 0 or remaining_items <= 0: + break + avg = remaining_items / remaining_purchases + count = max(1, min(15, int(random.gauss(avg, 2)))) + count = min(count, remaining_items) + + for _ in range(count): + product = random.choice(products) + unit_price = _decimal(random.uniform(0.99, 25.99)) + quantity = Decimal("1.000") + extended_price = _decimal(float(unit_price) * float(quantity)) + has_sale = random.random() > 0.7 + sale_price = ( + _decimal(float(unit_price) * random.uniform(0.7, 0.95)) if has_sale else None + ) + has_coupon = random.random() > 0.85 + coupon_discount = _decimal(random.uniform(0.25, 2.00)) if has_coupon else None + + upc = None + if product["upc_variants"]: + upc = random.choice(product["upc_variants"]) + + items.append( + { + "id": uuid.uuid4(), + "purchase_id": purchase["id"], + "product_name_raw": product["canonical_name"], + "upc": upc, + "quantity": quantity, + "unit_price": unit_price, + "extended_price": extended_price, + "regular_price": unit_price, + "sale_price": sale_price, + "coupon_discount": coupon_discount, + "loyalty_discount": None, + "category_raw": product["category"].value if product["category"] else None, + "normalized_product_id": product["id"], + "created_at": now, + "updated_at": now, + } + ) + + return items diff --git a/src/cartsnitch_common/seed/generators/shrinkflation.py b/src/cartsnitch_common/seed/generators/shrinkflation.py new file mode 100644 index 0000000..9d833bb --- /dev/null +++ b/src/cartsnitch_common/seed/generators/shrinkflation.py @@ -0,0 +1,114 @@ +"""Generate ShrinkflationEvent seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta +from decimal import Decimal + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.seed.config import ( + NUM_SHRINKFLATION_EVENTS, + SEED_END_DATE, + SEED_START_DATE, +) + +_DATE_RANGE_DAYS = (SEED_END_DATE - SEED_START_DATE).days + +# Shrinkflation patterns: (old_size, new_size, unit, size_reduction_pct) +_SHRINK_PATTERNS: list[tuple[str, str, SizeUnit, float]] = [ + ("16", "14", SizeUnit.OZ, 0.125), + ("32", "28", SizeUnit.OZ, 0.125), + ("64", "56", SizeUnit.FL_OZ, 0.125), + ("18", "16", SizeUnit.OZ, 0.111), + ("20", "18", SizeUnit.OZ, 0.10), + ("2", "1.75", SizeUnit.LB, 0.125), + ("24", "21", SizeUnit.OZ, 0.125), + ("12", "10.5", SizeUnit.OZ, 0.125), + ("48", "42", SizeUnit.OZ, 0.125), + ("8", "7", SizeUnit.OZ, 0.125), + ("1", "0.875", SizeUnit.LB, 0.125), + ("36", "32", SizeUnit.OZ, 0.111), + ("6", "5", SizeUnit.CT, 0.167), + ("12", "10", SizeUnit.CT, 0.167), + ("100", "90", SizeUnit.CT, 0.10), + ("16.9", "15", SizeUnit.FL_OZ, 0.112), + ("3", "2.5", SizeUnit.LB, 0.167), + ("40", "35", SizeUnit.OZ, 0.125), + ("28", "24", SizeUnit.OZ, 0.143), + ("14.5", "12.5", SizeUnit.OZ, 0.138), +] + + +def _decimal(val: float) -> Decimal: + return Decimal(str(round(val, 2))) + + +def generate_shrinkflation_events(products: list[dict]) -> list[dict]: + """Return NUM_SHRINKFLATION_EVENTS shrinkflation event records. + + Selects products and assigns size changes where price is maintained or + increased despite the smaller package — valid inputs for ShrinkRay. + """ + now = datetime.now(tz=UTC) + events = [] + + # Pick NUM_SHRINKFLATION_EVENTS unique products (prefer pantry/snacks/household) + from cartsnitch_common.constants import ProductCategory + + preferred_cats = { + ProductCategory.PANTRY, + ProductCategory.SNACKS, + ProductCategory.HOUSEHOLD, + ProductCategory.PERSONAL_CARE, + ProductCategory.FROZEN, + ProductCategory.DAIRY, + ProductCategory.BEVERAGES, + } + preferred = [p for p in products if p.get("category") in preferred_cats] + fallback = [p for p in products if p not in preferred] + pool = preferred + fallback + + selected = random.sample(pool, min(NUM_SHRINKFLATION_EVENTS, len(pool))) + + for i, product in enumerate(selected): + pattern = _SHRINK_PATTERNS[i % len(_SHRINK_PATTERNS)] + old_size, new_size, unit, reduction_pct = pattern + + # Detection date: at least 60 days into window so there's history before + min_day = 60 + detected_day = random.randint(min_day, _DATE_RANGE_DAYS) + detected_date = SEED_START_DATE + timedelta(days=detected_day) + + # Price maintained or slightly increased despite size reduction + base_price = random.uniform(2.99, 12.99) + price_at_old_size = _decimal(base_price) + # flat or small increase despite size reduction + price_at_new_size = _decimal(base_price * random.uniform(0.98, 1.08)) + + confidence = _decimal(random.uniform(0.70, 0.99)) + + notes = ( + f"Package reduced from {old_size}{unit} to {new_size}{unit} " + f"({reduction_pct * 100:.1f}% reduction). " + f"Price {'increased' if price_at_new_size > price_at_old_size else 'held steady'}." + ) + + events.append( + { + "id": uuid.uuid4(), + "normalized_product_id": product["id"], + "detected_date": detected_date, + "old_size": old_size, + "new_size": new_size, + "old_unit": unit, + "new_unit": unit, + "price_at_old_size": price_at_old_size, + "price_at_new_size": price_at_new_size, + "confidence": confidence, + "notes": notes, + "created_at": now, + "updated_at": now, + } + ) + + return events diff --git a/src/cartsnitch_common/seed/generators/stores.py b/src/cartsnitch_common/seed/generators/stores.py new file mode 100644 index 0000000..e2ebb82 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/stores.py @@ -0,0 +1,203 @@ +"""Generate Store and StoreLocation seed data.""" + +import uuid +from datetime import UTC, datetime + +from cartsnitch_common.constants import StoreSlug +from cartsnitch_common.seed.config import NUM_LOCATIONS_PER_STORE + +# Fixed store definitions +_STORE_DEFS: list[dict] = [ + { + "name": "Meijer", + "slug": StoreSlug.MEIJER, + "logo_url": "https://www.meijer.com/favicon.ico", + "website_url": "https://www.meijer.com", + }, + { + "name": "Kroger", + "slug": StoreSlug.KROGER, + "logo_url": "https://www.kroger.com/favicon.ico", + "website_url": "https://www.kroger.com", + }, + { + "name": "Target", + "slug": StoreSlug.TARGET, + "logo_url": "https://www.target.com/favicon.ico", + "website_url": "https://www.target.com", + }, +] + +# SE Michigan locations per store (5 each = 15 total) +_LOCATION_DEFS: dict[StoreSlug, list[dict]] = { + StoreSlug.MEIJER: [ + { + "address": "3145 Ann Arbor-Saline Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2434, + "lng": -83.8102, + }, + { + "address": "700 W Ellsworth Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48108", + "lat": 42.2318, + "lng": -83.7581, + }, + { + "address": "5100 Oakman Blvd", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3223, + "lng": -83.1952, + }, + { + "address": "15555 Northline Rd", + "city": "Southgate", + "state": "MI", + "zip": "48195", + "lat": 42.2089, + "lng": -83.1953, + }, + { + "address": "2855 Washtenaw Ave", + "city": "Ypsilanti", + "state": "MI", + "zip": "48197", + "lat": 42.2461, + "lng": -83.6388, + }, + ], + StoreSlug.KROGER: [ + { + "address": "2010 W Stadium Blvd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2706, + "lng": -83.7807, + }, + { + "address": "1100 S Main St", + "city": "Ann Arbor", + "state": "MI", + "zip": "48104", + "lat": 42.2555, + "lng": -83.7469, + }, + { + "address": "23650 Michigan Ave", + "city": "Dearborn", + "state": "MI", + "zip": "48124", + "lat": 42.3221, + "lng": -83.2135, + }, + { + "address": "14000 Michigan Ave", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3281, + "lng": -83.1789, + }, + { + "address": "3965 Packard St", + "city": "Ann Arbor", + "state": "MI", + "zip": "48108", + "lat": 42.2298, + "lng": -83.7196, + }, + ], + StoreSlug.TARGET: [ + { + "address": "3165 Ann Arbor-Saline Rd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48103", + "lat": 42.2431, + "lng": -83.8097, + }, + { + "address": "4001 Carpenter Rd", + "city": "Ypsilanti", + "state": "MI", + "zip": "48197", + "lat": 42.2373, + "lng": -83.6617, + }, + { + "address": "16000 Ford Rd", + "city": "Dearborn", + "state": "MI", + "zip": "48126", + "lat": 42.3312, + "lng": -83.2098, + }, + { + "address": "17300 Eureka Rd", + "city": "Southgate", + "state": "MI", + "zip": "48195", + "lat": 42.2001, + "lng": -83.2014, + }, + { + "address": "2400 E Stadium Blvd", + "city": "Ann Arbor", + "state": "MI", + "zip": "48104", + "lat": 42.2624, + "lng": -83.7102, + }, + ], +} + + +def generate_stores() -> list[dict]: + """Return 3 fixed store records.""" + now = datetime.now(tz=UTC) + stores = [] + for defn in _STORE_DEFS: + stores.append( + { + "id": uuid.uuid4(), + "name": defn["name"], + "slug": defn["slug"], + "logo_url": defn["logo_url"], + "website_url": defn["website_url"], + "created_at": now, + "updated_at": now, + } + ) + return stores + + +def generate_store_locations(stores: list[dict]) -> list[dict]: + """Return 5 locations per store (15 total).""" + now = datetime.now(tz=UTC) + slug_to_id = {s["slug"]: s["id"] for s in stores} + locations = [] + for slug, loc_defs in _LOCATION_DEFS.items(): + store_id = slug_to_id[slug] + for loc in loc_defs[:NUM_LOCATIONS_PER_STORE]: + locations.append( + { + "id": uuid.uuid4(), + "store_id": store_id, + "address": loc["address"], + "city": loc["city"], + "state": loc["state"], + "zip": loc["zip"], + "lat": loc["lat"], + "lng": loc["lng"], + "created_at": now, + "updated_at": now, + } + ) + return locations diff --git a/src/cartsnitch_common/seed/generators/users.py b/src/cartsnitch_common/seed/generators/users.py new file mode 100644 index 0000000..6757b23 --- /dev/null +++ b/src/cartsnitch_common/seed/generators/users.py @@ -0,0 +1,105 @@ +"""Generate User and UserStoreAccount seed data.""" + +import random +import uuid +from datetime import UTC, datetime, timedelta + +from faker import Faker + +from cartsnitch_common.constants import AccountStatus +from cartsnitch_common.seed.config import ( + NUM_ACTIVE_USERS, + NUM_USER_STORE_ACCOUNTS, + NUM_USERS, + SEED_END_DATE, +) + + +def generate_users(fake: Faker) -> list[dict]: + """Return NUM_USERS user records. First NUM_ACTIVE_USERS are active.""" + now = datetime.now(tz=UTC) + users = [] + for i in range(NUM_USERS): + created_at = now - timedelta(days=random.randint(30, 365)) + users.append( + { + "id": uuid.uuid4(), + "email": fake.unique.email(), + "hashed_password": fake.sha256(), + "display_name": fake.name() if random.random() > 0.2 else None, + "created_at": created_at, + "updated_at": created_at, + "_active": i < NUM_ACTIVE_USERS, + } + ) + return users + + +def generate_user_store_accounts( + users: list[dict], + stores: list[dict], +) -> list[dict]: + """Return ~NUM_USER_STORE_ACCOUNTS user-store account links. + + Active users get accounts at multiple stores; inactive users may have none. + """ + now = datetime.now(tz=UTC) + accounts = [] + seen: set[tuple] = set() + + active_users = [u for u in users if u["_active"]] + inactive_users = [u for u in users if not u["_active"]] + + # Active users: each gets 1-3 store accounts + for user in active_users: + num_accounts = random.randint(1, 3) + selected_stores = random.sample(stores, min(num_accounts, len(stores))) + for store in selected_stores: + key = (user["id"], store["id"]) + if key in seen: + continue + seen.add(key) + last_sync = datetime( + SEED_END_DATE.year, + SEED_END_DATE.month, + SEED_END_DATE.day, + tzinfo=UTC, + ) - timedelta(days=random.randint(0, 14)) + accounts.append( + { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "session_data": {"token": "SEED_FAKE_TOKEN", "expires": "2026-12-31"}, + "session_expires_at": now + timedelta(days=random.randint(1, 90)), + "last_sync_at": last_sync, + "status": AccountStatus.ACTIVE, + "created_at": user["created_at"], + "updated_at": user["updated_at"], + } + ) + + # Fill remaining slots from inactive users + remaining = NUM_USER_STORE_ACCOUNTS - len(accounts) + for user in random.sample(inactive_users, min(remaining, len(inactive_users))): + store = random.choice(stores) + key = (user["id"], store["id"]) + if key in seen: + continue + seen.add(key) + status = random.choice([AccountStatus.EXPIRED, AccountStatus.ERROR, AccountStatus.ACTIVE]) + accounts.append( + { + "id": uuid.uuid4(), + "user_id": user["id"], + "store_id": store["id"], + "session_data": None, + "session_expires_at": None, + "last_sync_at": None, + "status": status, + "created_at": user["created_at"], + "updated_at": user["updated_at"], + } + ) + + return accounts[: NUM_USER_STORE_ACCOUNTS + len(active_users) * 3] diff --git a/src/cartsnitch_common/seed/runner.py b/src/cartsnitch_common/seed/runner.py new file mode 100644 index 0000000..c2b7784 --- /dev/null +++ b/src/cartsnitch_common/seed/runner.py @@ -0,0 +1,189 @@ +"""Seed runner: orchestrates generation and DB insertion in FK-safe order.""" + +import random +import time +from typing import Any + +from faker import Faker +from sqlalchemy import text +from sqlalchemy.orm import Session + +from cartsnitch_common.database import get_sync_session_factory +from cartsnitch_common.models.coupon import Coupon +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.purchase import Purchase, PurchaseItem +from cartsnitch_common.models.shrinkflation import ShrinkflationEvent +from cartsnitch_common.models.store import Store, StoreLocation +from cartsnitch_common.models.user import User, UserStoreAccount +from cartsnitch_common.seed.config import SEED_VALUE +from cartsnitch_common.seed.generators.coupons import generate_coupons +from cartsnitch_common.seed.generators.prices import generate_price_history +from cartsnitch_common.seed.generators.products import generate_products +from cartsnitch_common.seed.generators.purchases import generate_purchase_items, generate_purchases +from cartsnitch_common.seed.generators.shrinkflation import generate_shrinkflation_events +from cartsnitch_common.seed.generators.stores import generate_store_locations, generate_stores +from cartsnitch_common.seed.generators.users import generate_user_store_accounts, generate_users + +# FK-safe truncation order (reverse of insertion order) +_TRUNCATE_TABLES: list[str] = [ + "shrinkflation_events", + "coupons", + "price_history", + "purchase_items", + "purchases", + "user_store_accounts", + "normalized_products", + "users", + "store_locations", + "stores", +] + + +def _log(msg: str) -> None: + print(msg, flush=True) + + +def _bulk_insert(session: Session, model: type, rows: list[dict[str, Any]]) -> None: + """Insert rows using core INSERT for performance, stripping private keys.""" + if not rows: + return + # Strip internal keys (prefixed with _) + clean = [{k: v for k, v in row.items() if not k.startswith("_")} for row in rows] + session.execute(model.__table__.insert(), clean) # type: ignore[attr-defined] + + +def run_seed( + database_url: str | None = None, + seed_value: int = SEED_VALUE, + dry_run: bool = False, +) -> None: + """Generate and insert all seed data. + + Args: + database_url: Optional override for the DB connection URL. + seed_value: Random seed for deterministic output. + dry_run: If True, print planned counts without touching the DB. + """ + random.seed(seed_value) + fake = Faker() + Faker.seed(seed_value) + + _log("=== CartSnitch Seed Data Generator ===") + _log(f"Seed: {seed_value}") + + # --- Generation phase --- + t0 = time.monotonic() + + _log("Generating stores...") + stores = generate_stores() + _log(f" {len(stores)} stores ({time.monotonic() - t0:.2f}s)") + + _log("Generating store locations...") + store_locations = generate_store_locations(stores) + _log(f" {len(store_locations)} store locations ({time.monotonic() - t0:.2f}s)") + + _log("Generating users...") + users = generate_users(fake) + _log(f" {len(users)} users ({time.monotonic() - t0:.2f}s)") + + _log("Generating user store accounts...") + user_store_accounts = generate_user_store_accounts(users, stores) + _log(f" {len(user_store_accounts)} user store accounts ({time.monotonic() - t0:.2f}s)") + + _log("Generating products...") + products = generate_products(fake) + _log(f" {len(products)} products ({time.monotonic() - t0:.2f}s)") + + _log("Generating purchases...") + purchases = generate_purchases(users, stores, store_locations) + _log(f" {len(purchases)} purchases ({time.monotonic() - t0:.2f}s)") + + _log("Generating purchase items...") + purchase_items = generate_purchase_items(purchases, products) + _log(f" {len(purchase_items)} purchase items ({time.monotonic() - t0:.2f}s)") + + _log("Generating price history...") + price_history = generate_price_history(products, stores, purchase_items) + _log(f" {len(price_history)} price history records ({time.monotonic() - t0:.2f}s)") + + _log("Generating coupons...") + coupons = generate_coupons(fake, products, stores) + _log(f" {len(coupons)} coupons ({time.monotonic() - t0:.2f}s)") + + _log("Generating shrinkflation events...") + shrinkflation_events = generate_shrinkflation_events(products) + _log(f" {len(shrinkflation_events)} shrinkflation events ({time.monotonic() - t0:.2f}s)") + + _log("") + _log("=== Summary ===") + _log(f" stores: {len(stores)}") + _log(f" store_locations: {len(store_locations)}") + _log(f" users: {len(users)}") + _log(f" user_store_accounts: {len(user_store_accounts)}") + _log(f" normalized_products: {len(products)}") + _log(f" purchases: {len(purchases)}") + _log(f" purchase_items: {len(purchase_items)}") + _log(f" price_history: {len(price_history)}") + _log(f" coupons: {len(coupons)}") + _log(f" shrinkflation_events: {len(shrinkflation_events)}") + + if dry_run: + _log("") + _log("Dry run — no data written.") + return + + # --- DB insertion phase --- + factory = get_sync_session_factory(database_url) + with factory() as session: + _log("") + _log("Truncating tables (reverse FK order)...") + for table in _TRUNCATE_TABLES: + session.execute(text(f"TRUNCATE TABLE {table} CASCADE")) + _log(" done") + + _log("Inserting stores...") + _bulk_insert(session, Store, stores) + _log(f" {len(stores)} inserted") + + _log("Inserting store locations...") + _bulk_insert(session, StoreLocation, store_locations) + _log(f" {len(store_locations)} inserted") + + _log("Inserting users...") + _bulk_insert(session, User, users) + _log(f" {len(users)} inserted") + + _log("Inserting user store accounts...") + _bulk_insert(session, UserStoreAccount, user_store_accounts) + _log(f" {len(user_store_accounts)} inserted") + + _log("Inserting products...") + _bulk_insert(session, NormalizedProduct, products) + _log(f" {len(products)} inserted") + + _log("Inserting purchases...") + _bulk_insert(session, Purchase, purchases) + _log(f" {len(purchases)} inserted") + + _log("Inserting purchase items...") + _bulk_insert(session, PurchaseItem, purchase_items) + _log(f" {len(purchase_items)} inserted") + + _log("Inserting price history...") + _bulk_insert(session, PriceHistory, price_history) + _log(f" {len(price_history)} inserted") + + _log("Inserting coupons...") + _bulk_insert(session, Coupon, coupons) + _log(f" {len(coupons)} inserted") + + _log("Inserting shrinkflation events...") + _bulk_insert(session, ShrinkflationEvent, shrinkflation_events) + _log(f" {len(shrinkflation_events)} inserted") + + session.commit() + + elapsed = time.monotonic() - t0 + _log("") + _log(f"Seed complete in {elapsed:.1f}s") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6bfc994 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +"""Shared test fixtures for cartsnitch-common tests.""" + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from cartsnitch_common.models.base import Base + + +@pytest.fixture +def engine(): + """In-memory SQLite engine for unit tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def session(engine): + """SQLAlchemy session bound to in-memory SQLite.""" + factory = sessionmaker(bind=engine) + with factory() as sess: + yield sess diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..8b5eb68 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,376 @@ +"""Tests for SQLAlchemy ORM models.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from sqlalchemy import inspect + +from cartsnitch_common.constants import ( + AccountStatus, + DiscountType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.models import ( + Coupon, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + StoreLocation, + User, + UserStoreAccount, +) + + +class TestTableCreation: + """Verify all expected tables are created.""" + + def test_all_tables_exist(self, engine): + inspector = inspect(engine) + table_names = set(inspector.get_table_names()) + expected = { + "stores", + "store_locations", + "users", + "user_store_accounts", + "purchases", + "purchase_items", + "normalized_products", + "price_history", + "coupons", + "shrinkflation_events", + } + assert expected.issubset(table_names) + + def test_ten_tables_total(self, engine): + inspector = inspect(engine) + assert len(inspector.get_table_names()) == 10 + + +class TestUUIDPrimaryKeys: + """All models use UUID PKs.""" + + def test_store_uuid_pk(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert isinstance(store.id, uuid.UUID) + + def test_user_uuid_pk(self, session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(user) + session.commit() + assert isinstance(user.id, uuid.UUID) + + +class TestStoreModel: + def test_store_slug_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.commit() + assert store.slug == StoreSlug.KROGER + + def test_store_unique_slug(self, session): + s1 = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + s2 = Store( + id=uuid.uuid4(), + name="Target Duplicate", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(s1) + session.commit() + session.add(s2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestStoreLocationModel: + def test_store_location_fields(self, session): + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + loc = StoreLocation( + id=uuid.uuid4(), + store_id=store.id, + address="123 Main St", + city="Ann Arbor", + state="MI", + zip="48104", + lat=42.2808, + lng=-83.7430, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(loc) + session.commit() + assert loc.city == "Ann Arbor" + assert loc.lat == pytest.approx(42.2808) + + +class TestUserStoreAccountModel: + def test_account_status_enum(self, session): + user = User( + id=uuid.uuid4(), + email="test@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + acct = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(acct) + session.commit() + assert acct.status == AccountStatus.ACTIVE + + def test_unique_user_store_constraint(self, session): + """One account per user per store.""" + user = User( + id=uuid.uuid4(), + email="unique@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + a1 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.ACTIVE, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + a2 = UserStoreAccount( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + status=AccountStatus.EXPIRED, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(a1) + session.commit() + session.add(a2) + with pytest.raises(Exception): # noqa: B017 + session.commit() + session.rollback() + + +class TestPurchaseModel: + def test_purchase_with_items(self, session): + user = User( + id=uuid.uuid4(), + email="buyer@test.com", + hashed_password="hashed", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=StoreSlug.MEIJER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([user, store]) + session.flush() + purchase = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + ingested_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(purchase) + session.flush() + item = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase.id, + product_name_raw="Meijer Whole Milk 1 Gallon", + upc="0041250000001", + quantity=Decimal("1"), + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ) + session.add(item) + session.commit() + assert item.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert item.unit_price == Decimal("3.49") + + +class TestNormalizedProductModel: + def test_product_with_upc_variants(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + brand="Store Brand", + size="128", + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + assert product.category == ProductCategory.DAIRY + assert product.size_unit == SizeUnit.FL_OZ + + +class TestPriceHistoryModel: + def test_price_source_enum(self, session): + store = Store( + id=uuid.uuid4(), + name="Kroger", + slug=StoreSlug.KROGER, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Eggs, Large, 12ct", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add_all([store, product]) + session.flush() + ph = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + sale_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(ph) + session.commit() + assert ph.source == PriceSource.RECEIPT + assert ph.regular_price == Decimal("4.99") + + +class TestCouponModel: + def test_coupon_discount_types(self, session): + store = Store( + id=uuid.uuid4(), + name="Target", + slug=StoreSlug.TARGET, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + coupon = Coupon( + id=uuid.uuid4(), + store_id=store.id, + title="$2 off eggs", + discount_type=DiscountType.FIXED, + discount_value=Decimal("2.00"), + requires_clip=True, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(coupon) + session.commit() + assert coupon.discount_type == DiscountType.FIXED + assert coupon.discount_value == Decimal("2.00") + + +class TestShrinkflationEventModel: + def test_shrinkflation_event(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Cereal, Honey Oats", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + event = ShrinkflationEvent( + id=uuid.uuid4(), + normalized_product_id=product.id, + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + price_at_old_size=Decimal("4.99"), + price_at_new_size=Decimal("4.99"), + confidence=Decimal("0.95"), + notes="Size reduced by 14.4%, price unchanged", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(event) + session.commit() + assert event.confidence == Decimal("0.95") + assert event.old_unit == SizeUnit.OZ diff --git a/tests/test_normalization.py b/tests/test_normalization.py new file mode 100644 index 0000000..d60c4d5 --- /dev/null +++ b/tests/test_normalization.py @@ -0,0 +1,157 @@ +"""Tests for product normalization module.""" + +import uuid +from datetime import UTC, datetime + +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import ( + MatchMethod, + clean_name, + extract_size_info, + jaccard_similarity, + match_by_name, + match_by_upc, + normalize_product, +) + + +class TestCleanName: + def test_lowercase(self): + assert clean_name("Kroger WHOLE MILK") == "kroger whole milk" + + def test_removes_size_info(self): + assert "oz" not in clean_name("Milk 16 oz Whole") + + def test_removes_noise_words(self): + cleaned = clean_name("The Original Brand Milk") + assert "the" not in cleaned.split() + assert "original" not in cleaned.split() + assert "brand" not in cleaned.split() + + def test_collapses_whitespace(self): + assert " " not in clean_name("Milk Whole Gallon") + + def test_removes_punctuation(self): + cleaned = clean_name("Meijer's Best (Organic) Milk!") + assert "'" not in cleaned + assert "(" not in cleaned + + +class TestExtractSizeInfo: + def test_extracts_oz(self): + result = extract_size_info("Cereal 18 oz box") + assert result == ("18", "oz") + + def test_extracts_fl_oz(self): + result = extract_size_info("Juice 64 fl oz") + assert result == ("64", "fl_oz") + + def test_extracts_lb(self): + result = extract_size_info("Ground Beef 1.5 lb") + assert result == ("1.5", "lb") + + def test_extracts_ct(self): + result = extract_size_info("Eggs Large 12 ct") + assert result == ("12", "ct") + + def test_no_size_returns_none(self): + assert extract_size_info("Bananas") is None + + +class TestJaccardSimilarity: + def test_identical_strings(self): + assert jaccard_similarity("whole milk gallon", "whole milk gallon") == 1.0 + + def test_completely_different(self): + assert jaccard_similarity("apple juice", "ground beef") == 0.0 + + def test_partial_overlap(self): + score = jaccard_similarity("kroger whole milk", "meijer whole milk") + assert 0.4 < score < 0.8 # "whole" and "milk" overlap + + def test_empty_strings(self): + assert jaccard_similarity("", "") == 0.0 + assert jaccard_similarity("milk", "") == 0.0 + + +class TestMatchByUPC: + def test_match_found(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + upc_variants=["0041250000001", "0041250000002"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + # SQLite doesn't support JSONB containment — this will raise + # In production (PostgreSQL), this would work + result = match_by_upc(session, "0041250000001") + assert result is not None + assert result.method == MatchMethod.UPC + assert result.confidence == 1.0 + + def test_no_match(self, session): + result = match_by_upc(session, "9999999999999") + assert result is None + + +class TestMatchByName: + def test_exact_name_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk, Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Whole Milk Gallon") + assert result is not None + assert result.method == MatchMethod.NAME + assert result.confidence > 0.5 + + def test_fuzzy_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Kroger Whole Milk, 1 Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Meijer Whole Milk 1 Gallon", threshold=0.3) + assert result is not None + assert result.confidence > 0.3 + + def test_no_match_below_threshold(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = match_by_name(session, "Apple Juice 64 oz", threshold=0.5) + assert result is None + + +class TestNormalizeProduct: + def test_name_fallback(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs, 12 count", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + result = normalize_product(session, "Large Eggs 12 ct", upc=None) + assert result is not None + assert result.method == MatchMethod.NAME + + def test_no_match(self, session): + result = normalize_product(session, "Nonexistent Product XYZ", upc=None) + assert result is None diff --git a/tests/test_pipeline_e2e.py b/tests/test_pipeline_e2e.py new file mode 100644 index 0000000..03737f2 --- /dev/null +++ b/tests/test_pipeline_e2e.py @@ -0,0 +1,949 @@ +"""End-to-end integration tests for the data pipeline. + +Tests the full flow: scraper output → normalization → product matching → DB storage +→ price tracking → shrinkflation detection → event publishing. + +Uses real test fixtures with an in-memory SQLite database, not mocks. +""" + +import uuid +from datetime import date +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, sessionmaker + +from cartsnitch_common.constants import ( + EventType, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.events import publish_event +from cartsnitch_common.models import ( + Base, + NormalizedProduct, + PriceHistory, + Purchase, + PurchaseItem, + ShrinkflationEvent, + Store, + User, +) +from cartsnitch_common.pipeline.matching import ProductMatcher +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_price_trend, + record_price_from_item, +) +from cartsnitch_common.pipeline.receipt import normalize_receipt, parse_meijer_item +from cartsnitch_common.pipeline.shrinkflation import detect_shrinkflation +from cartsnitch_common.schemas.events import EventEnvelope +from cartsnitch_common.schemas.purchase import PurchaseCreate + +# --------------------------------------------------------------------------- +# Fixtures: realistic scraper output from Meijer +# --------------------------------------------------------------------------- + +MEIJER_RECEIPT_FIXTURE = { + "receiptId": "MJ-2026-03-15-00042", + "date": "2026-03-15", + "total": "47.82", + "subtotal": "44.50", + "taxAmount": "3.32", + "totalSavings": "6.20", + "items": [ + { + "description": " Meijer Whole Milk 1 Gallon ", + "upcCode": "00041250010001", + "quantity": 1, + "unitPrice": "3.29", + "extendedPrice": "3.29", + "regularPrice": "3.49", + "salePrice": "3.29", + "category": "Dairy", + }, + { + "name": "BARILLA SPAGHETTI 16 OZ", + "upc": "076808280753", + "qty": 2, + "price": "1.69", + "totalPrice": "3.38", + "regularPrice": "1.89", + "couponDiscount": "0.40", + "department": "Pantry", + }, + { + "description": "Meijer Lean Ground Beef 1 lb", + "upcCode": "00041250022004", + "quantity": 1, + "unitPrice": "5.99", + "extendedPrice": "5.99", + "regularPrice": "6.49", + "loyaltyDiscount": "0.50", + "category": "Meat", + }, + { + "description": "Cheerios Original 12 oz", + "upcCode": "016000275645", + "quantity": 1, + "unitPrice": "4.49", + "extendedPrice": "4.49", + "regularPrice": "4.49", + "category": "Snacks", + }, + { + "description": "Fresh Bananas", + "quantity": 1, + "unitPrice": "0.69", + "extendedPrice": "0.69", + "category": "Produce", + }, + ], +} + +MEIJER_RECEIPT_SECOND_VISIT = { + "receiptId": "MJ-2026-03-18-00099", + "date": "2026-03-18", + "total": "12.47", + "items": [ + { + "description": "Meijer Whole Milk 1 Gallon", + "upcCode": "00041250010001", + "quantity": 1, + "unitPrice": "3.49", + "extendedPrice": "3.49", + "regularPrice": "3.49", + "category": "Dairy", + }, + { + "description": "BARILLA SPAGHETTI 16 OZ", + "upc": "076808280753", + "qty": 1, + "price": "1.99", + "totalPrice": "1.99", + "regularPrice": "1.99", + "department": "Pantry", + }, + { + "description": "Cheerios Original 10.8 oz", + "upcCode": "016000275645", + "quantity": 1, + "unitPrice": "4.49", + "extendedPrice": "4.49", + "regularPrice": "4.49", + "category": "Snacks", + }, + ], +} + + +@pytest.fixture +def e2e_engine(): + """In-memory SQLite engine for E2E tests.""" + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def e2e_session(e2e_engine): + """SQLAlchemy session with pre-seeded store and user.""" + factory = sessionmaker(bind=e2e_engine) + with factory() as sess: + yield sess + + +@pytest.fixture +def store(e2e_session: Session) -> Store: + """Seed a Meijer store.""" + s = Store(id=uuid.uuid4(), name="Meijer", slug=StoreSlug.MEIJER) + e2e_session.add(s) + e2e_session.flush() + return s + + +@pytest.fixture +def user(e2e_session: Session) -> User: + """Seed a test user.""" + u = User( + id=uuid.uuid4(), + email="tester@cartsnitch.com", + hashed_password="hashed_test_password", + display_name="Test User", + ) + e2e_session.add(u) + e2e_session.flush() + return u + + +@pytest.fixture +def redis_mock(): + """A lightweight Redis mock that captures published messages.""" + client = MagicMock() + published: list[tuple[str, str]] = [] + + def _publish(channel: str, message: str) -> int: + published.append((channel, message)) + return 1 + + client.publish = MagicMock(side_effect=_publish) + client._published = published + return client + + +# =========================================================================== +# Test class: Full pipeline E2E — scraper → normalization → matching → storage +# =========================================================================== + + +class TestFullPipelineE2E: + """Scraper output → normalize_receipt → ProductMatcher → DB storage.""" + + def test_normalize_meijer_receipt(self, user: User, store: Store): + """Raw Meijer receipt normalizes into a valid PurchaseCreate.""" + purchase = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + assert isinstance(purchase, PurchaseCreate) + assert purchase.receipt_id == "MJ-2026-03-15-00042" + assert purchase.purchase_date == date(2026, 3, 15) + assert purchase.total == Decimal("47.82") + assert purchase.subtotal == Decimal("44.50") + assert purchase.tax == Decimal("3.32") + assert purchase.savings_total == Decimal("6.20") + assert len(purchase.items) == 5 + assert purchase.raw_data == MEIJER_RECEIPT_FIXTURE + + def test_item_field_normalization(self, user: User, store: Store): + """Items parse correctly regardless of field name variants.""" + purchase = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Item using 'description' / 'upcCode' fields + milk = purchase.items[0] + assert milk.product_name_raw == "Meijer Whole Milk 1 Gallon" + assert milk.upc == "41250010001" # leading zeros stripped + assert milk.unit_price == Decimal("3.29") + + # Item using 'name' / 'upc' / 'qty' / 'price' / 'totalPrice' fields + pasta = purchase.items[1] + assert pasta.product_name_raw == "BARILLA SPAGHETTI 16 OZ" + assert pasta.upc == "76808280753" + assert pasta.quantity == Decimal("2") + assert pasta.extended_price == Decimal("3.38") + assert pasta.coupon_discount == Decimal("0.40") + + def test_upc_product_matching_and_storage(self, e2e_session: Session, user: User, store: Store): + """Full flow: normalize → match → store in DB. UPC matching works E2E.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Run product matching + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + assert len(outcomes) == 5 + + # First item has a UPC — auto_create makes a new product + assert outcomes[0].created_new is True + + # Store the purchase in DB + purchase_db = Purchase( + id=uuid.uuid4(), + user_id=user.id, + store_id=store.id, + receipt_id=purchase_schema.receipt_id, + purchase_date=purchase_schema.purchase_date, + total=purchase_schema.total, + subtotal=purchase_schema.subtotal, + tax=purchase_schema.tax, + savings_total=purchase_schema.savings_total, + raw_data=purchase_schema.raw_data, + ) + e2e_session.add(purchase_db) + e2e_session.flush() + + # Store items linked to the purchase and matched products + for _i, item_schema in enumerate(purchase_schema.items): + item_db = PurchaseItem( + id=uuid.uuid4(), + purchase_id=purchase_db.id, + product_name_raw=item_schema.product_name_raw, + upc=item_schema.upc, + quantity=item_schema.quantity, + unit_price=item_schema.unit_price, + extended_price=item_schema.extended_price, + regular_price=item_schema.regular_price, + sale_price=item_schema.sale_price, + coupon_discount=item_schema.coupon_discount, + loyalty_discount=item_schema.loyalty_discount, + category_raw=item_schema.category_raw, + ) + e2e_session.add(item_db) + e2e_session.flush() + + # Verify data persisted correctly + stored_purchase = e2e_session.execute( + select(Purchase).where(Purchase.receipt_id == "MJ-2026-03-15-00042") + ).scalar_one() + assert stored_purchase.total == Decimal("47.82") + assert stored_purchase.user_id == user.id + assert stored_purchase.store_id == store.id + + stored_items = ( + e2e_session.execute( + select(PurchaseItem).where(PurchaseItem.purchase_id == stored_purchase.id) + ) + .scalars() + .all() + ) + assert len(stored_items) == 5 + + # Verify products were created in normalized_products table + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + assert len(products) == 5 # all 5 items auto-created products + + def test_second_visit_reuses_existing_products( + self, e2e_session: Session, user: User, store: Store + ): + """On second receipt, products matched by UPC reuse existing records.""" + # Ingest first receipt + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + matcher.match_items(first.items) + + products_after_first = e2e_session.execute(select(NormalizedProduct)).scalars().all() + first_count = len(products_after_first) + + # Ingest second receipt — overlapping UPCs + second = normalize_receipt( + MEIJER_RECEIPT_SECOND_VISIT, + user_id=str(user.id), + store_id=str(store.id), + ) + second_outcomes = matcher.match_items(second.items) + + # Milk, pasta, cheerios should match existing by UPC + assert second_outcomes[0].created_new is False # milk — UPC match + assert second_outcomes[1].created_new is False # pasta — UPC match + assert second_outcomes[2].created_new is False # cheerios — UPC match + + products_after_second = e2e_session.execute(select(NormalizedProduct)).scalars().all() + assert len(products_after_second) == first_count # no new products created + + +# =========================================================================== +# Test class: Price tracking and shrinkflation detection E2E +# =========================================================================== + + +class TestPriceTrackingE2E: + """Price recording from stored items and price delta detection.""" + + def test_price_recorded_from_ingested_receipt( + self, e2e_session: Session, user: User, store: Store + ): + """Ingest receipt → match products → record prices → verify price history.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + # Record prices for each matched item + price_entries = [] + for i, item_schema in enumerate(purchase_schema.items): + product = outcomes[i].match.product if outcomes[i].match else None + if product is None: + # Was auto-created — find the product directly + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + + if product: + entry, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=purchase_schema.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + price_entries.append((entry, delta)) + + # First ingestion — no deltas expected + assert all(delta is None for _, delta in price_entries) + + # Verify price history stored + all_prices = e2e_session.execute(select(PriceHistory)).scalars().all() + assert len(all_prices) >= 4 # at least the items with regular_price + + def test_price_increase_detected_on_second_receipt( + self, e2e_session: Session, user: User, store: Store + ): + """Second receipt with higher price triggers a PriceDelta.""" + # Ingest first receipt + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + first_outcomes = matcher.match_items(first.items) + + # Record first prices + for i, item_schema in enumerate(first.items): + product = first_outcomes[i].match.product if first_outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=first.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + + # Ingest second receipt — pasta price went up ($1.89 → $1.99) + second = normalize_receipt( + MEIJER_RECEIPT_SECOND_VISIT, + user_id=str(user.id), + store_id=str(store.id), + ) + second_outcomes = matcher.match_items(second.items) + + # Record second prices and capture deltas + deltas: list[PriceDelta] = [] + for i, item_schema in enumerate(second.items): + product = second_outcomes[i].match.product if second_outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + _, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=second.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + sale_price=item_schema.sale_price, + ) + if delta: + deltas.append(delta) + + # Milk went from $3.49 → $3.49 (no change); pasta from $1.89 → $1.99 (increase) + price_increases = [d for d in deltas if d.is_increase] + assert len(price_increases) >= 1 + + pasta_delta = next( + (d for d in price_increases if d.old_price == Decimal("1.89")), + None, + ) + assert pasta_delta is not None + assert pasta_delta.new_price == Decimal("1.99") + assert pasta_delta.change_amount == Decimal("0.10") + assert pasta_delta.is_increase is True + + def test_price_trend_across_visits(self, e2e_session: Session, user: User, store: Store): + """get_price_trend returns ordered history after multiple ingestions.""" + # Create a product manually + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Test Product", + upc_variants=["1234567890"], + ) + e2e_session.add(product) + e2e_session.flush() + + # Record 3 prices on different dates + dates_prices = [ + (date(2026, 3, 10), Decimal("2.99")), + (date(2026, 3, 13), Decimal("3.19")), + (date(2026, 3, 16), Decimal("2.79")), + ] + for obs_date, price in dates_prices: + record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=obs_date, + regular_price=price, + ) + + trend = get_price_trend(e2e_session, product.id, store.id) + assert len(trend) == 3 + # Newest first + assert trend[0].regular_price == Decimal("2.79") + assert trend[1].regular_price == Decimal("3.19") + assert trend[2].regular_price == Decimal("2.99") + + +class TestShrinkflationE2E: + """Shrinkflation detection integrated with product matching.""" + + def test_shrinkflation_detected_from_receipt_data( + self, e2e_session: Session, user: User, store: Store + ): + """Cheerios went from 12 oz → 10.8 oz between receipts. Detect shrinkflation.""" + # Ingest first receipt — creates Cheerios product with size from name + first = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + matcher = ProductMatcher(e2e_session, auto_create=True) + first_outcomes = matcher.match_items(first.items) + + # Find the Cheerios product (index 3 in fixture) + cheerios_product = None + for outcome in first_outcomes: + if outcome.match and outcome.match.product: + p = outcome.match.product + else: + # Check auto-created products + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if "cheerios" in p.canonical_name.lower(): + cheerios_product = p + break + if cheerios_product: + break + else: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if "cheerios" in p.canonical_name.lower(): + cheerios_product = p + break + + assert cheerios_product is not None + # The auto-created product should have extracted "12" and "oz" from name + assert cheerios_product.size == "12" + assert cheerios_product.size_unit == SizeUnit.OZ + + # Now detect shrinkflation: 12 oz → 10.8 oz + event = detect_shrinkflation( + e2e_session, + product=cheerios_product, + new_size="10.8", + new_unit=SizeUnit.OZ, + new_price=Decimal("4.49"), + detected_date=date(2026, 3, 18), + ) + + assert event is not None + assert isinstance(event, ShrinkflationEvent) + assert event.old_size == "12" + assert event.new_size == "10.8" + assert event.old_unit == SizeUnit.OZ + assert event.new_unit == SizeUnit.OZ + assert event.confidence >= Decimal("0.85") # 10% decrease → 0.95 + + # Verify stored in DB + stored = e2e_session.execute( + select(ShrinkflationEvent).where( + ShrinkflationEvent.normalized_product_id == cheerios_product.id + ) + ).scalar_one() + assert stored.id == event.id + + def test_shrinkflation_dedup_on_repeat_detection( + self, e2e_session: Session, user: User, store: Store + ): + """Same shrinkflation detected twice returns the existing event, not a duplicate.""" + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Brand X Cereal 15 oz", + size="15", + size_unit=SizeUnit.OZ, + upc_variants=["999888777"], + ) + e2e_session.add(product) + e2e_session.flush() + + first = detect_shrinkflation(e2e_session, product, new_size="13.5", new_unit=SizeUnit.OZ) + second = detect_shrinkflation(e2e_session, product, new_size="13.5", new_unit=SizeUnit.OZ) + + assert first is not None + assert second is not None + assert first.id == second.id # same event, not duplicated + + count = len( + e2e_session.execute( + select(ShrinkflationEvent).where( + ShrinkflationEvent.normalized_product_id == product.id + ) + ) + .scalars() + .all() + ) + assert count == 1 + + +# =========================================================================== +# Test class: Event bus pub/sub for pipeline stage transitions +# =========================================================================== + + +class TestEventBusE2E: + """Redis event publishing at each pipeline stage.""" + + def test_receipt_ingested_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid EventEnvelope for RECEIPTS_INGESTED.""" + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + subscribers = publish_event( + redis_mock, + EventType.RECEIPTS_INGESTED, + service="receiptwitness", + payload={ + "receipt_id": purchase_schema.receipt_id, + "user_id": str(user.id), + "store_slug": StoreSlug.MEIJER, + "item_count": len(purchase_schema.items), + "total": str(purchase_schema.total), + }, + ) + + assert subscribers == 1 + assert len(redis_mock._published) == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.RECEIPTS_INGESTED.value + + # Deserialize and validate the envelope + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.event_type == EventType.RECEIPTS_INGESTED + assert envelope.service == "receiptwitness" + assert envelope.payload["receipt_id"] == "MJ-2026-03-15-00042" + assert envelope.payload["item_count"] == 5 + + def test_price_updated_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for PRICES_UPDATED.""" + subscribers = publish_event( + redis_mock, + EventType.PRICES_UPDATED, + service="cartsnitch-common", + payload={ + "product_id": str(uuid.uuid4()), + "store_slug": StoreSlug.MEIJER, + "old_price": "1.89", + "new_price": "1.99", + "change_percent": "5.29", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.PRICES_UPDATED.value + + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.event_type == EventType.PRICES_UPDATED + assert envelope.payload["old_price"] == "1.89" + + def test_products_normalized_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for PRODUCTS_NORMALIZED.""" + product_id = str(uuid.uuid4()) + subscribers = publish_event( + redis_mock, + EventType.PRODUCTS_NORMALIZED, + service="cartsnitch-common", + payload={ + "product_id": product_id, + "canonical_name": "Barilla Spaghetti", + "match_method": "upc", + "confidence": "high", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.PRODUCTS_NORMALIZED.value + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.payload["confidence"] == "high" + + def test_shrinkflation_alert_event(self, redis_mock, user: User, store: Store): + """publish_event sends a valid envelope for ALERT_SHRINKFLATION.""" + subscribers = publish_event( + redis_mock, + EventType.ALERT_SHRINKFLATION, + service="shrinkray", + payload={ + "product_id": str(uuid.uuid4()), + "product_name": "Cheerios Original", + "old_size": "12 oz", + "new_size": "10.8 oz", + "confidence": "0.95", + }, + ) + + assert subscribers == 1 + channel, raw_msg = redis_mock._published[0] + assert channel == EventType.ALERT_SHRINKFLATION.value + + def test_full_pipeline_emits_events_at_each_stage( + self, e2e_session: Session, redis_mock, user: User, store: Store + ): + """Full pipeline: ingest → match → record price → publish events at each stage.""" + # Stage 1: Normalize receipt + purchase_schema = normalize_receipt( + MEIJER_RECEIPT_FIXTURE, + user_id=str(user.id), + store_id=str(store.id), + ) + + # Publish receipt ingested + publish_event( + redis_mock, + EventType.RECEIPTS_INGESTED, + service="receiptwitness", + payload={ + "receipt_id": purchase_schema.receipt_id, + "item_count": len(purchase_schema.items), + }, + ) + + # Stage 2: Match products + matcher = ProductMatcher(e2e_session, auto_create=True) + outcomes = matcher.match_items(purchase_schema.items) + + for i, outcome in enumerate(outcomes): + product = outcome.match.product if outcome.match else None + if product is None: + # Auto-created — look up by name + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == purchase_schema.items[i].product_name_raw: + product = p + break + if product is None: + continue + publish_event( + redis_mock, + EventType.PRODUCTS_NORMALIZED, + service="cartsnitch-common", + payload={ + "product_id": str(product.id), + "match_method": outcome.match.method.value if outcome.match else "auto_create", + "confidence": outcome.confidence_level.value, + }, + ) + + # Stage 3: Record prices + for i, item_schema in enumerate(purchase_schema.items): + product = outcomes[i].match.product if outcomes[i].match else None + if product is None: + products = e2e_session.execute(select(NormalizedProduct)).scalars().all() + for p in products: + if p.canonical_name == item_schema.product_name_raw: + product = p + break + if product: + _, delta = record_price_from_item( + e2e_session, + product_id=product.id, + store_id=store.id, + observed_date=purchase_schema.purchase_date, + regular_price=item_schema.regular_price or item_schema.unit_price, + ) + if delta and delta.is_increase: + publish_event( + redis_mock, + EventType.ALERT_PRICE_INCREASE, + service="stickershock", + payload={ + "product_id": str(product.id), + "old_price": str(delta.old_price), + "new_price": str(delta.new_price), + }, + ) + + # Verify events published at each stage + channels = [ch for ch, _ in redis_mock._published] + assert EventType.RECEIPTS_INGESTED.value in channels + assert EventType.PRODUCTS_NORMALIZED.value in channels + # No price increases on first receipt, so no ALERT_PRICE_INCREASE expected + + # All messages are valid EventEnvelopes + for _, raw_msg in redis_mock._published: + envelope = EventEnvelope.model_validate_json(raw_msg) + assert envelope.timestamp is not None + assert envelope.service + + +# =========================================================================== +# Test class: Error handling for malformed scraper output +# =========================================================================== + + +class TestMalformedScraperOutput: + """Error handling for bad, partial, or unexpected scraper data.""" + + def test_missing_item_name_produces_empty_string(self): + """Item with no description/name field normalizes with empty product_name_raw.""" + item = parse_meijer_item({"unitPrice": "2.99"}) + assert item.product_name_raw == "" + assert item.unit_price == Decimal("2.99") + + def test_missing_price_defaults_to_zero(self): + """Item with no price fields defaults to zero.""" + item = parse_meijer_item({"description": "Mystery Product"}) + assert item.unit_price == Decimal("0") + assert item.extended_price == Decimal("0") + + def test_non_numeric_price_defaults_to_zero(self): + """Non-numeric price strings safely default to zero.""" + item = parse_meijer_item( + { + "description": "Bad Price Item", + "unitPrice": "not_a_number", + "extendedPrice": "$$$.xx", + } + ) + assert item.unit_price == Decimal("0") + assert item.extended_price == Decimal("0") + + def test_empty_receipt_produces_empty_items(self, user: User, store: Store): + """Receipt with no items normalizes cleanly.""" + raw = {"receiptId": "EMPTY-001", "date": "2026-03-15", "total": "0.00"} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + + assert purchase.receipt_id == "EMPTY-001" + assert purchase.total == Decimal("0.00") + assert len(purchase.items) == 0 + + def test_receipt_missing_date_defaults_to_today(self, user: User, store: Store): + """Receipt with no date field defaults to today.""" + raw = {"receiptId": "NO-DATE-001", "total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + assert purchase.purchase_date == date.today() + + def test_receipt_missing_id_generates_uuid(self, user: User, store: Store): + """Receipt with no ID generates a UUID.""" + raw = {"date": "2026-03-15", "total": "10.00", "items": []} + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + + # Should be a valid UUID string + uuid.UUID(purchase.receipt_id) + + def test_item_with_garbage_upc_preserves_it(self): + """UPC field with non-standard content is preserved as-is after strip.""" + item = parse_meijer_item( + { + "description": "Weird UPC Product", + "upc": " ABC-NOT-A-UPC ", + "unitPrice": "1.99", + } + ) + # lstrip("0") on "ABC-NOT-A-UPC" leaves it intact + assert item.upc == "ABC-NOT-A-UPC" + + def test_negative_prices_pass_through(self): + """Negative prices (refunds) are preserved, not zeroed.""" + item = parse_meijer_item( + { + "description": "Refund Item", + "unitPrice": "-5.99", + "extendedPrice": "-5.99", + } + ) + assert item.unit_price == Decimal("-5.99") + assert item.extended_price == Decimal("-5.99") + + def test_extended_price_auto_calculated(self): + """When extendedPrice is missing, it's calculated from unitPrice * quantity.""" + item = parse_meijer_item( + { + "description": "No Extended", + "unitPrice": "2.50", + "quantity": "3", + } + ) + assert item.extended_price == Decimal("7.50") + + def test_matching_with_malformed_items(self, e2e_session: Session): + """ProductMatcher handles items with missing/empty names gracefully.""" + matcher = ProductMatcher(e2e_session, auto_create=True) + + bad_items = [ + parse_meijer_item({"description": "", "unitPrice": "1.00"}), + parse_meijer_item({"unitPrice": "2.00"}), + ] + + outcomes = matcher.match_items(bad_items) + assert len(outcomes) == 2 + # Both should auto-create (no match possible for empty names) + assert all(o.created_new for o in outcomes) + + def test_completely_empty_receipt(self, user: User, store: Store): + """Totally empty dict produces a valid PurchaseCreate with defaults.""" + purchase = normalize_receipt({}, user_id=str(user.id), store_id=str(store.id)) + assert purchase.total == Decimal("0") + assert len(purchase.items) == 0 + assert purchase.purchase_date == date.today() + + def test_mixed_valid_and_malformed_items(self, user: User, store: Store): + """Receipt with a mix of good and bad items processes all of them.""" + raw = { + "receiptId": "MIX-001", + "date": "2026-03-15", + "total": "10.00", + "items": [ + { + "description": "Good Product 8 oz", + "upc": "1234567890", + "unitPrice": "3.99", + "extendedPrice": "3.99", + }, + { + "unitPrice": "not_a_price", + }, + { + "description": " *** Special Chars !!! ", + "unitPrice": "2.50", + }, + ], + } + purchase = normalize_receipt(raw, user_id=str(user.id), store_id=str(store.id)) + assert len(purchase.items) == 3 + + # Good item + assert purchase.items[0].product_name_raw == "Good Product 8 oz" + assert purchase.items[0].upc == "1234567890" + + # Bad price item + assert purchase.items[1].unit_price == Decimal("0") + + # Special chars stripped + assert purchase.items[2].product_name_raw == "Special Chars" diff --git a/tests/test_pipeline_matching.py b/tests/test_pipeline_matching.py new file mode 100644 index 0000000..c73da9e --- /dev/null +++ b/tests/test_pipeline_matching.py @@ -0,0 +1,160 @@ +"""Tests for product matching & dedup pipeline.""" + +import uuid +from datetime import UTC, datetime +from decimal import Decimal + +from cartsnitch_common.constants import MatchConfidence +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.normalization import MatchMethod +from cartsnitch_common.pipeline.matching import ( + ProductMatcher, + classify_confidence, + match_purchase_item, +) +from cartsnitch_common.schemas.purchase import PurchaseItemCreate + + +class TestClassifyConfidence: + def test_upc_always_high(self): + assert classify_confidence(1.0, MatchMethod.UPC) == MatchConfidence.HIGH + assert classify_confidence(0.5, MatchMethod.UPC) == MatchConfidence.HIGH + + def test_name_high(self): + assert classify_confidence(0.9, MatchMethod.NAME) == MatchConfidence.HIGH + assert classify_confidence(0.8, MatchMethod.NAME) == MatchConfidence.HIGH + + def test_name_medium(self): + assert classify_confidence(0.6, MatchMethod.NAME) == MatchConfidence.MEDIUM + assert classify_confidence(0.5, MatchMethod.NAME) == MatchConfidence.MEDIUM + + def test_name_low(self): + assert classify_confidence(0.3, MatchMethod.NAME) == MatchConfidence.LOW + assert classify_confidence(0.0, MatchMethod.NAME) == MatchConfidence.LOW + + +class TestProductMatcher: + def _make_item(self, name: str, upc: str | None = None) -> PurchaseItemCreate: + return PurchaseItemCreate( + product_name_raw=name, + upc=upc, + unit_price=Decimal("3.99"), + extended_price=Decimal("3.99"), + ) + + def test_match_by_upc(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + upc_variants=["041250000001"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + item = self._make_item("Kroger Milk", upc="041250000001") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert prod.id == product.id + assert result is not None + assert result.method == MatchMethod.UPC + assert confidence == MatchConfidence.HIGH + + def test_match_by_name(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Whole Milk Gallon", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session, name_threshold=0.3) + item = self._make_item("Whole Milk Gallon Size") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is not None + assert result.method == MatchMethod.NAME + + def test_auto_create_when_no_match(self, session): + matcher = ProductMatcher(session, auto_create=True) + item = self._make_item("Unique Product XYZ 16 oz") + prod, result, confidence = matcher.match_single(item) + + assert prod is not None + assert result is None # No match found, was created + assert confidence == MatchConfidence.LOW + assert prod.canonical_name == "Unique Product XYZ 16 oz" + assert prod.size == "16" + assert prod.size_unit == "oz" + + def test_no_create_when_disabled(self, session): + matcher = ProductMatcher(session, auto_create=False) + item = self._make_item("Nonexistent Product") + prod, result, confidence = matcher.match_single(item) + + assert prod is None + assert result is None + + def test_batch_match(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Large Eggs 12 Count", + upc_variants=["012345"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + matcher = ProductMatcher(session) + items = [ + self._make_item("Large Eggs", upc="012345"), + self._make_item("Brand New Never Seen Product"), + ] + outcomes = matcher.match_items(items) + + assert len(outcomes) == 2 + assert outcomes[0].match is not None + assert outcomes[0].confidence_level == MatchConfidence.HIGH + assert outcomes[0].created_new is False + assert outcomes[1].match is None + assert outcomes[1].created_new is True + + +class TestMatchPurchaseItem: + def test_convenience_function(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="Ground Beef 80/20", + upc_variants=["999888"], + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.commit() + + item = PurchaseItemCreate( + product_name_raw="Ground Beef", + upc="999888", + unit_price=Decimal("5.99"), + extended_price=Decimal("5.99"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.HIGH + + def test_auto_create_default(self, session): + item = PurchaseItemCreate( + product_name_raw="Totally New Item", + unit_price=Decimal("1.00"), + extended_price=Decimal("1.00"), + ) + prod, confidence = match_purchase_item(session, item) + assert prod is not None + assert confidence == MatchConfidence.LOW diff --git a/tests/test_pipeline_price.py b/tests/test_pipeline_price.py new file mode 100644 index 0000000..63cae15 --- /dev/null +++ b/tests/test_pipeline_price.py @@ -0,0 +1,282 @@ +"""Tests for price history tracking pipeline.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +from cartsnitch_common.constants import PriceSource, StoreSlug +from cartsnitch_common.models.price import PriceHistory +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.models.store import Store +from cartsnitch_common.pipeline.price_tracking import ( + PriceDelta, + get_latest_price, + get_price_trend, + record_price_from_item, +) + + +def _make_store(session, slug=StoreSlug.MEIJER) -> Store: + store = Store( + id=uuid.uuid4(), + name="Meijer", + slug=slug, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(store) + session.flush() + return store + + +def _make_product(session, name="Test Product") -> NormalizedProduct: + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=name, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + return product + + +class TestGetLatestPrice: + def test_no_history(self, session): + product = _make_product(session) + store = _make_store(session) + result = get_latest_price(session, product.id, store.id) + assert result is None + + def test_returns_newest(self, session): + product = _make_product(session) + store = _make_store(session) + + # Add two entries + old = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + ) + new = PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 10), + regular_price=Decimal("4.29"), + source=PriceSource.RECEIPT, + ) + session.add_all([old, new]) + session.flush() + + result = get_latest_price(session, product.id, store.id) + assert result is not None + assert result.regular_price == Decimal("4.29") + + +class TestRecordPriceFromItem: + def test_first_price_no_delta(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + ) + assert entry is not None + assert entry.regular_price == Decimal("3.99") + assert entry.source == PriceSource.RECEIPT + assert delta is None + + def test_price_increase_detected(self, session): + product = _make_product(session) + store = _make_store(session) + + # First price + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + ) + + # Price increase + entry, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.49"), + ) + + assert delta is not None + assert delta.old_price == Decimal("3.99") + assert delta.new_price == Decimal("4.49") + assert delta.change_amount == Decimal("0.50") + assert delta.is_increase is True + assert delta.is_decrease is False + assert delta.change_percent > Decimal("0") + + def test_price_decrease_detected(self, session): + product = _make_product(session) + store = _make_store(session) + + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("5.00"), + ) + + _, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.00"), + ) + + assert delta is not None + assert delta.is_decrease is True + assert delta.change_amount == Decimal("-1.00") + + def test_same_price_no_delta(self, session): + product = _make_product(session) + store = _make_store(session) + + record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 1), + regular_price=Decimal("3.99"), + ) + + _, delta = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + ) + assert delta is None + + def test_sale_and_loyalty_prices_recorded(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, _ = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("5.99"), + sale_price=Decimal("4.99"), + loyalty_price=Decimal("4.49"), + coupon_price=Decimal("3.99"), + ) + assert entry.sale_price == Decimal("4.99") + assert entry.loyalty_price == Decimal("4.49") + assert entry.coupon_price == Decimal("3.99") + + def test_custom_source(self, session): + product = _make_product(session) + store = _make_store(session) + + entry, _ = record_price_from_item( + session, + product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, 15), + regular_price=Decimal("3.99"), + source=PriceSource.CATALOG, + ) + assert entry.source == PriceSource.CATALOG + + +class TestGetPriceTrend: + def test_empty_trend(self, session): + product = _make_product(session) + store = _make_store(session) + trend = get_price_trend(session, product.id, store.id) + assert trend == [] + + def test_returns_newest_first(self, session): + product = _make_product(session) + store = _make_store(session) + + for day in [1, 5, 10, 15]: + session.add( + PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, day), + regular_price=Decimal(str(3 + day * 0.1)), + source=PriceSource.RECEIPT, + ) + ) + session.flush() + + trend = get_price_trend(session, product.id, store.id) + assert len(trend) == 4 + assert trend[0].observed_date == date(2026, 3, 15) + assert trend[-1].observed_date == date(2026, 3, 1) + + def test_respects_limit(self, session): + product = _make_product(session) + store = _make_store(session) + + for day in range(1, 11): + session.add( + PriceHistory( + id=uuid.uuid4(), + normalized_product_id=product.id, + store_id=store.id, + observed_date=date(2026, 3, day), + regular_price=Decimal("3.99"), + source=PriceSource.RECEIPT, + ) + ) + session.flush() + + trend = get_price_trend(session, product.id, store.id, limit=3) + assert len(trend) == 3 + + +class TestPriceDelta: + def test_delta_properties(self): + delta = PriceDelta( + product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + old_price=Decimal("3.99"), + new_price=Decimal("4.49"), + change_amount=Decimal("0.50"), + change_percent=Decimal("12.53"), + old_date=date(2026, 3, 1), + new_date=date(2026, 3, 15), + ) + assert delta.is_increase is True + assert delta.is_decrease is False + + def test_decrease_properties(self): + delta = PriceDelta( + product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + old_price=Decimal("4.49"), + new_price=Decimal("3.99"), + change_amount=Decimal("-0.50"), + change_percent=Decimal("-11.14"), + old_date=date(2026, 3, 1), + new_date=date(2026, 3, 15), + ) + assert delta.is_decrease is True + assert delta.is_increase is False diff --git a/tests/test_pipeline_receipt.py b/tests/test_pipeline_receipt.py new file mode 100644 index 0000000..d937d39 --- /dev/null +++ b/tests/test_pipeline_receipt.py @@ -0,0 +1,204 @@ +"""Tests for receipt normalization pipeline.""" + +import uuid +from datetime import date +from decimal import Decimal + +from cartsnitch_common.pipeline.receipt import ( + _clean_product_name, + _safe_decimal, + normalize_receipt, + parse_meijer_item, +) + + +class TestCleanProductName: + def test_strips_whitespace(self): + assert _clean_product_name(" Milk ") == "Milk" + + def test_removes_leading_punctuation(self): + assert _clean_product_name("---Milk---") == "Milk" + + def test_collapses_internal_whitespace(self): + assert _clean_product_name("Whole Milk Gallon") == "Whole Milk Gallon" + + def test_empty_string(self): + assert _clean_product_name("") == "" + + +class TestSafeDecimal: + def test_string_input(self): + assert _safe_decimal("3.99") == Decimal("3.99") + + def test_float_input(self): + assert _safe_decimal(3.99) == Decimal("3.99") + + def test_int_input(self): + assert _safe_decimal(4) == Decimal("4") + + def test_none_returns_default(self): + assert _safe_decimal(None) == Decimal("0") + + def test_none_custom_default(self): + assert _safe_decimal(None, Decimal("1")) == Decimal("1") + + def test_invalid_returns_default(self): + assert _safe_decimal("not-a-number") == Decimal("0") + + def test_decimal_passthrough(self): + assert _safe_decimal(Decimal("5.50")) == Decimal("5.50") + + +class TestParseMeijerItem: + def test_basic_item(self): + raw = { + "description": "Kroger Whole Milk 1 Gallon", + "upc": "0041250000001", + "quantity": 1, + "unitPrice": "3.99", + "extendedPrice": "3.99", + "category": "DAIRY", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Kroger Whole Milk 1 Gallon" + assert item.upc == "41250000001" # leading zeros stripped + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("3.99") + assert item.extended_price == Decimal("3.99") + assert item.category_raw == "DAIRY" + + def test_alternate_field_names(self): + raw = { + "name": "Eggs Large 12 ct", + "upcCode": "012345", + "qty": 2, + "price": "4.50", + "totalPrice": "9.00", + "department": "EGGS", + } + item = parse_meijer_item(raw) + assert item.product_name_raw == "Eggs Large 12 ct" + assert item.upc == "12345" + assert item.quantity == Decimal("2") + assert item.unit_price == Decimal("4.50") + assert item.extended_price == Decimal("9.00") + assert item.category_raw == "EGGS" + + def test_calculates_extended_from_unit_price(self): + raw = { + "description": "Bananas", + "unitPrice": "0.59", + "quantity": 3, + } + item = parse_meijer_item(raw) + assert item.extended_price == Decimal("1.77") + + def test_discounts_parsed(self): + raw = { + "description": "Cereal", + "unitPrice": "4.99", + "extendedPrice": "4.99", + "regularPrice": "5.99", + "salePrice": "4.99", + "couponAmount": "1.00", + "loyaltyAmount": "0.50", + } + item = parse_meijer_item(raw) + assert item.regular_price == Decimal("5.99") + assert item.sale_price == Decimal("4.99") + assert item.coupon_discount == Decimal("1.00") + assert item.loyalty_discount == Decimal("0.50") + + def test_alternate_discount_names(self): + raw = { + "description": "Bread", + "unitPrice": "2.99", + "extendedPrice": "2.99", + "couponDiscount": "0.75", + "loyaltyDiscount": "0.25", + } + item = parse_meijer_item(raw) + assert item.coupon_discount == Decimal("0.75") + assert item.loyalty_discount == Decimal("0.25") + + def test_missing_fields_default_gracefully(self): + raw = {"description": "Mystery Item"} + item = parse_meijer_item(raw) + assert item.product_name_raw == "Mystery Item" + assert item.upc is None + assert item.quantity == Decimal("1") + assert item.unit_price == Decimal("0") + assert item.regular_price is None + assert item.category_raw is None + + def test_no_upc_returns_none(self): + raw = {"description": "Loose Bananas", "unitPrice": "1.00", "extendedPrice": "1.00"} + item = parse_meijer_item(raw) + assert item.upc is None + + +class TestNormalizeReceipt: + def test_full_receipt(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receiptId": "REC-001", + "date": "2026-03-15", + "total": "25.47", + "subtotal": "23.00", + "tax": "2.47", + "savings": "3.00", + "items": [ + {"description": "Milk", "unitPrice": "3.99", "extendedPrice": "3.99"}, + {"description": "Bread", "unitPrice": "2.50", "extendedPrice": "2.50"}, + ], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-001" + assert purchase.purchase_date == date(2026, 3, 15) + assert purchase.total == Decimal("25.47") + assert purchase.subtotal == Decimal("23.00") + assert purchase.tax == Decimal("2.47") + assert purchase.savings_total == Decimal("3.00") + assert len(purchase.items) == 2 + assert purchase.items[0].product_name_raw == "Milk" + assert purchase.raw_data == raw + + def test_alternate_receipt_fields(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = { + "receipt_id": "REC-002", + "purchaseDate": "2026-03-14", + "totalAmount": "10.00", + "taxAmount": "0.75", + "totalSavings": "1.50", + "items": [], + } + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id == "REC-002" + assert purchase.purchase_date == date(2026, 3, 14) + assert purchase.total == Decimal("10.00") + assert purchase.tax == Decimal("0.75") + assert purchase.savings_total == Decimal("1.50") + + def test_missing_date_defaults_to_today(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date.today() + + def test_generates_receipt_id_if_missing(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"total": "5.00", "date": "2026-03-15", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.receipt_id # Should be a generated UUID string + + def test_date_object_passthrough(self): + user_id = str(uuid.uuid4()) + store_id = str(uuid.uuid4()) + raw = {"date": date(2026, 1, 1), "total": "5.00", "items": []} + purchase = normalize_receipt(raw, user_id, store_id) + assert purchase.purchase_date == date(2026, 1, 1) diff --git a/tests/test_pipeline_shrinkflation.py b/tests/test_pipeline_shrinkflation.py new file mode 100644 index 0000000..9c1bd0c --- /dev/null +++ b/tests/test_pipeline_shrinkflation.py @@ -0,0 +1,233 @@ +"""Tests for shrinkflation detection pipeline.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +from cartsnitch_common.constants import SizeUnit +from cartsnitch_common.models.product import NormalizedProduct +from cartsnitch_common.pipeline.shrinkflation import ( + _to_comparable, + _units_comparable, + detect_shrinkflation, +) + + +class TestToComparable: + def test_oz_to_grams(self): + result = _to_comparable("16", SizeUnit.OZ) + assert result is not None + assert result == Decimal("16") * Decimal("28.3495") + + def test_lb_to_grams(self): + result = _to_comparable("1", SizeUnit.LB) + assert result == Decimal("453.592") + + def test_ml_to_ml(self): + assert _to_comparable("500", SizeUnit.ML) == Decimal("500") + + def test_fl_oz_to_ml(self): + result = _to_comparable("12", SizeUnit.FL_OZ) + assert result is not None + assert result == Decimal("12") * Decimal("29.5735") + + def test_count_units(self): + assert _to_comparable("12", SizeUnit.CT) == Decimal("12") + assert _to_comparable("6", SizeUnit.PK) == Decimal("6") + + def test_invalid_size(self): + assert _to_comparable("abc", SizeUnit.OZ) is None + + +class TestUnitsComparable: + def test_weight_comparable(self): + assert _units_comparable(SizeUnit.OZ, SizeUnit.LB) is True + assert _units_comparable(SizeUnit.G, SizeUnit.KG) is True + + def test_volume_comparable(self): + assert _units_comparable(SizeUnit.ML, SizeUnit.L) is True + assert _units_comparable(SizeUnit.FL_OZ, SizeUnit.ML) is True + + def test_count_comparable(self): + assert _units_comparable(SizeUnit.CT, SizeUnit.PK) is True + + def test_not_comparable_across_systems(self): + assert _units_comparable(SizeUnit.OZ, SizeUnit.ML) is False + assert _units_comparable(SizeUnit.CT, SizeUnit.OZ) is False + assert _units_comparable(SizeUnit.LB, SizeUnit.L) is False + + +class TestDetectShrinkflation: + def _make_product(self, session, size: str, unit: SizeUnit, name: str = "Test Product"): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name=name, + size=size, + size_unit=unit, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + return product + + def test_detects_oz_decrease(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + assert event is not None + assert event.old_size == "16" + assert event.new_size == "14" + assert "decreased" in event.notes.lower() + + def test_no_detection_when_size_increases(self, session): + product = self._make_product(session, "14", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="16", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_no_detection_same_size(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="16", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_no_detection_incompatible_units(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + event = detect_shrinkflation( + session, + product=product, + new_size="400", + new_unit=SizeUnit.ML, + ) + assert event is None + + def test_no_detection_without_existing_size(self, session): + product = NormalizedProduct( + id=uuid.uuid4(), + canonical_name="No Size Product", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + session.add(product) + session.flush() + + event = detect_shrinkflation( + session, + product=product, + new_size="12", + new_unit=SizeUnit.OZ, + ) + assert event is None + + def test_cross_unit_detection_same_system(self, session): + # 1 lb = 453.592g, 14 oz = 396.893g → size decreased + product = self._make_product(session, "1", SizeUnit.LB) + event = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + assert event is not None + + def test_count_decrease(self, session): + product = self._make_product(session, "12", SizeUnit.CT) + event = detect_shrinkflation( + session, + product=product, + new_size="10", + new_unit=SizeUnit.CT, + detected_date=date(2026, 3, 15), + ) + assert event is not None + assert event.old_size == "12" + assert event.new_size == "10" + + def test_dedup_existing_event(self, session): + product = self._make_product(session, "16", SizeUnit.OZ) + + # First detection + event1 = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 15), + ) + + # Same detection again — should return existing + event2 = detect_shrinkflation( + session, + product=product, + new_size="14", + new_unit=SizeUnit.OZ, + detected_date=date(2026, 3, 16), + ) + + assert event1 is not None + assert event2 is not None + assert event1.id == event2.id + + def test_confidence_scaling(self, session): + # Small decrease (< 5%) → 0.70 + product1 = self._make_product(session, "100", SizeUnit.G, "Product A") + event1 = detect_shrinkflation( + session, + product=product1, + new_size="97", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event1 is not None + assert event1.confidence == Decimal("0.70") + + # Medium decrease (5-10%) → 0.85 + product2 = self._make_product(session, "100", SizeUnit.G, "Product B") + event2 = detect_shrinkflation( + session, + product=product2, + new_size="93", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event2 is not None + assert event2.confidence == Decimal("0.85") + + # Large decrease (>= 10%) → 0.95 + product3 = self._make_product(session, "100", SizeUnit.G, "Product C") + event3 = detect_shrinkflation( + session, + product=product3, + new_size="85", + new_unit=SizeUnit.G, + detected_date=date(2026, 3, 15), + ) + assert event3 is not None + assert event3.confidence == Decimal("0.95") + + def test_min_size_decrease_threshold(self, session): + product = self._make_product(session, "100", SizeUnit.G) + # 0.5% decrease — below default 1% threshold + event = detect_shrinkflation( + session, + product=product, + new_size="99.5", + new_unit=SizeUnit.G, + min_size_decrease_pct=Decimal("1"), + ) + assert event is None diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..564665e --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,225 @@ +"""Tests for Pydantic v2 schemas.""" + +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal + +import pytest +from pydantic import ValidationError + +from cartsnitch_common.constants import ( + AccountStatus, + DiscountType, + EventType, + PriceSource, + ProductCategory, + SizeUnit, + StoreSlug, +) +from cartsnitch_common.schemas import ( + CouponCreate, + EventEnvelope, + NormalizedProductCreate, + PriceHistoryCreate, + PurchaseCreate, + PurchaseItemCreate, + ShrinkflationEventCreate, + StoreCreate, + StoreLocationCreate, + StoreRead, + UserCreate, + UserStoreAccountCreate, +) + + +class TestStoreSchemas: + def test_store_create_valid(self): + s = StoreCreate(name="Meijer", slug=StoreSlug.MEIJER) + assert s.slug == StoreSlug.MEIJER + + def test_store_create_invalid_slug(self): + with pytest.raises(ValidationError): + StoreCreate(name="Walmart", slug="walmart") + + def test_store_read_from_attributes(self): + data = { + "id": uuid.uuid4(), + "name": "Kroger", + "slug": StoreSlug.KROGER, + "logo_url": None, + "website_url": None, + "created_at": datetime.now(UTC), + "updated_at": datetime.now(UTC), + } + s = StoreRead(**data) + assert s.slug == StoreSlug.KROGER + + +class TestStoreLocationSchemas: + def test_location_create(self): + loc = StoreLocationCreate( + store_id=uuid.uuid4(), + address="456 Oak Ave", + city="Detroit", + state="MI", + zip="48201", + ) + assert loc.city == "Detroit" + + +class TestUserSchemas: + def test_user_create_valid(self): + u = UserCreate(email="test@example.com", password="secret123") + assert u.email == "test@example.com" + + def test_user_create_invalid_email(self): + with pytest.raises(ValidationError): + UserCreate(email="not-an-email", password="secret123") + + +class TestUserStoreAccountSchemas: + def test_account_create_with_status(self): + a = UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + status=AccountStatus.EXPIRED, + ) + assert a.status == AccountStatus.EXPIRED + + def test_account_create_default_status(self): + a = UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + ) + assert a.status == AccountStatus.ACTIVE + + def test_account_create_invalid_status(self): + with pytest.raises(ValidationError): + UserStoreAccountCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + status="invalid_status", + ) + + +class TestPurchaseSchemas: + def test_purchase_create_with_items(self): + p = PurchaseCreate( + user_id=uuid.uuid4(), + store_id=uuid.uuid4(), + receipt_id="RCP-001", + purchase_date=date(2026, 3, 15), + total=Decimal("42.50"), + items=[ + PurchaseItemCreate( + product_name_raw="Milk", + unit_price=Decimal("3.49"), + extended_price=Decimal("3.49"), + ), + ], + ) + assert len(p.items) == 1 + assert p.items[0].quantity == Decimal("1") + + +class TestNormalizedProductSchemas: + def test_product_create_with_enums(self): + p = NormalizedProductCreate( + canonical_name="Whole Milk, 1 Gallon", + category=ProductCategory.DAIRY, + size_unit=SizeUnit.FL_OZ, + upc_variants=["0041250000001"], + ) + assert p.category == ProductCategory.DAIRY + + def test_product_create_invalid_category(self): + with pytest.raises(ValidationError): + NormalizedProductCreate( + canonical_name="Test", + category="invalid_category", + ) + + +class TestPriceHistorySchemas: + def test_price_create(self): + p = PriceHistoryCreate( + normalized_product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + source=PriceSource.RECEIPT, + ) + assert p.source == PriceSource.RECEIPT + + def test_price_create_invalid_source(self): + with pytest.raises(ValidationError): + PriceHistoryCreate( + normalized_product_id=uuid.uuid4(), + store_id=uuid.uuid4(), + observed_date=date(2026, 3, 15), + regular_price=Decimal("4.99"), + source="invalid_source", + ) + + +class TestCouponSchemas: + def test_coupon_create(self): + c = CouponCreate( + store_id=uuid.uuid4(), + title="BOGO Chips", + discount_type=DiscountType.BOGO, + ) + assert c.discount_type == DiscountType.BOGO + + def test_coupon_create_invalid_discount_type(self): + with pytest.raises(ValidationError): + CouponCreate( + store_id=uuid.uuid4(), + title="Test", + discount_type="free_stuff", + ) + + +class TestShrinkflationEventSchemas: + def test_shrinkflation_create(self): + s = ShrinkflationEventCreate( + normalized_product_id=uuid.uuid4(), + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit=SizeUnit.OZ, + new_unit=SizeUnit.OZ, + confidence=Decimal("0.95"), + ) + assert s.old_unit == SizeUnit.OZ + + def test_shrinkflation_create_invalid_unit(self): + with pytest.raises(ValidationError): + ShrinkflationEventCreate( + normalized_product_id=uuid.uuid4(), + detected_date=date(2026, 3, 10), + old_size="18", + new_size="15.4", + old_unit="bushels", + new_unit=SizeUnit.OZ, + ) + + +class TestEventEnvelope: + def test_valid_event(self): + e = EventEnvelope( + event_type=EventType.RECEIPTS_INGESTED, + timestamp=datetime.now(UTC), + service="receiptwitness", + payload={"receipt_id": "RCP-001"}, + ) + assert e.event_type == EventType.RECEIPTS_INGESTED + + def test_invalid_event_type(self): + with pytest.raises(ValidationError): + EventEnvelope( + event_type="invalid.event", + timestamp=datetime.now(UTC), + service="test", + payload={}, + ) diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 0000000..8007f9c --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,357 @@ +"""Tests for the seed data generator.""" + +import random + +from faker import Faker + +from cartsnitch_common.seed.config import ( + NUM_ACTIVE_USERS, + NUM_COUPONS, + NUM_PRICE_HISTORY, + NUM_PRODUCTS, + NUM_PURCHASE_ITEMS, + NUM_PURCHASES, + NUM_SHRINKFLATION_EVENTS, + NUM_STORES, + NUM_USERS, + SEED_END_DATE, + SEED_START_DATE, + SEED_VALUE, +) +from cartsnitch_common.seed.generators.coupons import generate_coupons +from cartsnitch_common.seed.generators.prices import generate_price_history +from cartsnitch_common.seed.generators.products import generate_products +from cartsnitch_common.seed.generators.purchases import generate_purchase_items, generate_purchases +from cartsnitch_common.seed.generators.shrinkflation import generate_shrinkflation_events +from cartsnitch_common.seed.generators.stores import generate_store_locations, generate_stores +from cartsnitch_common.seed.generators.users import generate_users + + +def _seed() -> None: + random.seed(SEED_VALUE) + Faker.seed(SEED_VALUE) + + +def _make_fake() -> Faker: + return Faker() + + +# --------------------------------------------------------------------------- +# Stores +# --------------------------------------------------------------------------- + + +def test_generate_stores_count() -> None: + stores = generate_stores() + assert len(stores) == NUM_STORES + + +def test_generate_stores_deterministic() -> None: + stores_a = generate_stores() + stores_b = generate_stores() + # Stores are fixed (no RNG), so slugs are stable + slugs_a = {s["slug"] for s in stores_a} + slugs_b = {s["slug"] for s in stores_b} + assert slugs_a == slugs_b + + +def test_generate_store_locations_count() -> None: + stores = generate_stores() + locs = generate_store_locations(stores) + assert len(locs) == 15 # 3 stores * 5 locations + + +def test_generate_store_locations_fk() -> None: + stores = generate_stores() + locs = generate_store_locations(stores) + store_ids = {s["id"] for s in stores} + for loc in locs: + assert loc["store_id"] in store_ids + + +# --------------------------------------------------------------------------- +# Users +# --------------------------------------------------------------------------- + + +def test_generate_users_count() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + assert len(users) == NUM_USERS + + +def test_generate_users_active_count() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + active = [u for u in users if u["_active"]] + assert len(active) == NUM_ACTIVE_USERS + + +def test_generate_users_deterministic() -> None: + _seed() + fake_a = _make_fake() + users_a = generate_users(fake_a) + + _seed() + fake_b = _make_fake() + users_b = generate_users(fake_b) + + # Emails should match (same seed → same Faker output) + emails_a = [u["email"] for u in users_a] + emails_b = [u["email"] for u in users_b] + assert emails_a == emails_b + + +def test_generate_users_unique_emails() -> None: + _seed() + fake = _make_fake() + users = generate_users(fake) + emails = [u["email"] for u in users] + assert len(emails) == len(set(emails)) + + +# --------------------------------------------------------------------------- +# Products +# --------------------------------------------------------------------------- + + +def test_generate_products_count() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + assert len(products) == NUM_PRODUCTS + + +def test_generate_products_deterministic() -> None: + _seed() + fake_a = _make_fake() + products_a = generate_products(fake_a) + + _seed() + fake_b = _make_fake() + products_b = generate_products(fake_b) + + names_a = [p["canonical_name"] for p in products_a] + names_b = [p["canonical_name"] for p in products_b] + assert names_a == names_b + + +def test_generate_products_have_categories() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + for product in products: + assert product["category"] is not None + + +def test_generate_products_have_upc_variants() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + for product in products: + assert product["upc_variants"] + assert isinstance(product["upc_variants"], list) + assert len(product["upc_variants"]) >= 1 + + +# --------------------------------------------------------------------------- +# Purchases +# --------------------------------------------------------------------------- + + +def test_generate_purchases_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + assert len(purchases) == NUM_PURCHASES + + +def test_generate_purchases_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + + user_ids = {u["id"] for u in users} + store_ids = {s["id"] for s in stores} + for p in purchases: + assert p["user_id"] in user_ids + assert p["store_id"] in store_ids + + +def test_generate_purchase_items_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + # Should be close to target (within 20%) + assert abs(len(items) - NUM_PURCHASE_ITEMS) < NUM_PURCHASE_ITEMS * 0.20 + + +def test_generate_purchase_items_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + + purchase_ids = {p["id"] for p in purchases} + product_ids = {p["id"] for p in products} + for item in items: + assert item["purchase_id"] in purchase_ids + assert item["normalized_product_id"] in product_ids + + +# --------------------------------------------------------------------------- +# Price History +# --------------------------------------------------------------------------- + + +def test_generate_price_history_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + # Should be within 10% of target + assert abs(len(prices) - NUM_PRICE_HISTORY) < NUM_PRICE_HISTORY * 0.10 + + +def test_generate_price_history_fk() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + + product_ids = {p["id"] for p in products} + store_ids = {s["id"] for s in stores} + for ph in prices: + assert ph["normalized_product_id"] in product_ids + assert ph["store_id"] in store_ids + assert ph["regular_price"] > 0 + + +def test_price_history_dates_in_range() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + store_locs = generate_store_locations(stores) + users = generate_users(fake) + purchases = generate_purchases(users, stores, store_locs) + products = generate_products(fake) + items = generate_purchase_items(purchases, products) + prices = generate_price_history(products, stores, items) + + for ph in prices: + assert SEED_START_DATE <= ph["observed_date"] <= SEED_END_DATE + + +# --------------------------------------------------------------------------- +# Coupons +# --------------------------------------------------------------------------- + + +def test_generate_coupons_count() -> None: + _seed() + fake = _make_fake() + stores = generate_stores() + products = generate_products(fake) + coupons = generate_coupons(fake, products, stores) + assert len(coupons) == NUM_COUPONS + + +def test_generate_coupons_mix() -> None: + """Verify ~60% expired and ~40% active.""" + _seed() + fake = _make_fake() + stores = generate_stores() + products = generate_products(fake) + coupons = generate_coupons(fake, products, stores) + + expired = [c for c in coupons if c["valid_to"] < SEED_END_DATE] + active = [c for c in coupons if c["valid_to"] >= SEED_END_DATE] + # Allow ±15% variance from target + assert len(expired) / NUM_COUPONS > 0.45 + assert len(active) / NUM_COUPONS > 0.25 + + +# --------------------------------------------------------------------------- +# Shrinkflation +# --------------------------------------------------------------------------- + + +def test_generate_shrinkflation_count() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + assert len(events) == NUM_SHRINKFLATION_EVENTS + + +def test_generate_shrinkflation_fk() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + product_ids = {p["id"] for p in products} + for event in events: + assert event["normalized_product_id"] in product_ids + + +def test_generate_shrinkflation_price_held_or_increased() -> None: + """Validate shrinkflation: new_size < old_size, price maintained or up.""" + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + for event in events: + old_size = float(event["old_size"]) + new_size = float(event["new_size"]) + assert new_size < old_size, f"Expected size reduction: {old_size} -> {new_size}" + if event["price_at_old_size"] and event["price_at_new_size"]: + # Price should be maintained or increased (not significantly dropped) + assert float(event["price_at_new_size"]) >= float(event["price_at_old_size"]) * 0.95 + + +def test_generate_shrinkflation_confidence_range() -> None: + _seed() + fake = _make_fake() + products = generate_products(fake) + events = generate_shrinkflation_events(products) + for event in events: + assert 0 <= float(event["confidence"]) <= 1.0 + + +# --------------------------------------------------------------------------- +# Dry-run smoke test +# --------------------------------------------------------------------------- + + +def test_dry_run_does_not_raise() -> None: + """Smoke test the full run_seed in dry-run mode.""" + from cartsnitch_common.seed.runner import run_seed + + run_seed(dry_run=True, seed_value=SEED_VALUE)