Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a2b0e7cbd3 | |||
| 6d7d54729c | |||
| 895ad77850 | |||
| 5f1570e6d2 | |||
| 96c0f89a03 | |||
| 0f8aa2fe47 | |||
| 6903c7dde3 | |||
| 2946ac8dc5 | |||
| 6717e105f4 | |||
| 01ea36c5aa | |||
| 0bb4b7d183 | |||
| 7b9194a152 | |||
| ca8cf2a80a | |||
| 5bb0a5817b | |||
| 2444219f75 | |||
| c707caea41 | |||
| 30d670a257 | |||
| cfa4d8fa91 | |||
| 39e8d5c9f9 | |||
| 44c475265e | |||
| 8e1f61214c | |||
| fb1c5fb929 | |||
| 75be08ccf3 | |||
| 5596e22d0c | |||
| f45a49059e | |||
| 47ba602b02 | |||
| 5b12625e3f | |||
| d7a4086647 | |||
| b43ec1fb9b | |||
| 129f0adc96 | |||
| 587d444773 | |||
| ea789378dd | |||
| 2f096c985a | |||
| ad218c07ec | |||
| fff9f6f63a | |||
| b0ea4767b6 | |||
| c1778074e3 | |||
| 5de258220e | |||
| 003c62da3e | |||
| 57ce4315a1 | |||
| 7426ff1909 | |||
| 782448a54a | |||
| b9a66dfc8b | |||
| 7a1267de79 | |||
| 4415c56a53 | |||
| da8b413f76 | |||
| dd6a683b90 | |||
| cf8e821bdc | |||
| c9be9324cf | |||
| cc0957fc92 | |||
| f3a7b33093 | |||
| 342906c9d1 | |||
| b736e62d4f | |||
| 4cf6f91e95 | |||
| 27fe957074 | |||
| fc99e8a82e | |||
| cb1d926fc4 | |||
| fc689a3f90 | |||
| d2337a7ef7 | |||
| b7e7960f35 | |||
| aa4da81b6e | |||
| ce9e71c793 | |||
| e662ff5fab | |||
| 656c8d3842 | |||
| 853d722044 | |||
| 61540905dd | |||
| bea3342042 | |||
| 95317884ff | |||
| ca0dbd0e63 | |||
| cdcffc8582 | |||
| 8cccb8cbf0 | |||
| d201753d83 | |||
| 516697b4bd | |||
| b3aa18d7df | |||
| 6e681b9010 | |||
| 979a671300 | |||
| 860dd827d3 | |||
| 7d2e0ba64e | |||
| 118946898b | |||
| 90c81f9c8f | |||
| 4baac1ae26 | |||
| 0f1e158e89 | |||
| a9101246c9 | |||
| cf4ae49ad7 | |||
| 634d54b7fc | |||
| c74a4226f4 | |||
| 14c8aa5797 | |||
| 77c45e7eac | |||
| d6175760d1 | |||
| 6a130a9d76 | |||
| 38c860f1bb | |||
| 91ff8f76d0 | |||
| ab358f44bb | |||
| 5b8d132948 | |||
| 66565fff5c | |||
| a65361106c | |||
| 66376f6a87 | |||
| 580864ac69 | |||
| e8a53399c2 | |||
| b8091e367e | |||
| d0c887e29f | |||
| c81e14b8e7 | |||
| ec81004268 | |||
| fb6f4a0ed4 | |||
| e6f09a0212 | |||
| 58844b33fe | |||
| 0000297e4f | |||
| e572a32021 | |||
| 0789de39f0 | |||
| e57baa4468 | |||
| e42b7e1a66 | |||
| 265f2ae654 |
+157
-11
@@ -11,16 +11,17 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: cartsnitch/cartsnitch
|
||||
AUTH_IMAGE_NAME: cartsnitch/auth
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -34,7 +35,7 @@ jobs:
|
||||
run: npx tsc --noEmit
|
||||
|
||||
test:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -45,17 +46,57 @@ jobs:
|
||||
- name: Run tests
|
||||
run: npx vitest run
|
||||
|
||||
build-and-push:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
needs: [lint, test]
|
||||
lighthouse:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
node-version: "20"
|
||||
cache: npm
|
||||
- run: npm ci
|
||||
- run: npm run build
|
||||
- name: Install Chromium for Lighthouse
|
||||
run: |
|
||||
npm install -g playwright
|
||||
npx playwright install --with-deps chromium
|
||||
- name: Start preview server
|
||||
run: |
|
||||
npm run preview &
|
||||
npx wait-on http://localhost:4173/ --timeout 30000
|
||||
- name: Run Lighthouse CI
|
||||
run: |
|
||||
CHROME_PATH=$(find /home/runner/.cache/ms-playwright -name chrome -type f 2>/dev/null | head -1)
|
||||
npm install -g @lhci/cli
|
||||
LHCI_CHROME_PATH="$CHROME_PATH" lhci autorun
|
||||
|
||||
build-and-push:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then
|
||||
VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
|
||||
VERSION="${DATE_TAG}.2"
|
||||
else
|
||||
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
|
||||
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
|
||||
fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
echo "CalVer tag: $VERSION"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
@@ -72,6 +113,7 @@ jobs:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
@@ -82,3 +124,107 @@ jobs:
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
target: prod
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Create git tag
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
git tag "v${{ steps.calver.outputs.version }}"
|
||||
git push origin "v${{ steps.calver.outputs.version }}"
|
||||
|
||||
build-and-push-auth:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then
|
||||
VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
|
||||
VERSION="${DATE_TAG}.2"
|
||||
else
|
||||
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
|
||||
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
|
||||
fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (auth)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.AUTH_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push auth Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./auth
|
||||
file: ./auth/Dockerfile
|
||||
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
deploy-dev:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [build-and-push, build-and-push-auth]
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: actions/create-github-app-token@v1
|
||||
with:
|
||||
app-id: ${{ secrets.CARTSNITCH_APP_ID }}
|
||||
private-key: ${{ secrets.CARTSNITCH_APP_PRIVATE_KEY }}
|
||||
owner: ${{ github.repository_owner }}
|
||||
repositories: infra
|
||||
|
||||
- name: Checkout infra repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: cartsnitch/infra
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
ref: main
|
||||
path: infra
|
||||
|
||||
- name: Install kubectl
|
||||
uses: azure/setup-kubectl@v4
|
||||
|
||||
- name: Install kustomize
|
||||
uses: imranismail/setup-kustomize@v2
|
||||
|
||||
- name: Update dev overlay image tag
|
||||
run: |
|
||||
cd infra/apps/overlays/dev
|
||||
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ needs.build-and-push.outputs.calver_tag }}
|
||||
kustomize edit set image ghcr.io/cartsnitch/auth:${{ needs.build-and-push-auth.outputs.calver_tag }}
|
||||
|
||||
- name: Commit and push to infra
|
||||
run: |
|
||||
cd infra
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/dev/kustomization.yaml
|
||||
git commit -m "ci(dev): update cartsnitch and auth images to ${{ needs.build-and-push.outputs.calver_tag }}"
|
||||
git push origin main
|
||||
|
||||
@@ -11,6 +11,7 @@ node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
.env
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
# CartSnitch Frontend
|
||||
# CartSnitch Monorepo
|
||||
|
||||
## Project Context
|
||||
|
||||
CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/cartsnitch`) is the mobile-first Progressive Web App — the flagship repo and primary user interface.
|
||||
CartSnitch is a self-hosted grocery price intelligence platform. This repo (`cartsnitch/cartsnitch`) is the **monorepo** containing the flagship frontend PWA and core backend services.
|
||||
|
||||
**GitHub org:** github.com/cartsnitch
|
||||
**Domain:** cartsnitch.com
|
||||
|
||||
### CartSnitch Services
|
||||
### Monorepo Layout
|
||||
|
||||
| Directory | Service | Purpose |
|
||||
|-----------|---------|---------|
|
||||
| `/` (root) | Frontend | React PWA, mobile-first (this directory) |
|
||||
| `auth/` | Auth | Better-Auth Node.js service (session management, email/password, OAuth) |
|
||||
| `api/` | API Gateway | Frontend-facing REST API |
|
||||
| `common/` | Common | Shared Python models, schemas, Alembic migrations |
|
||||
| `receiptwitness/` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
|
||||
### Other CartSnitch Repos (still separate)
|
||||
|
||||
| Repo | Service | Purpose |
|
||||
|------|---------|---------|
|
||||
| `cartsnitch/common` | — | Shared models, schemas, utilities |
|
||||
| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
| `cartsnitch/api` | API Gateway | Frontend-facing REST API |
|
||||
| `cartsnitch/cartsnitch` | Frontend | React PWA, mobile-first (this repo) |
|
||||
| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison |
|
||||
| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring |
|
||||
| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization |
|
||||
@@ -161,9 +167,13 @@ frontend/
|
||||
|
||||
All data comes from the CartSnitch API gateway (`cartsnitch/api`). Base URL configured via environment variable `VITE_API_URL`.
|
||||
|
||||
- JWT auth: store access token in memory (not localStorage), refresh token in httpOnly cookie if possible, or secure storage.
|
||||
- **Authentication via Better-Auth** (`auth/` service). Sessions are managed via httpOnly cookies — no tokens in localStorage or memory.
|
||||
- Auth service URL configured via `VITE_AUTH_URL` (default: `http://localhost:3001`)
|
||||
- Frontend uses `better-auth/react` client for sign-in, sign-up, sign-out, and `useSession()` hook
|
||||
- API gateway validates sessions by querying the shared `sessions` table in Postgres
|
||||
- Both cookie-based and Bearer token auth are supported (cookies for web, Bearer for API clients)
|
||||
- TanStack Query handles caching, background refetching, and optimistic updates.
|
||||
- API client should handle 401 responses by attempting token refresh before retrying.
|
||||
- API client sends `credentials: 'include'` on all requests to forward session cookies.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
|
||||
+5
-4
@@ -9,13 +9,14 @@ RUN npm ci
|
||||
COPY . .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Production
|
||||
FROM nginx:stable-alpine AS prod
|
||||
# Stage 2: Production — uses nginxinc/nginx-unprivileged which runs as non-root (UID 101)
|
||||
FROM nginxinc/nginx-unprivileged:stable-alpine AS prod
|
||||
|
||||
COPY --from=build /app/dist /usr/share/nginx/html
|
||||
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
EXPOSE 80
|
||||
USER 101
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget -qO- http://localhost/health || exit 1
|
||||
CMD wget -qO- http://localhost:8080/health || exit 1
|
||||
|
||||
@@ -1,73 +1 @@
|
||||
# React + TypeScript + Vite
|
||||
|
||||
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||
|
||||
Currently, two official plugins are available:
|
||||
|
||||
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Oxc](https://oxc.rs)
|
||||
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/)
|
||||
|
||||
## React Compiler
|
||||
|
||||
The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
|
||||
|
||||
## Expanding the ESLint configuration
|
||||
|
||||
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||
|
||||
```js
|
||||
export default defineConfig([
|
||||
globalIgnores(['dist']),
|
||||
{
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
extends: [
|
||||
// Other configs...
|
||||
|
||||
// Remove tseslint.configs.recommended and replace with this
|
||||
tseslint.configs.recommendedTypeChecked,
|
||||
// Alternatively, use this for stricter rules
|
||||
tseslint.configs.strictTypeChecked,
|
||||
// Optionally, add this for stylistic rules
|
||||
tseslint.configs.stylisticTypeChecked,
|
||||
|
||||
// Other configs...
|
||||
],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
// other options...
|
||||
},
|
||||
},
|
||||
])
|
||||
```
|
||||
|
||||
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||
|
||||
```js
|
||||
// eslint.config.js
|
||||
import reactX from 'eslint-plugin-react-x'
|
||||
import reactDom from 'eslint-plugin-react-dom'
|
||||
|
||||
export default defineConfig([
|
||||
globalIgnores(['dist']),
|
||||
{
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
extends: [
|
||||
// Other configs...
|
||||
// Enable lint rules for React
|
||||
reactX.configs['recommended-typescript'],
|
||||
// Enable lint rules for React DOM
|
||||
reactDom.configs.recommended,
|
||||
],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
// other options...
|
||||
},
|
||||
},
|
||||
])
|
||||
```
|
||||
# CartSnitch
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
.git
|
||||
.github
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
__pycache__
|
||||
*.py[cod]
|
||||
*.egg-info
|
||||
dist
|
||||
.venv
|
||||
.env
|
||||
tests
|
||||
openapi.json
|
||||
CLAUDE.md
|
||||
README.md
|
||||
Vendored
+164
@@ -0,0 +1,164 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: cartsnitch/api
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: pip
|
||||
- run: pip install ruff
|
||||
- name: Ruff lint
|
||||
run: ruff check .
|
||||
- name: Ruff format check
|
||||
run: ruff format --check .
|
||||
|
||||
typecheck:
|
||||
runs-on: runners-cartsnitch
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: pip
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
|
||||
- name: Install cartsnitch-common from GitHub
|
||||
run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git"
|
||||
- run: pip install -e ".[dev]" mypy
|
||||
- name: Type check
|
||||
run: mypy src/cartsnitch_api
|
||||
|
||||
test:
|
||||
runs-on: runners-cartsnitch
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
env:
|
||||
POSTGRES_USER: cartsnitch
|
||||
POSTGRES_PASSWORD: cartsnitch_test
|
||||
POSTGRES_DB: cartsnitch_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test
|
||||
CARTSNITCH_REDIS_URL: redis://localhost:6379/0
|
||||
CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: pip
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
|
||||
- name: Install cartsnitch-common from GitHub
|
||||
run: pip install "cartsnitch-common @ git+https://github.com/cartsnitch/common.git"
|
||||
- run: pip install -e ".[dev]"
|
||||
- name: Run tests
|
||||
run: pytest --tb=short -q
|
||||
|
||||
build-and-push:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then
|
||||
VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
|
||||
VERSION="${DATE_TAG}.2"
|
||||
else
|
||||
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
|
||||
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
|
||||
fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
echo "CalVer tag: $VERSION"
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
target: prod
|
||||
|
||||
- name: Create git tag
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
git tag "v${{ steps.calver.outputs.version }}"
|
||||
git push origin "v${{ steps.calver.outputs.version }}"
|
||||
@@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
dist/
|
||||
.venv/
|
||||
.env
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
openapi.json
|
||||
+175
@@ -0,0 +1,175 @@
|
||||
# CartSnitch API Gateway
|
||||
|
||||
## Project Context
|
||||
|
||||
CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/api`) is the public-facing API gateway that serves the frontend and proxies requests to internal services.
|
||||
|
||||
**GitHub org:** github.com/cartsnitch
|
||||
**Domain:** cartsnitch.com
|
||||
|
||||
### CartSnitch Services
|
||||
|
||||
| Repo | Service | Purpose |
|
||||
|------|---------|---------|
|
||||
| `cartsnitch/common` | — | Shared models, schemas, utilities |
|
||||
| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
| `cartsnitch/api` | API Gateway | Frontend-facing REST API (this repo) |
|
||||
| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) |
|
||||
| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison |
|
||||
| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring |
|
||||
| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization |
|
||||
| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations |
|
||||
|
||||
### Architecture Decisions
|
||||
|
||||
- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline.
|
||||
- **Shared DB:** One PostgreSQL cluster. This service reads from all tables for serving frontend queries. Models come from `cartsnitch-common`.
|
||||
- **Inter-service comms:** REST to internal services, Redis pub/sub for event subscriptions.
|
||||
- **Target scale:** 500–1,000 users initially.
|
||||
|
||||
## What This Service Does
|
||||
|
||||
The API Gateway is the single entry point for the frontend PWA and any external consumers. It:
|
||||
|
||||
1. **Handles user authentication** — registration, login, JWT token management
|
||||
2. **Serves purchase/product/price data** — reads from the shared DB
|
||||
3. **Proxies scraping operations** — forwards scrape triggers to ReceiptWitness
|
||||
4. **Serves coupon/deal data** — reads from shared DB (written by ClipArtist)
|
||||
5. **Serves alerts** — price increase alerts (StickerShock), shrinkflation alerts (ShrinkRay)
|
||||
6. **Provides public data endpoints** — aggregate price trends for the transparency/shaming features
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- Python 3.12+
|
||||
- FastAPI (async)
|
||||
- SQLAlchemy 2.0 (via `cartsnitch-common`, read-heavy)
|
||||
- Pydantic v2 (request/response validation)
|
||||
- python-jose or PyJWT (JWT auth)
|
||||
- passlib + bcrypt (password hashing)
|
||||
- httpx (async HTTP client for proxying to internal services)
|
||||
- Redis (subscribe to events for websocket push, caching)
|
||||
- uvicorn (ASGI server)
|
||||
|
||||
## Repo Structure
|
||||
|
||||
```
|
||||
api/
|
||||
├── CLAUDE.md
|
||||
├── README.md
|
||||
├── pyproject.toml
|
||||
├── Dockerfile
|
||||
├── docker-compose.yml
|
||||
├── src/
|
||||
│ └── cartsnitch_api/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # Service-specific settings
|
||||
│ ├── main.py # FastAPI app factory, lifespan, middleware
|
||||
│ ├── auth/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── jwt.py # JWT creation/validation
|
||||
│ │ ├── passwords.py # Hashing, verification
|
||||
│ │ ├── dependencies.py # FastAPI dependency injection (get_current_user)
|
||||
│ │ └── routes.py # /auth/register, /auth/login, /auth/refresh
|
||||
│ ├── routes/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── purchases.py # Purchase history endpoints
|
||||
│ │ ├── products.py # Normalized product catalog
|
||||
│ │ ├── prices.py # Price history and trends
|
||||
│ │ ├── coupons.py # Active coupons and deals
|
||||
│ │ ├── alerts.py # Price increase / shrinkflation alerts
|
||||
│ │ ├── stores.py # Store info, user store account management
|
||||
│ │ ├── scraping.py # Proxy to ReceiptWitness (trigger scrape, status)
|
||||
│ │ ├── shopping.py # Optimized shopping list (proxy to ClipArtist)
|
||||
│ │ ├── public.py # Public price transparency endpoints (no auth)
|
||||
│ │ └── health.py
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── receiptwitness.py # HTTP client for ReceiptWitness internal API
|
||||
│ │ ├── stickershock.py # HTTP client for StickerShock internal API
|
||||
│ │ ├── clipartist.py # HTTP client for ClipArtist internal API
|
||||
│ │ └── shrinkray.py # HTTP client for ShrinkRay internal API
|
||||
│ ├── middleware/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── cors.py
|
||||
│ │ └── rate_limit.py
|
||||
│ └── cache.py # Redis caching helpers
|
||||
└── tests/
|
||||
├── conftest.py
|
||||
├── test_auth/
|
||||
├── test_routes/
|
||||
└── test_services/
|
||||
```
|
||||
|
||||
## API Endpoint Design
|
||||
|
||||
### Auth
|
||||
- `POST /auth/register` — create account
|
||||
- `POST /auth/login` — get JWT access + refresh tokens
|
||||
- `POST /auth/refresh` — refresh access token
|
||||
- `GET /auth/me` — current user profile
|
||||
|
||||
### Store Accounts
|
||||
- `GET /stores` — list supported stores
|
||||
- `GET /me/stores` — list user's connected store accounts + sync status
|
||||
- `POST /me/stores/{store_slug}/connect` — initiate store connection flow
|
||||
- `DELETE /me/stores/{store_slug}` — disconnect store account
|
||||
|
||||
### Purchases
|
||||
- `GET /purchases` — list user's purchases (paginated, filterable by store/date)
|
||||
- `GET /purchases/{id}` — purchase detail with line items
|
||||
- `GET /purchases/stats` — spending summary (by store, by category, by period)
|
||||
|
||||
### Products
|
||||
- `GET /products` — normalized product catalog (search, filter)
|
||||
- `GET /products/{id}` — product detail with cross-store price comparison
|
||||
- `GET /products/{id}/prices` — price history for a product across stores
|
||||
|
||||
### Prices
|
||||
- `GET /prices/trends` — aggregate price trends (public-capable)
|
||||
- `GET /prices/increases` — recent significant price increases
|
||||
- `GET /prices/comparison` — compare specific items across stores
|
||||
|
||||
### Coupons
|
||||
- `GET /coupons` — active coupons/deals (filterable by store)
|
||||
- `GET /coupons/relevant` — coupons relevant to user's purchase history
|
||||
|
||||
### Shopping
|
||||
- `POST /shopping/optimize` — input: shopping list → output: store-split + coupons
|
||||
- `GET /shopping/lists` — user's saved shopping lists
|
||||
|
||||
### Alerts
|
||||
- `GET /alerts` — user's price increase and shrinkflation alerts
|
||||
- `PUT /alerts/settings` — configure alert thresholds
|
||||
|
||||
### Public (No Auth)
|
||||
- `GET /public/trends/{product_id}` — public price trend for a product
|
||||
- `GET /public/store-comparison` — public store-vs-store price comparison
|
||||
- `GET /public/inflation` — price changes vs CPI baseline
|
||||
|
||||
### Scraping (Proxy to ReceiptWitness)
|
||||
- `POST /scraping/{store_slug}/sync` — trigger a sync for the current user
|
||||
- `GET /scraping/status` — sync status across all stores
|
||||
|
||||
## Authentication
|
||||
|
||||
- JWT-based auth with short-lived access tokens (15 min) and longer refresh tokens (7 days).
|
||||
- Passwords hashed with bcrypt via passlib.
|
||||
- All user-specific endpoints require a valid JWT in the `Authorization: Bearer` header.
|
||||
- Public endpoints under `/public/` do not require auth.
|
||||
- Internal service-to-service calls (ReceiptWitness, etc.) use a shared API key in the `X-Service-Key` header — not user JWTs.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
- **Never push directly to main.** Always create feature branches and open PRs.
|
||||
- Branch naming: `feature/<description>` or `fix/<description>`
|
||||
- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:`
|
||||
- OpenAPI docs auto-generated at `/docs` (Swagger) and `/redoc`.
|
||||
- Write tests for all routes. Use httpx.AsyncClient with FastAPI's TestClient pattern.
|
||||
|
||||
## Important Notes
|
||||
|
||||
- This service is read-heavy on the shared DB. Use async SQLAlchemy sessions.
|
||||
- Consider Redis caching for expensive queries (price trends, product comparisons). Cache invalidation via Redis pub/sub events from other services.
|
||||
- Rate limiting on public endpoints is important — these could get hammered if the price transparency features get attention.
|
||||
- CORS must allow the frontend origin (cartsnitch.com and localhost for dev).
|
||||
- The store connection flow is the trickiest UX challenge: the user needs to authenticate with each retailer, and we need to capture the resulting session. This likely involves a controlled Playwright browser session that the user can see/interact with, or an OAuth-like redirect flow if the retailer supports it (Kroger does for its public API, but not for purchase history access).
|
||||
@@ -0,0 +1,26 @@
|
||||
FROM python:3.12-slim AS build
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpq-dev \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml ./
|
||||
COPY src/ ./src/
|
||||
RUN pip install --no-cache-dir --prefix=/install .
|
||||
|
||||
FROM python:3.12-slim AS prod
|
||||
|
||||
WORKDIR /app
|
||||
RUN adduser --system --group --uid 1000 app
|
||||
COPY --from=build /install /usr/local
|
||||
COPY src/ ./src/
|
||||
|
||||
USER 1000
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"
|
||||
|
||||
CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,36 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = postgresql://OVERRIDE_VIA_ENV_VAR
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Alembic environment configuration for CartSnitch."""
|
||||
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from cartsnitch_api.models import Base # noqa: F401 — imports all models for autogenerate
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC")
|
||||
if not db_url:
|
||||
raise RuntimeError(
|
||||
"CARTSNITCH_DATABASE_URL_SYNC must be set. "
|
||||
"Example: postgresql://user:pass@localhost:5432/cartsnitch"
|
||||
)
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
${imports if imports else ""}
|
||||
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Encrypt existing plaintext session_data with Fernet.
|
||||
|
||||
Revision ID: 001_encrypt_session_data
|
||||
Revises:
|
||||
Create Date: 2026-03-19
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import sqlalchemy as sa
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "001_encrypt_session_data"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
key = os.environ.get("CARTSNITCH_FERNET_KEY")
|
||||
if not key:
|
||||
raise RuntimeError("CARTSNITCH_FERNET_KEY must be set to run this migration")
|
||||
return Fernet(key.encode())
|
||||
|
||||
|
||||
def _is_fernet_token(value: str) -> bool:
|
||||
"""Check if a string looks like a Fernet token (base64 starting with gAAAAA)."""
|
||||
return value.startswith("gAAAAA")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Change column type from JSON to TEXT to hold Fernet ciphertext
|
||||
op.alter_column(
|
||||
"user_store_accounts",
|
||||
"session_data",
|
||||
type_=sa.Text(),
|
||||
existing_type=sa.JSON(),
|
||||
existing_nullable=True,
|
||||
postgresql_using="session_data::text",
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
rows = conn.execute(
|
||||
text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
f = _get_fernet()
|
||||
for row_id, session_data in rows:
|
||||
raw = str(session_data)
|
||||
if _is_fernet_token(raw):
|
||||
continue
|
||||
plaintext = raw if isinstance(session_data, str) else json.dumps(session_data)
|
||||
encrypted = f.encrypt(plaintext.encode()).decode()
|
||||
conn.execute(
|
||||
text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"),
|
||||
{"data": encrypted, "id": row_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
rows = conn.execute(
|
||||
text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
f = _get_fernet()
|
||||
for row_id, session_data in rows:
|
||||
raw = str(session_data)
|
||||
if not _is_fernet_token(raw):
|
||||
continue
|
||||
decrypted = f.decrypt(raw.encode()).decode()
|
||||
conn.execute(
|
||||
text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"),
|
||||
{"data": decrypted, "id": row_id},
|
||||
)
|
||||
|
||||
# Revert column type from TEXT back to JSON
|
||||
op.alter_column(
|
||||
"user_store_accounts",
|
||||
"session_data",
|
||||
type_=sa.JSON(),
|
||||
existing_type=sa.Text(),
|
||||
existing_nullable=True,
|
||||
postgresql_using="session_data::json",
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Add Better-Auth tables and extend users table.
|
||||
|
||||
Creates sessions, accounts, and verifications tables for Better-Auth.
|
||||
Adds email_verified and image columns to existing users table.
|
||||
Migrates password hashes from users.hashed_password to accounts.password.
|
||||
|
||||
Revision ID: 002_better_auth_tables
|
||||
Revises: 001_encrypt_session_data
|
||||
Create Date: 2026-03-28
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "002_better_auth_tables"
|
||||
down_revision = "001_encrypt_session_data"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# --- Extend users table for Better-Auth compatibility ---
|
||||
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
|
||||
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
|
||||
|
||||
# --- Create sessions table ---
|
||||
op.create_table(
|
||||
"sessions",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("token", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("ip_address", sa.Text(), nullable=True),
|
||||
sa.Column("user_agent", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
|
||||
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
|
||||
|
||||
# --- Create accounts table ---
|
||||
op.create_table(
|
||||
"accounts",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("account_id", sa.Text(), nullable=False),
|
||||
sa.Column("provider_id", sa.Text(), nullable=False),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("scope", sa.Text(), nullable=True),
|
||||
sa.Column("id_token", sa.Text(), nullable=True),
|
||||
sa.Column("password", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
|
||||
|
||||
# --- Create verifications table ---
|
||||
op.create_table(
|
||||
"verifications",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("identifier", sa.Text(), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# --- Migrate existing password hashes to accounts table ---
|
||||
# For each user with a hashed_password, create a 'credential' account row
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
for user_id, hashed_password in users:
|
||||
user_id_str = str(user_id)
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
|
||||
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
|
||||
),
|
||||
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("verifications")
|
||||
op.drop_table("accounts")
|
||||
op.drop_index("ix_sessions_user_id", table_name="sessions")
|
||||
op.drop_index("ix_sessions_token", table_name="sessions")
|
||||
op.drop_table("sessions")
|
||||
op.drop_column("users", "image")
|
||||
op.drop_column("users", "email_verified")
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Change users.id and FK columns from uuid to text.
|
||||
|
||||
Better-Auth generates nanoid-style text IDs (e.g. pGud2ln2WAFHC0KYjBVKR4Rc7mM8OcTI),
|
||||
but the users table was using PostgreSQL uuid type, causing INSERT failures.
|
||||
|
||||
Revision ID: 003_fix_user_id_text
|
||||
Revises: 002_better_auth_tables
|
||||
Create Date: 2026-03-31
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "003_fix_user_id_text"
|
||||
down_revision = "002_better_auth_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Drop FK constraints that reference users.id
|
||||
op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"))
|
||||
op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey"))
|
||||
|
||||
# Step 2: Alter users.id from uuid to text
|
||||
op.alter_column("users", "id", existing_type=sa.UUID(), type_=sa.Text(), existing_nullable=False, postgresql_using="id::text")
|
||||
|
||||
# Step 3: Alter user_store_accounts.user_id from uuid to text
|
||||
op.alter_column("user_store_accounts", "user_id", existing_type=sa.UUID(), type_=sa.Text(), existing_nullable=False, postgresql_using="user_id::text")
|
||||
|
||||
# Step 4: Alter purchases.user_id from uuid to text
|
||||
op.alter_column("purchases", "user_id", existing_type=sa.UUID(), type_=sa.Text(), existing_nullable=False, postgresql_using="user_id::text")
|
||||
|
||||
# Step 5: Re-add FK constraints
|
||||
op.execute(text("ALTER TABLE user_store_accounts ADD CONSTRAINT user_store_accounts_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id)"))
|
||||
op.execute(text("ALTER TABLE purchases ADD CONSTRAINT purchases_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id)"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop FK constraints
|
||||
op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey"))
|
||||
op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"))
|
||||
|
||||
# Alter back to UUID
|
||||
op.alter_column("purchases", "user_id", existing_type=sa.Text(), type_=sa.UUID(), existing_nullable=False, postgresql_using="user_id::uuid")
|
||||
op.alter_column("user_store_accounts", "user_id", existing_type=sa.Text(), type_=sa.UUID(), existing_nullable=False, postgresql_using="user_id::uuid")
|
||||
op.alter_column("users", "id", existing_type=sa.Text(), type_=sa.UUID(), existing_nullable=False, postgresql_using="id::uuid")
|
||||
|
||||
# Re-add FK constraints
|
||||
op.execute(text("ALTER TABLE user_store_accounts ADD CONSTRAINT user_store_accounts_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id)"))
|
||||
op.execute(text("ALTER TABLE purchases ADD CONSTRAINT purchases_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id)"))
|
||||
@@ -0,0 +1,58 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "cartsnitch-api"
|
||||
version = "0.1.0"
|
||||
description = "CartSnitch API Gateway — public-facing REST API"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"pydantic[email]>=2.9.0",
|
||||
"pydantic-settings>=2.5.0",
|
||||
"sqlalchemy[asyncio]>=2.0.35",
|
||||
"asyncpg>=0.30.0",
|
||||
"alembic>=1.13,<2.0",
|
||||
"psycopg2>=2.9,<3.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"httpx>=0.27.0",
|
||||
"redis[hiredis]>=5.2.0",
|
||||
"cryptography>=43.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"httpx>=0.27.0",
|
||||
"ruff>=0.7.0",
|
||||
"psycopg2-binary>=2.9,<3.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/cartsnitch_api"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "UP", "B"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"src/cartsnitch_api/**/routes*.py" = ["B008"]
|
||||
"src/cartsnitch_api/**/dependencies.py" = ["B008"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
ignore_missing_imports = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"extends": ["local>cartsnitch/.github:renovate-config"]
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""FastAPI dependency injection for authentication.
|
||||
|
||||
Validates Better-Auth session tokens from cookies or Bearer header.
|
||||
Sessions are verified by querying the shared sessions table directly.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.database import get_db
|
||||
|
||||
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
||||
# but we support Bearer tokens for service-to-service or mobile clients.
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Better-Auth session cookie name
|
||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||
|
||||
|
||||
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
|
||||
"""Validate a Better-Auth session token against the sessions table.
|
||||
|
||||
Returns the user_id (as UUID) if the session is valid and not expired.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid session token",
|
||||
)
|
||||
|
||||
user_id, expires_at = row
|
||||
if expires_at.tzinfo is None:
|
||||
# Treat naive datetimes as UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < datetime.now(UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired",
|
||||
)
|
||||
|
||||
return UUID(str(user_id))
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> UUID:
|
||||
"""Extract and validate the session token from cookie or Authorization header.
|
||||
|
||||
Checks in order:
|
||||
1. Better-Auth session cookie (primary — web clients)
|
||||
2. Bearer token in Authorization header (fallback — API clients)
|
||||
"""
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie
|
||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
token = cookie_token
|
||||
|
||||
# 2. Fall back to Bearer header
|
||||
if not token and credentials:
|
||||
token = credentials.credentials
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
return await _validate_session_token(token, db)
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
if x_service_key != settings.service_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid service key",
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""JWT token creation and validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def create_access_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "access"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def create_refresh_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
try:
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]),
|
||||
)
|
||||
except JWTError as e:
|
||||
raise ValueError(f"Invalid token: {e}") from e
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Password hashing and verification with bcrypt."""
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Auth routes: user profile management.
|
||||
|
||||
Registration, login, refresh, and session management are handled by
|
||||
the Better-Auth service (auth/). This router provides user profile
|
||||
endpoints that query our own user data from the shared database.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
UpdateUserRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from cartsnitch_api.services.auth import AuthService
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.get_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_me(
|
||||
body: UpdateUserRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.update_user(user_id, email=body.email, display_name=body.display_name)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
await svc.delete_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Redis/DragonflyDB caching helpers."""
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class CacheClient:
|
||||
"""Stub for Redis/DragonflyDB caching.
|
||||
|
||||
Will be used for expensive queries: price trends, product comparisons.
|
||||
Cache invalidation via Redis pub/sub events from other services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.url = settings.redis_url
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
# TODO: implement with redis-py async
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
@@ -0,0 +1,53 @@
|
||||
import base64
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_prefix": "CARTSNITCH_"}
|
||||
|
||||
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 15
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
service_key: str = "change-me-in-production"
|
||||
# Valid Fernet key for local dev — MUST be overridden in production
|
||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
|
||||
auth_service_url: str = "http://auth:3001"
|
||||
|
||||
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
||||
|
||||
receiptwitness_url: str = "http://receiptwitness:8001"
|
||||
stickershock_url: str = "http://stickershock:8002"
|
||||
clipartist_url: str = "http://clipartist:8003"
|
||||
shrinkray_url: str = "http://shrinkray:8004"
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_fernet_key(self):
|
||||
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||
if len(decoded) != 32:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_FERNET_KEY must be a valid Fernet key "
|
||||
"(32 bytes, url-safe base64 encoded). "
|
||||
"Generate one with: python -c "
|
||||
"'from cryptography.fernet import Fernet; "
|
||||
"print(Fernet.generate_key().decode())'"
|
||||
) from None
|
||||
return self
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Constants and enums shared across CartSnitch services."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StoreSlug(StrEnum):
|
||||
"""Supported retailer slugs."""
|
||||
|
||||
MEIJER = "meijer"
|
||||
KROGER = "kroger"
|
||||
TARGET = "target"
|
||||
|
||||
|
||||
class AccountStatus(StrEnum):
|
||||
"""User store account link status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DiscountType(StrEnum):
|
||||
"""Coupon discount type."""
|
||||
|
||||
PERCENT = "percent"
|
||||
FIXED = "fixed"
|
||||
BOGO = "bogo"
|
||||
BUY_X_GET_Y = "buy_x_get_y"
|
||||
|
||||
|
||||
class PriceSource(StrEnum):
|
||||
"""Source of a price observation."""
|
||||
|
||||
RECEIPT = "receipt"
|
||||
CATALOG = "catalog"
|
||||
WEEKLY_AD = "weekly_ad"
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Redis pub/sub event types."""
|
||||
|
||||
RECEIPTS_INGESTED = "cartsnitch.receipts.ingested"
|
||||
PRICES_UPDATED = "cartsnitch.prices.updated"
|
||||
PRODUCTS_NORMALIZED = "cartsnitch.products.normalized"
|
||||
COUPONS_UPDATED = "cartsnitch.coupons.updated"
|
||||
ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase"
|
||||
ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation"
|
||||
|
||||
|
||||
class ProductCategory(StrEnum):
|
||||
"""Top-level product categories."""
|
||||
|
||||
PRODUCE = "produce"
|
||||
DAIRY = "dairy"
|
||||
MEAT = "meat"
|
||||
BAKERY = "bakery"
|
||||
FROZEN = "frozen"
|
||||
PANTRY = "pantry"
|
||||
BEVERAGES = "beverages"
|
||||
SNACKS = "snacks"
|
||||
HOUSEHOLD = "household"
|
||||
PERSONAL_CARE = "personal_care"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class MatchConfidence(StrEnum):
|
||||
"""Confidence level for product matching."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class SizeUnit(StrEnum):
|
||||
"""Standardized product size units."""
|
||||
|
||||
OZ = "oz"
|
||||
FL_OZ = "fl_oz"
|
||||
LB = "lb"
|
||||
G = "g"
|
||||
KG = "kg"
|
||||
ML = "ml"
|
||||
L = "l"
|
||||
CT = "ct"
|
||||
PK = "pk"
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Database session management for the API gateway."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,62 @@
|
||||
"""FastAPI app factory for CartSnitch API Gateway."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from cartsnitch_api.auth.routes import router as auth_router
|
||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
|
||||
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
|
||||
from cartsnitch_api.routes.alerts import router as alerts_router
|
||||
from cartsnitch_api.routes.coupons import router as coupons_router
|
||||
from cartsnitch_api.routes.health import router as health_router
|
||||
from cartsnitch_api.routes.prices import router as prices_router
|
||||
from cartsnitch_api.routes.products import router as products_router
|
||||
from cartsnitch_api.routes.public import router as public_router
|
||||
from cartsnitch_api.routes.purchases import router as purchases_router
|
||||
from cartsnitch_api.routes.scraping import router as scraping_router
|
||||
from cartsnitch_api.routes.shopping import router as shopping_router
|
||||
from cartsnitch_api.routes.stores import router as stores_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# TODO: initialize DB session pool, Redis connection, service clients
|
||||
yield
|
||||
# TODO: cleanup connections
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="CartSnitch API",
|
||||
description="Grocery price tracking and shrinkflation detection API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Middleware (order matters — outermost first)
|
||||
add_cors_middleware(app)
|
||||
add_error_monitor_middleware(app)
|
||||
add_rate_limit_middleware(app)
|
||||
|
||||
# Exception handlers
|
||||
add_error_handlers(app)
|
||||
|
||||
# Routers
|
||||
app.include_router(health_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(stores_router)
|
||||
app.include_router(purchases_router)
|
||||
app.include_router(products_router)
|
||||
app.include_router(prices_router)
|
||||
app.include_router(coupons_router)
|
||||
app.include_router(shopping_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(scraping_router)
|
||||
app.include_router(public_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -0,0 +1,16 @@
|
||||
"""CORS middleware configuration."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def add_cors_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Structured error responses and error monitoring.
|
||||
|
||||
Ensures all errors return a consistent JSON shape and never leak stack traces.
|
||||
Provides hooks for error monitoring/alerting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger("cartsnitch_api.errors")
|
||||
|
||||
|
||||
def _error_response(
|
||||
status_code: int,
|
||||
detail: str,
|
||||
code: str | None = None,
|
||||
errors: list[dict] | None = None,
|
||||
) -> JSONResponse:
|
||||
"""Build a consistent error response."""
|
||||
body: dict = {"detail": detail}
|
||||
if code:
|
||||
body["code"] = code
|
||||
if errors:
|
||||
body["errors"] = errors
|
||||
return JSONResponse(status_code=status_code, content=body)
|
||||
|
||||
|
||||
def add_error_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers for consistent error responses."""
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
"""Return 422 with structured field-level error details."""
|
||||
field_errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", ())
|
||||
field_errors.append(
|
||||
{
|
||||
"field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc),
|
||||
"message": err.get("msg", "Invalid value"),
|
||||
"type": err.get("type", "value_error"),
|
||||
}
|
||||
)
|
||||
return _error_response(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Validation error",
|
||||
code="VALIDATION_ERROR",
|
||||
errors=field_errors,
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
"""Wrap HTTP exceptions (Starlette and FastAPI) in consistent format."""
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
return _error_response(
|
||||
status_code=exc.status_code,
|
||||
detail=detail,
|
||||
code=_status_to_code(exc.status_code),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all: log full traceback, return safe 500 to client."""
|
||||
logger.error(
|
||||
"Unhandled exception on %s %s: %s\n%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
_notify_error_monitor(request, exc)
|
||||
|
||||
return _error_response(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
code="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def _status_to_code(status_code: int) -> str:
|
||||
"""Map HTTP status code to a machine-readable error code."""
|
||||
mapping = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "RATE_LIMITED",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
}
|
||||
return mapping.get(status_code, f"HTTP_{status_code}")
|
||||
|
||||
|
||||
# ---------- Error Monitoring ----------
|
||||
|
||||
|
||||
class _ErrorMonitor:
|
||||
"""Simple error counter for monitoring and alerting hooks.
|
||||
|
||||
Tracks error counts and rates. In production, this would forward
|
||||
to an external monitoring service (Prometheus, Sentry, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.error_counts: dict[int, int] = {}
|
||||
self.recent_5xx: list[dict] = []
|
||||
self._max_recent = 100
|
||||
|
||||
def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None:
|
||||
self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1
|
||||
|
||||
if status_code >= 500:
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"status": status_code,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"error": error,
|
||||
}
|
||||
self.recent_5xx.append(entry)
|
||||
if len(self.recent_5xx) > self._max_recent:
|
||||
self.recent_5xx = self.recent_5xx[-self._max_recent :]
|
||||
|
||||
logger.warning(
|
||||
"5xx error recorded: %s %s -> %d (%s)",
|
||||
method,
|
||||
path,
|
||||
status_code,
|
||||
error or "unknown",
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return {
|
||||
"error_counts": dict(self.error_counts),
|
||||
"recent_5xx_count": len(self.recent_5xx),
|
||||
}
|
||||
|
||||
|
||||
_monitor = _ErrorMonitor()
|
||||
|
||||
|
||||
def get_error_monitor() -> _ErrorMonitor:
|
||||
"""Access the global error monitor (for health/metrics endpoints)."""
|
||||
return _monitor
|
||||
|
||||
|
||||
def _notify_error_monitor(request: Request, exc: Exception) -> None:
|
||||
"""Record unhandled exception in the error monitor."""
|
||||
_monitor.record(
|
||||
status_code=500,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc)[:200],
|
||||
)
|
||||
|
||||
|
||||
class ErrorMonitorMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to track all 4xx/5xx responses for monitoring."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable],
|
||||
):
|
||||
response = await call_next(request)
|
||||
|
||||
if response.status_code >= 400:
|
||||
_monitor.record(
|
||||
status_code=response.status_code,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def add_error_monitor_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(ErrorMonitorMiddleware)
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Rate limiting middleware for public and authenticated endpoints.
|
||||
|
||||
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
if current_count >= self.max_requests:
|
||||
retry_after = int(self._hits[key][0] - cutoff) + 1
|
||||
return False, 0, retry_after
|
||||
|
||||
self._hits[key].append(now)
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
# Use last 16 chars of token as key to avoid storing full tokens
|
||||
return f"token:{token[-16:]}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"code": "RATE_LIMITED",
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limiter.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
|
||||
def add_rate_limit_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""SQLAlchemy ORM models — re-exports all models for convenience."""
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import Purchase, PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
"UUIDPrimaryKeyMixin",
|
||||
"Store",
|
||||
"StoreLocation",
|
||||
"User",
|
||||
"UserStoreAccount",
|
||||
"Purchase",
|
||||
"PurchaseItem",
|
||||
"NormalizedProduct",
|
||||
"PriceHistory",
|
||||
"Coupon",
|
||||
"ShrinkflationEvent",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Base model and mixins for all CartSnitch ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all CartSnitch models."""
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin providing created_at / updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin providing a UUID primary key."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Coupon model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import DiscountType
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A coupon or deal for a product at a store."""
|
||||
|
||||
__tablename__ = "coupons"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(String(1000))
|
||||
discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False)
|
||||
discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
valid_from: Mapped[date | None] = mapped_column(Date)
|
||||
valid_to: Mapped[date | None] = mapped_column(Date)
|
||||
requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
coupon_code: Mapped[str | None] = mapped_column(String(100))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="coupons")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""PriceHistory model — tracks product prices over time."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Index, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import PriceSource
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single price observation for a product at a store on a date."""
|
||||
|
||||
__tablename__ = "price_history"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_price_history_product_store_date",
|
||||
"normalized_product_id",
|
||||
"store_id",
|
||||
"observed_date",
|
||||
),
|
||||
)
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
observed_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source: Mapped[PriceSource] = mapped_column(String(20), nullable=False)
|
||||
purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id"))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories")
|
||||
store: Mapped["Store"] = relationship(back_populates="price_histories")
|
||||
purchase_item: Mapped["PurchaseItem | None"] = relationship(
|
||||
back_populates="price_history_entries"
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""NormalizedProduct model — the canonical product identity."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import ProductCategory, SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
|
||||
|
||||
class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Canonical product identity — matches products across retailers."""
|
||||
|
||||
__tablename__ = "normalized_products"
|
||||
|
||||
canonical_name: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
category: Mapped[ProductCategory | None] = mapped_column(String(50))
|
||||
subcategory: Mapped[str | None] = mapped_column(String(100))
|
||||
brand: Mapped[str | None] = mapped_column(String(200))
|
||||
size: Mapped[str | None] = mapped_column(String(50))
|
||||
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
|
||||
|
||||
# Relationships
|
||||
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product")
|
||||
shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Purchase and PurchaseItem models."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Date,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Numeric,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User
|
||||
|
||||
|
||||
class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single shopping trip / receipt."""
|
||||
|
||||
__tablename__ = "purchases"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
|
||||
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
purchase_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
raw_data: Mapped[dict | None] = mapped_column(JSON)
|
||||
ingested_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="purchases")
|
||||
store: Mapped["Store"] = relationship(back_populates="purchases")
|
||||
store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases")
|
||||
items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_purchases_user_store", "user_id", "store_id"),
|
||||
UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
|
||||
)
|
||||
|
||||
|
||||
class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Individual line item on a receipt."""
|
||||
|
||||
__tablename__ = "purchase_items"
|
||||
|
||||
purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False)
|
||||
product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
upc: Mapped[str | None] = mapped_column(String(20))
|
||||
quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1)
|
||||
unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
category_raw: Mapped[str | None] = mapped_column(String(100))
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
|
||||
# Relationships
|
||||
purchase: Mapped["Purchase"] = relationship(back_populates="items")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(
|
||||
back_populates="purchase_items"
|
||||
)
|
||||
price_history_entries: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="purchase_item"
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""ShrinkflationEvent model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
|
||||
|
||||
class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Detected shrinkflation event — product size changed while price held or rose."""
|
||||
|
||||
__tablename__ = "shrinkflation_events"
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
detected_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
old_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
new_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
confidence: Mapped[Decimal] = mapped_column(
|
||||
Numeric(3, 2), nullable=False, default=Decimal("1.00")
|
||||
)
|
||||
notes: Mapped[str | None] = mapped_column(String(1000))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(
|
||||
back_populates="shrinkflation_events"
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Store and StoreLocation models."""
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Float, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import StoreSlug
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.user import UserStoreAccount
|
||||
|
||||
|
||||
class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Supported retailer."""
|
||||
|
||||
__tablename__ = "stores"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True)
|
||||
logo_url: Mapped[str | None] = mapped_column(String(500))
|
||||
website_url: Mapped[str | None] = mapped_column(String(500))
|
||||
|
||||
# Relationships
|
||||
locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store")
|
||||
user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store")
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="store")
|
||||
|
||||
|
||||
class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Physical store location."""
|
||||
|
||||
__tablename__ = "store_locations"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
address: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
city: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
state: Mapped[str] = mapped_column(String(2), nullable=False)
|
||||
zip: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
lat: Mapped[float | None] = mapped_column(Float)
|
||||
lng: Mapped[float | None] = mapped_column(Float)
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="locations")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""User and UserStoreAccount models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import AccountStatus
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.types import EncryptedJSON
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Application user."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(String(100))
|
||||
|
||||
# Relationships
|
||||
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Link between a user and their retailer account credentials."""
|
||||
|
||||
__tablename__ = "user_store_accounts"
|
||||
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
|
||||
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
status: Mapped[AccountStatus] = mapped_column(
|
||||
String(20), nullable=False, default=AccountStatus.ACTIVE
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="store_accounts")
|
||||
store: Mapped["Store"] = relationship(back_populates="user_accounts")
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Alert routes: list alerts, manage settings."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse
|
||||
from cartsnitch_api.services.alerts import AlertService
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[AlertResponse])
|
||||
async def list_alerts(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.list_alerts(user_id)
|
||||
|
||||
|
||||
@router.get("/settings", response_model=AlertSettingsResponse)
|
||||
async def get_alert_settings(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.get_settings(user_id)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_alert_settings(
|
||||
body: AlertSettingsRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Alert settings persistence not yet implemented. "
|
||||
"Use GET /alerts/settings for current defaults.",
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Coupon routes: browse, relevant matches."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import CouponResponse
|
||||
from cartsnitch_api.services.coupons import CouponService
|
||||
|
||||
router = APIRouter(prefix="/coupons", tags=["coupons"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[CouponResponse])
|
||||
async def list_coupons(
|
||||
store_id: UUID | None = Query(None),
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.list_coupons(store_id)
|
||||
|
||||
|
||||
@router.get("/relevant", response_model=list[CouponResponse])
|
||||
async def relevant_coupons(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.relevant_coupons(user_id)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Health check and error metrics endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from cartsnitch_api.auth.dependencies import verify_service_key
|
||||
from cartsnitch_api.middleware.error_handler import get_error_monitor
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
|
||||
async def error_stats():
|
||||
"""Error monitoring stats — internal only (requires X-Service-Key)."""
|
||||
monitor = get_error_monitor()
|
||||
return monitor.get_stats()
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Price routes: trends, increases, comparison."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PriceComparisonResponse,
|
||||
PriceIncreaseResponse,
|
||||
PriceTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.prices import PriceService
|
||||
|
||||
router = APIRouter(prefix="/prices", tags=["prices"])
|
||||
|
||||
|
||||
@router.get("/trends", response_model=list[PriceTrendResponse])
|
||||
async def price_trends(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
category: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_trends(category)
|
||||
|
||||
|
||||
@router.get("/increases", response_model=list[PriceIncreaseResponse])
|
||||
async def price_increases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_increases()
|
||||
|
||||
|
||||
@router.get("/comparison", response_model=list[PriceComparisonResponse])
|
||||
async def price_comparison(
|
||||
product_ids: Annotated[list[UUID], Query()],
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_comparison(product_ids)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Product routes: search/list, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse
|
||||
from cartsnitch_api.services.products import ProductService
|
||||
|
||||
router = APIRouter(prefix="/products", tags=["products"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ProductResponse])
|
||||
async def list_products(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
q: str | None = Query(None),
|
||||
category: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
return await svc.list_products(q, category, page, page_size)
|
||||
|
||||
|
||||
@router.get("/{product_id}", response_model=ProductDetailResponse)
|
||||
async def get_product(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_product(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
|
||||
async def get_product_prices(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_price_history(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Public endpoints: price transparency data (no auth required)."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PublicInflationResponse,
|
||||
PublicStoreComparisonResponse,
|
||||
PublicTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.public import PublicService
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
try:
|
||||
return await svc.get_trend(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||
async def public_store_comparison(
|
||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not product_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one product_id is required",
|
||||
)
|
||||
svc = PublicService(db)
|
||||
return await svc.get_store_comparison(product_ids)
|
||||
|
||||
|
||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
return await svc.get_inflation()
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Purchase routes: list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse
|
||||
from cartsnitch_api.services.purchases import PurchaseService
|
||||
|
||||
router = APIRouter(prefix="/purchases", tags=["purchases"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[PurchaseResponse])
|
||||
async def list_purchases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
store_id: UUID | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.list_purchases(user_id, store_id, page, page_size)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=PurchaseStatsResponse)
|
||||
async def purchase_stats(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.get_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
|
||||
async def get_purchase(
|
||||
purchase_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
try:
|
||||
return await svc.get_purchase(purchase_id, user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found"
|
||||
) from None
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse
|
||||
from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient
|
||||
|
||||
router = APIRouter(prefix="/scraping", tags=["scraping"])
|
||||
|
||||
|
||||
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
|
||||
async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
result = await client.trigger_sync(str(user_id), store_slug)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Sync service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/status", response_model=list[SyncStatusResponse])
|
||||
async def sync_status(user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
return await client.get_sync_status(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Shopping routes: optimize list, saved lists."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse
|
||||
from cartsnitch_api.services.clipartist import ClipArtistClient
|
||||
|
||||
router = APIRouter(prefix="/shopping", tags=["shopping"])
|
||||
|
||||
|
||||
@router.post("/optimize", response_model=OptimizeResponse)
|
||||
async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
result = await client.optimize(
|
||||
user_id=str(user_id),
|
||||
items=[item.model_dump() for item in body.items],
|
||||
preferred_stores=(
|
||||
[str(s) for s in body.preferred_stores] if body.preferred_stores else None
|
||||
),
|
||||
)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Shopping optimization service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping optimization service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/lists", response_model=list[ShoppingListResponse])
|
||||
async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
return await client.get_shopping_lists(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping service",
|
||||
) from None
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Store routes: list stores, manage user store connections."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse
|
||||
from cartsnitch_api.services.stores import StoreService
|
||||
|
||||
router = APIRouter(tags=["stores"])
|
||||
|
||||
|
||||
@router.get("/stores", response_model=list[StoreResponse])
|
||||
async def list_stores(db: AsyncSession = Depends(get_db)):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_stores()
|
||||
|
||||
|
||||
@router.get("/me/stores", response_model=list[StoreAccountResponse])
|
||||
async def list_user_stores(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_user_stores(user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/me/stores/{store_slug}/connect",
|
||||
response_model=StoreAccountResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def connect_store(
|
||||
store_slug: str,
|
||||
body: ConnectStoreRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
return await svc.connect_store(user_id, store_slug, body.credentials)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def disconnect_store(
|
||||
store_slug: str,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
await svc.disconnect_store(user_id, store_slug)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Pydantic v2 request/response schemas for all API endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
# ---------- Auth ----------
|
||||
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
||||
# These schemas are for the profile management endpoints only.
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
email: EmailStr | None = None
|
||||
display_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
display_name: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------- Stores ----------
|
||||
|
||||
|
||||
class StoreResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
logo_url: str | None = None
|
||||
supported: bool = True
|
||||
|
||||
|
||||
class StoreAccountResponse(BaseModel):
|
||||
store: StoreResponse
|
||||
connected: bool
|
||||
last_sync_at: datetime | None = None
|
||||
sync_status: str | None = None
|
||||
|
||||
|
||||
class ConnectStoreRequest(BaseModel):
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
# ---------- Purchases ----------
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
id: UUID
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: float
|
||||
unit_price: float
|
||||
total_price: float
|
||||
|
||||
|
||||
class PurchaseResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
purchased_at: datetime
|
||||
total: float
|
||||
item_count: int
|
||||
|
||||
|
||||
class PurchaseDetailResponse(PurchaseResponse):
|
||||
line_items: list[LineItemResponse]
|
||||
|
||||
|
||||
class PurchaseStatsResponse(BaseModel):
|
||||
total_spent: float
|
||||
purchase_count: int
|
||||
by_store: dict[str, float]
|
||||
by_period: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Products ----------
|
||||
|
||||
|
||||
class ProductResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brand: str | None = None
|
||||
category: str | None = None
|
||||
upc: str | None = None
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class ProductDetailResponse(ProductResponse):
|
||||
prices_by_store: list["StorePriceResponse"]
|
||||
|
||||
|
||||
class StorePriceResponse(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
current_price: float
|
||||
last_seen_at: datetime
|
||||
|
||||
|
||||
# ---------- Prices ----------
|
||||
|
||||
|
||||
class PriceTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list["PricePointResponse"]
|
||||
|
||||
|
||||
class PricePointResponse(BaseModel):
|
||||
date: datetime
|
||||
price: float
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
|
||||
|
||||
class PriceIncreaseResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
store_name: str
|
||||
old_price: float
|
||||
new_price: float
|
||||
increase_pct: float
|
||||
detected_at: datetime
|
||||
|
||||
|
||||
class PriceComparisonResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
prices: list[StorePriceResponse]
|
||||
|
||||
|
||||
# ---------- Coupons ----------
|
||||
|
||||
|
||||
class CouponResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
description: str
|
||||
discount_value: float
|
||||
discount_type: str
|
||||
product_id: UUID | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
# ---------- Shopping ----------
|
||||
|
||||
|
||||
class ShoppingListItemRequest(BaseModel):
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: int = 1
|
||||
|
||||
|
||||
class OptimizeRequest(BaseModel):
|
||||
items: list[ShoppingListItemRequest]
|
||||
preferred_stores: list[UUID] | None = None
|
||||
|
||||
|
||||
class OptimizedStoreTrip(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
items: list["OptimizedItemResponse"]
|
||||
subtotal: float
|
||||
coupons: list[CouponResponse]
|
||||
savings: float
|
||||
|
||||
|
||||
class OptimizedItemResponse(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
product_id: UUID | None = None
|
||||
|
||||
|
||||
class OptimizeResponse(BaseModel):
|
||||
trips: list[OptimizedStoreTrip]
|
||||
total_cost: float
|
||||
total_savings: float
|
||||
|
||||
|
||||
class ShoppingListResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
item_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------- Alerts ----------
|
||||
|
||||
|
||||
class AlertResponse(BaseModel):
|
||||
id: UUID
|
||||
alert_type: str
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
message: str
|
||||
triggered_at: datetime
|
||||
read: bool = False
|
||||
|
||||
|
||||
class AlertSettingsRequest(BaseModel):
|
||||
price_increase_threshold_pct: float | None = None
|
||||
shrinkflation_enabled: bool | None = None
|
||||
email_notifications: bool | None = None
|
||||
|
||||
|
||||
class AlertSettingsResponse(BaseModel):
|
||||
price_increase_threshold_pct: float
|
||||
shrinkflation_enabled: bool
|
||||
email_notifications: bool
|
||||
|
||||
|
||||
# ---------- Scraping ----------
|
||||
|
||||
|
||||
class SyncTriggerResponse(BaseModel):
|
||||
job_id: UUID
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class SyncStatusResponse(BaseModel):
|
||||
store_slug: str
|
||||
status: str
|
||||
last_sync_at: datetime | None = None
|
||||
items_synced: int | None = None
|
||||
|
||||
|
||||
# ---------- Public ----------
|
||||
|
||||
|
||||
class PublicTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list[PricePointResponse]
|
||||
|
||||
|
||||
class PublicStoreComparisonResponse(BaseModel):
|
||||
products: list[PriceComparisonResponse]
|
||||
|
||||
|
||||
class PublicInflationResponse(BaseModel):
|
||||
period: str
|
||||
cartsnitch_index: float
|
||||
cpi_baseline: float
|
||||
categories: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Common ----------
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: str
|
||||
code: str | None = None
|
||||
|
||||
|
||||
# Rebuild forward refs
|
||||
ProductDetailResponse.model_rebuild()
|
||||
PriceTrendResponse.model_rebuild()
|
||||
OptimizedStoreTrip.model_rebuild()
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Alert service — price and shrinkflation alerts for users.
|
||||
|
||||
Alerts are generated by StickerShock and ShrinkRay services and written to the DB.
|
||||
This service reads them for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class AlertService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_alerts(self, user_id: UUID) -> list[dict]:
|
||||
"""List shrinkflation events for products the user has purchased."""
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
|
||||
|
||||
# Get product IDs from user's purchases
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(ShrinkflationEvent)
|
||||
.where(ShrinkflationEvent.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(ShrinkflationEvent.normalized_product))
|
||||
.order_by(ShrinkflationEvent.detected_date.desc())
|
||||
)
|
||||
events = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"alert_type": "shrinkflation",
|
||||
"product_id": e.normalized_product_id,
|
||||
"product_name": e.normalized_product.canonical_name,
|
||||
"message": (
|
||||
f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}"
|
||||
),
|
||||
"triggered_at": e.detected_date,
|
||||
"read": False,
|
||||
}
|
||||
for e in events
|
||||
]
|
||||
|
||||
async def get_settings(self, user_id: UUID) -> dict:
|
||||
# Alert settings would be stored in a user_settings table.
|
||||
# For now, return defaults since the table doesn't exist yet in common lib.
|
||||
return {
|
||||
"price_increase_threshold_pct": 5.0,
|
||||
"shrinkflation_enabled": True,
|
||||
"email_notifications": False,
|
||||
}
|
||||
|
||||
async def update_settings(self, user_id: UUID, **fields) -> dict:
|
||||
# Would update user_settings table. Return merged defaults for now.
|
||||
current = await self.get_settings(user_id)
|
||||
for k, v in fields.items():
|
||||
if v is not None and k in current:
|
||||
current[k] = v
|
||||
return current
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Auth service — user profile management.
|
||||
|
||||
Registration, login, token management, and session handling are now
|
||||
handled by the Better-Auth service (auth/). This service provides
|
||||
user lookup and profile update operations for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_user(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
if "display_name" in fields and fields["display_name"] is not None:
|
||||
user.display_name = fields["display_name"]
|
||||
if "email" in fields and fields["email"] is not None:
|
||||
existing = await self.db.execute(
|
||||
select(User).where(User.email == fields["email"], User.id != user_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already in use")
|
||||
user.email = fields["email"]
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def delete_user(self, user_id: UUID) -> None:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
await self.db.delete(user)
|
||||
await self.db.commit()
|
||||
@@ -0,0 +1,52 @@
|
||||
"""HTTP client for ClipArtist internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ClipArtistClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.clipartist_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def optimize(
|
||||
self,
|
||||
user_id: str,
|
||||
items: list[dict],
|
||||
preferred_stores: list[str] | None = None,
|
||||
) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/optimize",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"items": items,
|
||||
"preferred_stores": preferred_stores,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_shopping_lists(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/shopping-lists",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_relevant_coupons(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/coupons/relevant",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Coupon service — browse coupons, find relevant ones."""
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class CouponService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_coupons(self, store_id: UUID | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import Coupon
|
||||
|
||||
today = date.today()
|
||||
query = (
|
||||
select(Coupon)
|
||||
.where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)))
|
||||
.options(selectinload(Coupon.store))
|
||||
.order_by(Coupon.valid_to.asc().nullslast())
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Coupon.store_id == store_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
async def relevant_coupons(self, user_id: UUID) -> list[dict]:
|
||||
"""Coupons for products the user has purchased."""
|
||||
from cartsnitch_api.models import Coupon, PurchaseItem
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Get product IDs from user's purchase history
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Coupon)
|
||||
.where(
|
||||
Coupon.normalized_product_id.in_(product_ids),
|
||||
(Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)),
|
||||
)
|
||||
.options(selectinload(Coupon.store))
|
||||
)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
def _to_dict(self, c) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"store_id": c.store_id,
|
||||
"store_name": c.store.name,
|
||||
"description": c.description or c.title,
|
||||
"discount_value": float(c.discount_value) if c.discount_value else 0,
|
||||
"discount_type": c.discount_type,
|
||||
"product_id": c.normalized_product_id,
|
||||
"expires_at": c.valid_to,
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Price service — trends, increases, comparison."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PriceService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trends(self, category: str | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
query = (
|
||||
select(PriceHistory)
|
||||
.join(NormalizedProduct)
|
||||
.options(
|
||||
selectinload(PriceHistory.store),
|
||||
selectinload(PriceHistory.normalized_product),
|
||||
)
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
prices = result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
by_product: dict[UUID, dict] = {}
|
||||
for ph in prices:
|
||||
pid = ph.normalized_product_id
|
||||
if pid not in by_product:
|
||||
by_product[pid] = {
|
||||
"product_id": pid,
|
||||
"product_name": ph.normalized_product.canonical_name,
|
||||
"data_points": [],
|
||||
}
|
||||
by_product[pid]["data_points"].append(
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
)
|
||||
return list(by_product.values())
|
||||
|
||||
async def get_increases(self) -> list[dict]:
|
||||
"""Find products with recent significant price increases.
|
||||
|
||||
Uses a window function (lag) to compare each price observation with the
|
||||
previous one per product+store, avoiding the N+1 query pattern.
|
||||
"""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
# Use lag() window function to get previous price in a single query
|
||||
prev_price = (
|
||||
func.lag(PriceHistory.regular_price)
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date,
|
||||
)
|
||||
.label("prev_price")
|
||||
)
|
||||
|
||||
row_num = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date.desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
|
||||
inner = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
PriceHistory.regular_price,
|
||||
PriceHistory.observed_date,
|
||||
prev_price,
|
||||
row_num,
|
||||
).subquery()
|
||||
|
||||
# Only keep the latest row (rn=1) where price increased
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
inner.c.normalized_product_id,
|
||||
inner.c.store_id,
|
||||
inner.c.regular_price,
|
||||
inner.c.observed_date,
|
||||
inner.c.prev_price,
|
||||
NormalizedProduct.canonical_name,
|
||||
Store.name.label("store_name"),
|
||||
)
|
||||
.join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id)
|
||||
.join(Store, Store.id == inner.c.store_id)
|
||||
.where(
|
||||
inner.c.rn == 1,
|
||||
inner.c.prev_price.isnot(None),
|
||||
inner.c.regular_price > inner.c.prev_price,
|
||||
)
|
||||
)
|
||||
|
||||
increases = []
|
||||
for row in result.all():
|
||||
old = float(row.prev_price)
|
||||
new = float(row.regular_price)
|
||||
increases.append(
|
||||
{
|
||||
"product_id": row.normalized_product_id,
|
||||
"product_name": row.canonical_name,
|
||||
"store_name": row.store_name,
|
||||
"old_price": old,
|
||||
"new_price": new,
|
||||
"increase_pct": round((new - old) / old * 100, 2),
|
||||
"detected_at": row.observed_date,
|
||||
}
|
||||
)
|
||||
|
||||
increases.sort(key=lambda x: x["increase_pct"], reverse=True)
|
||||
return increases
|
||||
|
||||
async def get_comparison(self, product_ids: list[UUID]) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
# Fetch all requested products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group prices by product
|
||||
prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
comparisons = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
comparisons.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
return comparisons
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Product service — catalog, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class ProductService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_products(
|
||||
self,
|
||||
q: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct
|
||||
|
||||
query = select(NormalizedProduct)
|
||||
if q:
|
||||
# Escape SQL LIKE wildcards in user input
|
||||
safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%"))
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
query = query.order_by(NormalizedProduct.canonical_name)
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
products = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.canonical_name,
|
||||
"brand": p.brand,
|
||||
"category": p.category,
|
||||
"upc": (p.upc_variants[0] if p.upc_variants else None),
|
||||
"image_url": None,
|
||||
}
|
||||
for p in products
|
||||
]
|
||||
|
||||
async def get_product(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
# Get latest price per store
|
||||
subq = latest_price_per_store([product_id])
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"id": product.id,
|
||||
"name": product.canonical_name,
|
||||
"brand": product.brand,
|
||||
"category": product.category,
|
||||
"upc": (product.upc_variants[0] if product.upc_variants else None),
|
||||
"image_url": None,
|
||||
"prices_by_store": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_price_history(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Public service — unauthenticated price transparency endpoints."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PublicService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trend(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return {"products": []}
|
||||
|
||||
# Fetch all products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
prices_by_product: dict[UUID, list] = {}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
products = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
products.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return {"products": products}
|
||||
|
||||
async def get_inflation(self) -> dict:
|
||||
"""Aggregate price change stats. Compares average prices across periods."""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
# Get average prices grouped by category for recent vs older data
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
)
|
||||
.join(NormalizedProduct)
|
||||
.group_by(NormalizedProduct.category)
|
||||
)
|
||||
categories = {}
|
||||
for row in result.all():
|
||||
cat, avg_price = row
|
||||
if cat:
|
||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||
|
||||
return {
|
||||
"period": "all-time",
|
||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||
"cpi_baseline": 100.0,
|
||||
"categories": categories,
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Purchase service — list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class PurchaseService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_purchases(
|
||||
self,
|
||||
user_id: UUID,
|
||||
store_id: UUID | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store
|
||||
|
||||
# Count items per purchase in a single subquery instead of N+1
|
||||
item_counts = (
|
||||
select(
|
||||
PurchaseItem.purchase_id,
|
||||
func.count().label("item_count"),
|
||||
)
|
||||
.group_by(PurchaseItem.purchase_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(Purchase, item_counts.c.item_count, Store.name.label("store_name"))
|
||||
.join(Store, Store.id == Purchase.store_id)
|
||||
.outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id)
|
||||
.where(Purchase.user_id == user_id)
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Purchase.store_id == store_id)
|
||||
|
||||
query = query.order_by(Purchase.purchase_date.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"store_id": p.store_id,
|
||||
"store_name": store_name,
|
||||
"purchased_at": p.purchase_date,
|
||||
"total": float(p.total),
|
||||
"item_count": item_count or 0,
|
||||
}
|
||||
for p, item_count, store_name in result.all()
|
||||
]
|
||||
|
||||
async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.id == purchase_id, Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store), selectinload(Purchase.items))
|
||||
)
|
||||
purchase = result.scalar_one_or_none()
|
||||
if not purchase:
|
||||
raise LookupError("Purchase not found")
|
||||
|
||||
return {
|
||||
"id": purchase.id,
|
||||
"store_id": purchase.store_id,
|
||||
"store_name": purchase.store.name,
|
||||
"purchased_at": purchase.purchase_date,
|
||||
"total": float(purchase.total),
|
||||
"item_count": len(purchase.items),
|
||||
"line_items": [
|
||||
{
|
||||
"id": item.id,
|
||||
"product_id": item.normalized_product_id,
|
||||
"name": item.product_name_raw,
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"total_price": float(item.extended_price),
|
||||
}
|
||||
for item in purchase.items
|
||||
],
|
||||
}
|
||||
|
||||
async def get_stats(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store))
|
||||
)
|
||||
purchases = result.scalars().all()
|
||||
|
||||
total_spent = sum(float(p.total) for p in purchases)
|
||||
by_store: dict[str, float] = {}
|
||||
by_period: dict[str, float] = {}
|
||||
|
||||
for p in purchases:
|
||||
store_name = p.store.name
|
||||
by_store[store_name] = by_store.get(store_name, 0) + float(p.total)
|
||||
period = p.purchase_date.strftime("%Y-%m")
|
||||
by_period[period] = by_period.get(period, 0) + float(p.total)
|
||||
|
||||
return {
|
||||
"total_spent": total_spent,
|
||||
"purchase_count": len(purchases),
|
||||
"by_store": by_store,
|
||||
"by_period": by_period,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Shared query helpers for service layer."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
|
||||
def latest_price_per_store(product_ids: list[UUID] | None = None):
|
||||
"""Subquery returning the latest observed_date per product+store.
|
||||
|
||||
Optionally filtered to a list of product IDs. Returns a subquery with
|
||||
columns: normalized_product_id, store_id, max_date.
|
||||
"""
|
||||
from cartsnitch_api.models import PriceHistory
|
||||
|
||||
query = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
func.max(PriceHistory.observed_date).label("max_date"),
|
||||
).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id)
|
||||
if product_ids is not None:
|
||||
query = query.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
return query.subquery()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""HTTP client for ReceiptWitness internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ReceiptWitnessClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.receiptwitness_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def trigger_sync(self, user_id: str, store_slug: str) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/sync/{store_slug}",
|
||||
headers=self.headers,
|
||||
json={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_sync_status(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/sync/status",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,23 @@
|
||||
"""HTTP client for ShrinkRay internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ShrinkRayClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.shrinkray_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/alerts",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,32 @@
|
||||
"""HTTP client for StickerShock internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class StickerShockClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.stickershock_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_price_increases(self, params: dict | None = None) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/increases",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_inflation_data(self) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/inflation",
|
||||
headers=self.headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Store service — list stores, manage user store account connections."""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class StoreService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_stores(self) -> list[dict]:
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
result = await self.db.execute(select(Store).order_by(Store.name))
|
||||
stores = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"slug": s.slug,
|
||||
"logo_url": s.logo_url,
|
||||
"supported": True,
|
||||
}
|
||||
for s in stores
|
||||
]
|
||||
|
||||
async def list_user_stores(self, user_id: UUID) -> list[dict]:
|
||||
from cartsnitch_api.models import UserStoreAccount
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount)
|
||||
.where(UserStoreAccount.user_id == user_id)
|
||||
.options(selectinload(UserStoreAccount.store))
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"store": {
|
||||
"id": a.store.id,
|
||||
"name": a.store.name,
|
||||
"slug": a.store.slug,
|
||||
"logo_url": a.store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": a.status == "active",
|
||||
"last_sync_at": a.last_sync_at,
|
||||
"sync_status": a.status,
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Store account already connected")
|
||||
|
||||
encrypted_data = None
|
||||
if credentials:
|
||||
fernet = _get_fernet()
|
||||
encrypted_data = {
|
||||
"encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode()
|
||||
}
|
||||
|
||||
account = UserStoreAccount(
|
||||
user_id=user_id,
|
||||
store_id=store.id,
|
||||
session_data=encrypted_data,
|
||||
status="active",
|
||||
)
|
||||
self.db.add(account)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(account)
|
||||
|
||||
return {
|
||||
"store": {
|
||||
"id": store.id,
|
||||
"name": store.name,
|
||||
"slug": store.slug,
|
||||
"logo_url": store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": True,
|
||||
"last_sync_at": None,
|
||||
"sync_status": "active",
|
||||
}
|
||||
|
||||
async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise LookupError("Store account not connected")
|
||||
|
||||
await self.db.delete(account)
|
||||
await self.db.commit()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Custom SQLAlchemy column types."""
|
||||
|
||||
import json
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class EncryptedJSON(TypeDecorator):
|
||||
"""SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet.
|
||||
|
||||
Stores data as a Fernet-encrypted text blob in the database.
|
||||
On read, decrypts and deserialises back to a Python dict/list.
|
||||
"""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
plaintext = json.dumps(value).encode()
|
||||
return _get_fernet().encrypt(plaintext).decode()
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
decrypted = _get_fernet().decrypt(value.encode())
|
||||
return json.loads(decrypted)
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Shared test fixtures with in-memory SQLite database.
|
||||
|
||||
Session-based auth: tests create users and sessions directly in the DB,
|
||||
matching the Better-Auth session validation flow.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.main import create_app
|
||||
from cartsnitch_api.models import Base
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limiting():
|
||||
"""Disable rate limiting for all tests to prevent 429 interference."""
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
"""Sync in-memory SQLite engine for model unit tests."""
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
"""Sync SQLAlchemy session for model unit tests."""
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
user_id TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
access_token_expires_at TIMESTAMP,
|
||||
refresh_token_expires_at TIMESTAMP,
|
||||
scope TEXT,
|
||||
id_token TEXT,
|
||||
password TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS verifications (
|
||||
id TEXT PRIMARY KEY,
|
||||
identifier TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
|
||||
yield engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async def override_get_db():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||
"""Create a test user and a valid session directly in the DB.
|
||||
|
||||
Returns (user_dict, session_token).
|
||||
"""
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_overrides.get("email", "test@example.com")
|
||||
display_name = user_overrides.get("display_name", "Test User")
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
session_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": email,
|
||||
"hashed_password": "not-used-with-better-auth",
|
||||
"display_name": display_name,
|
||||
"email_verified": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": session_id,
|
||||
"token": session_token,
|
||||
"user_id": user_id,
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client, db_engine):
|
||||
"""Create a test user with a valid session and return auth headers."""
|
||||
_, session_token = await _create_test_user_and_session(client, db_engine)
|
||||
return {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Integration tests for auth profile endpoints.
|
||||
|
||||
Registration, login, and session management are handled by the Better-Auth
|
||||
service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me)
|
||||
which validate sessions via the shared sessions table.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me(client, auth_headers):
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == "test@example.com"
|
||||
assert data["display_name"] == "Test User"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_unauthorized(client):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_invalid_session(client):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=invalid-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_with_bearer_token(client, db_engine):
|
||||
"""Session tokens can also be passed as Bearer tokens for API clients."""
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@example.com", display_name="Bearer User"
|
||||
)
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me(client, auth_headers):
|
||||
resp = await client.patch(
|
||||
"/auth/me",
|
||||
headers=auth_headers,
|
||||
json={"display_name": "Updated Name"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["display_name"] == "Updated Name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_me(client, auth_headers):
|
||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Session is still valid but user is gone
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session_rejected(client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@example.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"ev": False,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Shared fixtures for E2E integration tests.
|
||||
|
||||
Seeds a realistic dataset with stores, products, price history,
|
||||
purchases, coupons, and shrinkflation events so E2E flows can
|
||||
exercise cross-resource queries against real data.
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
)
|
||||
|
||||
# Shared test constants
|
||||
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
BAD_UUID = "not-a-uuid"
|
||||
# Fixed anchor date for deterministic tests
|
||||
ANCHOR_DATE = date(2026, 3, 15)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seed_data(db_engine, auth_headers):
|
||||
"""Seed a full dataset and return identifiers for test assertions."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
# -- Stores --
|
||||
meijer = Store(name="Meijer", slug="meijer")
|
||||
kroger = Store(name="Kroger", slug="kroger")
|
||||
target = Store(name="Target", slug="target")
|
||||
session.add_all([meijer, kroger, target])
|
||||
await session.flush()
|
||||
|
||||
# -- Products --
|
||||
cheerios = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
size="18",
|
||||
size_unit="oz",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
milk = NormalizedProduct(
|
||||
canonical_name="Whole Milk 1gal",
|
||||
category="dairy",
|
||||
brand="Meijer",
|
||||
size="1",
|
||||
size_unit="gal",
|
||||
)
|
||||
chicken = NormalizedProduct(
|
||||
canonical_name="Chicken Breast 1lb",
|
||||
category="meat",
|
||||
brand=None,
|
||||
size="1",
|
||||
size_unit="lb",
|
||||
)
|
||||
session.add_all([cheerios, milk, chicken])
|
||||
await session.flush()
|
||||
|
||||
# -- Price history (multiple dates, multiple stores) --
|
||||
today = ANCHOR_DATE
|
||||
prices = []
|
||||
# Cheerios at Meijer: price increase over time
|
||||
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=price_val,
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Cheerios at Kroger: stable price
|
||||
for i in range(3):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=Decimal("4.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Milk at Meijer
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=7),
|
||||
regular_price=Decimal("3.29"),
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Milk at Kroger
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=5),
|
||||
regular_price=Decimal("3.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Chicken at Target
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=chicken.id,
|
||||
store_id=target.id,
|
||||
observed_date=today - timedelta(days=3),
|
||||
regular_price=Decimal("5.99"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
session.add_all(prices)
|
||||
await session.flush()
|
||||
|
||||
# -- Get the user_id from the session token in auth_headers --
|
||||
cookie_str = auth_headers.get("Cookie", "")
|
||||
session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else ""
|
||||
|
||||
result = await session.execute(
|
||||
text("SELECT user_id FROM sessions WHERE token = :token"),
|
||||
{"token": session_token},
|
||||
)
|
||||
row = result.first()
|
||||
user_id = UUID(row[0])
|
||||
|
||||
purchase1 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=meijer.id,
|
||||
receipt_id="meijer-2026-001",
|
||||
purchase_date=today - timedelta(days=10),
|
||||
total=Decimal("23.45"),
|
||||
subtotal=Decimal("21.50"),
|
||||
tax=Decimal("1.95"),
|
||||
)
|
||||
purchase2 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=kroger.id,
|
||||
receipt_id="kroger-2026-001",
|
||||
purchase_date=today - timedelta(days=5),
|
||||
total=Decimal("15.78"),
|
||||
subtotal=Decimal("14.50"),
|
||||
tax=Decimal("1.28"),
|
||||
)
|
||||
session.add_all([purchase1, purchase2])
|
||||
await session.flush()
|
||||
|
||||
# -- Purchase Items --
|
||||
item1 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Cheerios 18oz Box",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.79"),
|
||||
extended_price=Decimal("4.79"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
item2 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Meijer Whole Milk 1gal",
|
||||
quantity=Decimal("2"),
|
||||
unit_price=Decimal("3.29"),
|
||||
extended_price=Decimal("6.58"),
|
||||
normalized_product_id=milk.id,
|
||||
)
|
||||
item3 = PurchaseItem(
|
||||
purchase_id=purchase2.id,
|
||||
product_name_raw="KRO CHEERIOS 18OZ",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.49"),
|
||||
extended_price=Decimal("4.49"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
session.add_all([item1, item2, item3])
|
||||
await session.flush()
|
||||
|
||||
# -- Coupons --
|
||||
coupon1 = Coupon(
|
||||
store_id=meijer.id,
|
||||
normalized_product_id=cheerios.id,
|
||||
title="$1 off Cheerios",
|
||||
description="Save $1 on any Cheerios 18oz or larger",
|
||||
discount_type="fixed",
|
||||
discount_value=Decimal("1.00"),
|
||||
valid_from=today - timedelta(days=7),
|
||||
valid_to=today + timedelta(days=30),
|
||||
)
|
||||
coupon2 = Coupon(
|
||||
store_id=kroger.id,
|
||||
normalized_product_id=None,
|
||||
title="10% off dairy",
|
||||
description="10% off all dairy products",
|
||||
discount_type="percent",
|
||||
discount_value=Decimal("10.00"),
|
||||
valid_from=today - timedelta(days=3),
|
||||
valid_to=today + timedelta(days=14),
|
||||
)
|
||||
session.add_all([coupon1, coupon2])
|
||||
await session.flush()
|
||||
|
||||
# -- Shrinkflation events --
|
||||
shrink = ShrinkflationEvent(
|
||||
normalized_product_id=cheerios.id,
|
||||
detected_date=today - timedelta(days=15),
|
||||
old_size="20",
|
||||
new_size="18",
|
||||
old_unit="oz",
|
||||
new_unit="oz",
|
||||
price_at_old_size=Decimal("3.99"),
|
||||
price_at_new_size=Decimal("4.29"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced from 20oz to 18oz while price increased",
|
||||
)
|
||||
session.add(shrink)
|
||||
await session.commit()
|
||||
|
||||
for obj in [
|
||||
meijer,
|
||||
kroger,
|
||||
target,
|
||||
cheerios,
|
||||
milk,
|
||||
chicken,
|
||||
purchase1,
|
||||
purchase2,
|
||||
item1,
|
||||
item2,
|
||||
item3,
|
||||
coupon1,
|
||||
coupon2,
|
||||
shrink,
|
||||
]:
|
||||
await session.refresh(obj)
|
||||
|
||||
return {
|
||||
"headers": auth_headers,
|
||||
"user_id": user_id,
|
||||
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
|
||||
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
|
||||
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
|
||||
"items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3},
|
||||
"coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2},
|
||||
"shrinkflation": {"cheerios_shrink": shrink},
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
"""E2E: Auth and session validation flows.
|
||||
|
||||
Registration and login are handled by the Better-Auth service.
|
||||
These tests validate session token handling at the API gateway level.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSessionValidation:
|
||||
"""Session edge cases and error responses."""
|
||||
|
||||
async def test_invalid_session_token_rejected(self, client, db_engine):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=not-a-real-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_missing_auth(self, client, db_engine):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
async def test_bearer_token_also_works(self, client, db_engine):
|
||||
"""Session tokens passed as Bearer tokens should also be accepted."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E"
|
||||
)
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@e2e.com"
|
||||
|
||||
async def test_deleted_user_session_returns_not_found(self, client, db_engine):
|
||||
"""After deleting a user, their session should result in 404 for profile."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="delete-me@e2e.com", display_name="Delete Me"
|
||||
)
|
||||
headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||
assert delete_resp.status_code == 204
|
||||
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code == 404
|
||||
|
||||
async def test_expired_session_rejected(self, client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@e2e.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"ev": False,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthProtectedEndpoints:
|
||||
"""Verify auth is enforced on all user-specific endpoints."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,path",
|
||||
[
|
||||
("GET", "/purchases"),
|
||||
("GET", "/products"),
|
||||
("GET", "/prices/trends"),
|
||||
("GET", "/prices/increases"),
|
||||
("GET", "/coupons"),
|
||||
("GET", "/alerts"),
|
||||
("GET", "/me/stores"),
|
||||
],
|
||||
)
|
||||
async def test_endpoints_require_auth(self, client, db_engine, method, path):
|
||||
resp = await client.request(method, path)
|
||||
assert resp.status_code in (401, 403), f"{method} {path} should require auth"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCrossUserDataIsolation:
|
||||
"""Verify that users cannot access other users' data."""
|
||||
|
||||
async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data):
|
||||
"""A second user cannot see User A's purchases."""
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userb@e2e.com", display_name="User B"
|
||||
)
|
||||
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||
assert resp.status_code in (403, 404), (
|
||||
"User B should not be able to access User A's purchase"
|
||||
)
|
||||
|
||||
async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data):
|
||||
"""A new user should see no purchases."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userc@e2e.com", display_name="User C"
|
||||
)
|
||||
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get("/purchases", headers=user_c_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||
|
||||
async def test_user_b_stores_isolated(self, client, db_engine, seed_data):
|
||||
"""User B's connected stores should be independent from User A."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userd@e2e.com", display_name="User D"
|
||||
)
|
||||
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||
@@ -0,0 +1,114 @@
|
||||
"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectToPurchaseFlow:
|
||||
"""Connect a store, then verify purchases and related data are accessible."""
|
||||
|
||||
async def test_connect_store_then_list(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
# Connect to Meijer
|
||||
resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert resp.status_code in (200, 201)
|
||||
|
||||
# Verify store appears in user's connected stores
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
assert stores.status_code == 200
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "meijer" in slugs
|
||||
|
||||
async def test_disconnect_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
await client.post("/me/stores/kroger/connect", json={}, headers=headers)
|
||||
resp = await client.delete("/me/stores/kroger", headers=headers)
|
||||
assert resp.status_code in (200, 204)
|
||||
|
||||
# Verify store no longer in connected list
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "kroger" not in slugs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseToPriceFlow:
|
||||
"""Verify purchase data links to price comparison data."""
|
||||
|
||||
async def test_purchase_items_link_to_products(self, client, seed_data):
|
||||
"""Items from purchases reference products that have price data."""
|
||||
headers = seed_data["headers"]
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
# Get purchase detail
|
||||
purchase = await client.get(f"/purchases/{purchase_id}", headers=headers)
|
||||
assert purchase.status_code == 200
|
||||
items = purchase.json()["line_items"]
|
||||
|
||||
# Get product detail for an item that has a product_id
|
||||
product_ids = [li["product_id"] for li in items if li.get("product_id")]
|
||||
assert len(product_ids) >= 1
|
||||
|
||||
for pid in product_ids:
|
||||
product = await client.get(f"/products/{pid}", headers=headers)
|
||||
assert product.status_code == 200
|
||||
assert len(product.json()["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCouponFlow:
|
||||
"""Verify coupon listing and relevance filtering."""
|
||||
|
||||
async def test_list_all_coupons(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
async def test_filter_coupons_by_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert all(c["store_name"] == "Meijer" for c in data)
|
||||
|
||||
async def test_relevant_coupons_for_user(self, client, seed_data):
|
||||
"""User bought Cheerios, so the Cheerios coupon should be relevant."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons/relevant", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases"
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAlertFlow:
|
||||
"""Verify alert listing with seeded data."""
|
||||
|
||||
async def test_list_alerts(self, client, seed_data):
|
||||
"""User bought Cheerios which has a shrinkflation event — may appear as alert."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
# If alerts are generated synchronously, verify shrinkflation alert content
|
||||
if len(data) > 0:
|
||||
alert_types = [a["alert_type"] for a in data]
|
||||
product_names = [a["product_name"] for a in data]
|
||||
assert any(t in ("shrinkflation", "price_increase") for t in alert_types)
|
||||
assert any("Cheerios" in name for name in product_names)
|
||||
|
||||
async def test_alert_settings_default(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts/settings", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "price_increase_threshold_pct" in data
|
||||
assert "shrinkflation_enabled" in data
|
||||
@@ -0,0 +1,127 @@
|
||||
"""E2E: Error responses for bad input across all endpoint categories."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRegistrationErrors:
|
||||
"""Validation errors during user registration."""
|
||||
|
||||
async def test_short_password(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_email(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_missing_fields(self, client, db_engine):
|
||||
resp = await client.post("/auth/register", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_empty_display_name(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_duplicate_email(self, client, db_engine):
|
||||
payload = {
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "First",
|
||||
}
|
||||
first = await client.post("/auth/register", json=payload)
|
||||
assert first.status_code == 201
|
||||
second = await client.post("/auth/register", json=payload)
|
||||
assert second.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoginErrors:
|
||||
"""Login failure modes."""
|
||||
|
||||
async def test_wrong_password(self, client, db_engine):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login-err@example.com",
|
||||
"password": "correctpass1",
|
||||
"display_name": "Login",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "login-err@example.com", "password": "wrongpass123"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_nonexistent_user(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "nobody@example.com", "password": "doesntmatter"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestNotFoundErrors:
|
||||
"""404 responses for missing resources."""
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_public_trend_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{ZERO_UUID}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMalformedInput:
|
||||
"""Invalid UUID formats and bad query params."""
|
||||
|
||||
async def test_invalid_uuid_product(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_purchase(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_public_trend(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{BAD_UUID}")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectionErrors:
|
||||
"""Store connection edge cases."""
|
||||
|
||||
async def test_connect_nonexistent_store(self, client, seed_data):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent-store/connect",
|
||||
json={},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_connect_store_twice(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
first = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert first.status_code in (200, 201)
|
||||
second = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert second.status_code == 409
|
||||
@@ -0,0 +1,102 @@
|
||||
"""E2E: Price history queries returning correct data."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceTrends:
|
||||
"""Verify price trend aggregation against seeded history."""
|
||||
|
||||
async def test_trends_returns_all_products(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
product_names = [t["product_name"] for t in data]
|
||||
assert "Cheerios 18oz" in product_names
|
||||
assert "Whole Milk 1gal" in product_names
|
||||
|
||||
async def test_trends_filter_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
# Only dairy products should appear
|
||||
for trend in data:
|
||||
assert trend["product_name"] == "Whole Milk 1gal"
|
||||
|
||||
async def test_trends_contain_data_points(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz")
|
||||
assert len(cheerios_trend["data_points"]) >= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceIncreases:
|
||||
"""Detect price increases from seeded price history."""
|
||||
|
||||
async def test_increases_detected(self, client, seed_data):
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Cheerios at Meijer went from 3.99 → 4.29 → 4.79
|
||||
cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"]
|
||||
assert len(cheerios_increases) >= 1
|
||||
# Verify the increase data makes sense
|
||||
for inc in cheerios_increases:
|
||||
assert inc["new_price"] > inc["old_price"]
|
||||
assert inc["increase_pct"] > 0
|
||||
assert inc["store_name"] == "Meijer"
|
||||
|
||||
async def test_stable_prices_not_flagged(self, client, seed_data):
|
||||
"""Kroger Cheerios price is stable at $4.49 — should not appear as increase."""
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
kroger_increases = [
|
||||
inc
|
||||
for inc in data
|
||||
if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger"
|
||||
]
|
||||
assert len(kroger_increases) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceComparison:
|
||||
"""Compare prices across stores for specific products."""
|
||||
|
||||
async def test_compare_cheerios_across_stores(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params={"product_ids": cheerios_id},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
cheerios_cmp = data[0]
|
||||
assert cheerios_cmp["product_name"] == "Cheerios 18oz"
|
||||
store_names = [p["store_name"] for p in cheerios_cmp["prices"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_compare_requires_product_ids(self, client, seed_data):
|
||||
"""product_ids is required — omitting it must return 422."""
|
||||
resp = await client.get("/prices/comparison", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_compare_multiple_products(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
milk_id = str(seed_data["products"]["milk"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params=[("product_ids", cheerios_id), ("product_ids", milk_id)],
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
names = [c["product_name"] for c in data]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
@@ -0,0 +1,82 @@
|
||||
"""E2E: Product search/lookup endpoints with real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductSearch:
|
||||
"""Search and filter products against seeded data."""
|
||||
|
||||
async def test_list_all_products(self, client, seed_data):
|
||||
resp = await client.get("/products", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
names = [p["name"] for p in products]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
assert "Chicken Breast 1lb" in names
|
||||
|
||||
async def test_search_by_name(self, client, seed_data):
|
||||
resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all("cheerios" in p["name"].lower() for p in products)
|
||||
|
||||
async def test_search_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all(p["category"] == "dairy" for p in products)
|
||||
|
||||
async def test_search_no_results(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductLookup:
|
||||
"""Detailed product lookups with cross-store pricing."""
|
||||
|
||||
async def test_get_product_detail_with_prices(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert data["category"] == "pantry"
|
||||
# Should have prices from both Meijer and Kroger
|
||||
store_names = [p["store_name"] for p in data["prices_by_store"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_product_prices_reflect_latest(self, client, seed_data):
|
||||
"""The latest Meijer price for Cheerios should be 4.79 (the increase)."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer")
|
||||
assert meijer_price["current_price"] == 4.79
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_product_price_history(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations
|
||||
# Verify chronological ordering exists
|
||||
prices = [dp["price"] for dp in data["data_points"]]
|
||||
assert len(prices) >= 3
|
||||
@@ -0,0 +1,59 @@
|
||||
"""E2E: Public price transparency endpoints (no auth required)."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicTrends:
|
||||
"""Public price trend endpoint — no auth, real data."""
|
||||
|
||||
async def test_public_trend_returns_data(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) >= 3
|
||||
|
||||
async def test_public_trend_no_auth_needed(self, client, seed_data):
|
||||
"""Confirm no Authorization header is required."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicStoreComparison:
|
||||
"""Public store comparison endpoint."""
|
||||
|
||||
async def test_store_comparison(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/public/store-comparison",
|
||||
params=[("product_ids", cheerios_id)],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "products" in data
|
||||
assert len(data["products"]) >= 1
|
||||
|
||||
async def test_store_comparison_rejects_more_than_20_ids(self, client):
|
||||
"""max_length=20 guard: 21 product IDs must return 422."""
|
||||
too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)]
|
||||
resp = await client.get("/public/store-comparison", params=too_many)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicInflation:
|
||||
"""Public inflation index endpoint."""
|
||||
|
||||
async def test_inflation_returns_index(self, client, seed_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "cartsnitch_index" in data
|
||||
assert "cpi_baseline" in data
|
||||
assert "categories" in data
|
||||
@@ -0,0 +1,87 @@
|
||||
"""E2E: Purchase listing, detail, and stats against real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseList:
|
||||
"""List and filter a user's purchases."""
|
||||
|
||||
async def test_list_user_purchases(self, client, seed_data):
|
||||
resp = await client.get("/purchases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
store_names = [p["store_name"] for p in data]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_filter_purchases_by_store(self, client, seed_data):
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get(
|
||||
"/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert all(p["store_name"] == "Meijer" for p in data)
|
||||
|
||||
async def test_purchases_require_auth(self, client, seed_data):
|
||||
resp = await client.get("/purchases")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseDetail:
|
||||
"""Retrieve individual purchase with line items."""
|
||||
|
||||
async def test_get_purchase_detail(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["store_name"] == "Meijer"
|
||||
assert data["total"] == 23.45
|
||||
assert len(data["line_items"]) == 2
|
||||
item_names = [li["name"] for li in data["line_items"]]
|
||||
assert "Cheerios 18oz Box" in item_names
|
||||
assert "Meijer Whole Milk 1gal" in item_names
|
||||
|
||||
async def test_line_item_amounts_correct(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"])
|
||||
assert cheerios_item["unit_price"] == 4.79
|
||||
assert cheerios_item["quantity"] == 1.0
|
||||
assert cheerios_item["total_price"] == 4.79
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
f"/purchases/{ZERO_UUID}",
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseStats:
|
||||
"""Verify spending aggregation across purchases."""
|
||||
|
||||
async def test_purchase_stats_totals(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["purchase_count"] == 2
|
||||
# 23.45 + 15.78 = 39.23
|
||||
assert abs(data["total_spent"] - 39.23) < 0.01
|
||||
|
||||
async def test_purchase_stats_by_store(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
assert "Meijer" in data["by_store"]
|
||||
assert "Kroger" in data["by_store"]
|
||||
assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01
|
||||
assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Tests for EncryptedJSON TypeDecorator and session_data encryption."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import column, create_engine, table, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.models import Base
|
||||
from cartsnitch_api.models.store import Store
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(session):
|
||||
s = Store(name="Test Store", slug="test-store")
|
||||
session.add(s)
|
||||
session.commit()
|
||||
session.refresh(s)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user(session):
|
||||
u = User(email="alice@example.com", hashed_password="fakehash")
|
||||
session.add(u)
|
||||
session.commit()
|
||||
session.refresh(u)
|
||||
return u
|
||||
|
||||
|
||||
class TestEncryptedJSONType:
|
||||
"""Unit tests for the EncryptedJSON TypeDecorator."""
|
||||
|
||||
def test_round_trip(self, session, user, store):
|
||||
"""Data written via the ORM comes back as the original dict."""
|
||||
original = {"token": "abc123", "cookies": {"session_id": "xyz"}}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == original
|
||||
|
||||
def test_stored_value_is_encrypted(self, session, user, store):
|
||||
"""The raw value in the DB should be a Fernet token, not plaintext JSON."""
|
||||
original = {"secret": "do-not-leak"}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
# Use a raw table construct to bypass TypeDecorator on read
|
||||
raw_table = table("user_store_accounts", column("id"), column("session_data"))
|
||||
raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first()
|
||||
# If UUID matching fails with str, try bytes format
|
||||
if raw is None:
|
||||
raw = session.execute(
|
||||
text("SELECT session_data FROM user_store_accounts LIMIT 1")
|
||||
).scalar_one()
|
||||
else:
|
||||
raw = raw[1]
|
||||
|
||||
assert raw != json.dumps(original)
|
||||
assert raw.startswith("gAAAAA")
|
||||
|
||||
# Verify we can decrypt the raw value manually
|
||||
f = Fernet(settings.fernet_key.encode())
|
||||
decrypted = json.loads(f.decrypt(raw.encode()))
|
||||
assert decrypted == original
|
||||
|
||||
def test_null_round_trip(self, session, user, store):
|
||||
"""NULL session_data stays NULL."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data is None
|
||||
|
||||
def test_empty_dict_round_trip(self, session, user, store):
|
||||
"""Empty dict round-trips correctly."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {}
|
||||
|
||||
def test_update_session_data(self, session, user, store):
|
||||
"""Updating session_data re-encrypts the new value."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
account.session_data = {"v": 2, "new_field": True}
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {"v": 2, "new_field": True}
|
||||
|
||||
|
||||
class TestEncryptionKeyValidation:
|
||||
"""Test that invalid/missing keys are caught at startup."""
|
||||
|
||||
def test_invalid_fernet_key_rejected(self, monkeypatch):
|
||||
"""Settings validation rejects a bad key."""
|
||||
monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
from cartsnitch_api.config import Settings
|
||||
|
||||
Settings()
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Conftest for middleware tests — re-enables rate limiting after global disable."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_rate_limiting():
|
||||
"""Re-enable rate limiting after the global disable_rate_limiting fixture runs.
|
||||
|
||||
The root conftest disables rate limiting for all tests to prevent 429
|
||||
interference. Middleware tests need it active to verify headers and
|
||||
enforcement. This fixture runs after the root fixture (more local = later
|
||||
in setup order) so True is the effective value during the test body.
|
||||
"""
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for structured error responses and error monitoring."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_returns_structured_error(client):
|
||||
"""Non-existent route should return structured error."""
|
||||
resp = await client.get("/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert "detail" in body
|
||||
assert "code" in body
|
||||
assert body["code"] == "NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_returns_422_with_field_errors(client):
|
||||
"""Invalid request body should return structured validation errors."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "short", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["code"] == "VALIDATION_ERROR"
|
||||
assert "errors" in body
|
||||
assert isinstance(body["errors"], list)
|
||||
assert len(body["errors"]) > 0
|
||||
# Each error should have field, message, type
|
||||
for err in body["errors"]:
|
||||
assert "field" in err
|
||||
assert "message" in err
|
||||
assert "type" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_requires_service_key(client):
|
||||
"""Error stats endpoint should require X-Service-Key."""
|
||||
resp = await client.get("/internal/error-stats")
|
||||
assert resp.status_code == 422 # Missing required header
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_with_valid_key(client):
|
||||
"""Error stats endpoint returns monitoring data with valid key."""
|
||||
resp = await client.get(
|
||||
"/internal/error-stats",
|
||||
headers={"X-Service-Key": "change-me-in-production"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "error_counts" in body
|
||||
assert "recent_5xx_count" in body
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
def test_allows_within_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
||||
for i in range(5):
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4 - i
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
||||
for _ in range(3):
|
||||
counter.is_allowed("test-key")
|
||||
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert remaining == 0
|
||||
assert retry > 0
|
||||
|
||||
def test_separate_keys(self):
|
||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
||||
# Fill key-a
|
||||
counter.is_allowed("key-a")
|
||||
counter.is_allowed("key-a")
|
||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
"""Public endpoint should return 429 after limit exceeded."""
|
||||
# The default limit is 60/min — we won't hit it in normal tests,
|
||||
# but we verify the middleware adds rate limit headers.
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Tests for SQLAlchemy ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from cartsnitch_api.constants import (
|
||||
AccountStatus,
|
||||
DiscountType,
|
||||
PriceSource,
|
||||
ProductCategory,
|
||||
SizeUnit,
|
||||
StoreSlug,
|
||||
)
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
StoreLocation,
|
||||
User,
|
||||
UserStoreAccount,
|
||||
)
|
||||
|
||||
|
||||
class TestTableCreation:
|
||||
"""Verify all expected tables are created."""
|
||||
|
||||
def test_all_tables_exist(self, engine):
|
||||
inspector = inspect(engine)
|
||||
table_names = set(inspector.get_table_names())
|
||||
expected = {
|
||||
"stores",
|
||||
"store_locations",
|
||||
"users",
|
||||
"user_store_accounts",
|
||||
"purchases",
|
||||
"purchase_items",
|
||||
"normalized_products",
|
||||
"price_history",
|
||||
"coupons",
|
||||
"shrinkflation_events",
|
||||
}
|
||||
assert expected.issubset(table_names)
|
||||
|
||||
def test_ten_tables_total(self, engine):
|
||||
inspector = inspect(engine)
|
||||
assert len(inspector.get_table_names()) == 10
|
||||
|
||||
|
||||
class TestUUIDPrimaryKeys:
|
||||
"""All models use UUID PKs."""
|
||||
|
||||
def test_store_uuid_pk(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert isinstance(store.id, uuid.UUID)
|
||||
|
||||
def test_user_uuid_pk(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
assert isinstance(user.id, uuid.UUID)
|
||||
|
||||
|
||||
class TestStoreModel:
|
||||
def test_store_slug_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert store.slug == StoreSlug.KROGER
|
||||
|
||||
def test_store_unique_slug(self, session):
|
||||
s1 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
s2 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target Duplicate",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(s1)
|
||||
session.commit()
|
||||
session.add(s2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestStoreLocationModel:
|
||||
def test_store_location_fields(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
loc = StoreLocation(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
address="123 Main St",
|
||||
city="Ann Arbor",
|
||||
state="MI",
|
||||
zip="48104",
|
||||
lat=42.2808,
|
||||
lng=-83.7430,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(loc)
|
||||
session.commit()
|
||||
assert loc.city == "Ann Arbor"
|
||||
assert loc.lat == pytest.approx(42.2808)
|
||||
|
||||
|
||||
class TestUserStoreAccountModel:
|
||||
def test_account_status_enum(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
acct = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(acct)
|
||||
session.commit()
|
||||
assert acct.status == AccountStatus.ACTIVE
|
||||
|
||||
def test_unique_user_store_constraint(self, session):
|
||||
"""One account per user per store."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="unique@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
a1 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
a2 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.EXPIRED,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(a1)
|
||||
session.commit()
|
||||
session.add(a2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestPurchaseModel:
|
||||
def test_purchase_with_items(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="buyer@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
purchase = Purchase(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="RCP-001",
|
||||
purchase_date=date(2026, 3, 15),
|
||||
total=Decimal("42.50"),
|
||||
ingested_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(purchase)
|
||||
session.flush()
|
||||
item = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Meijer Whole Milk 1 Gallon",
|
||||
upc="0041250000001",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("3.49"),
|
||||
extended_price=Decimal("3.49"),
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
assert item.product_name_raw == "Meijer Whole Milk 1 Gallon"
|
||||
assert item.unit_price == Decimal("3.49")
|
||||
|
||||
|
||||
class TestNormalizedProductModel:
|
||||
def test_product_with_upc_variants(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Whole Milk, 1 Gallon",
|
||||
category=ProductCategory.DAIRY,
|
||||
brand="Store Brand",
|
||||
size="128",
|
||||
size_unit=SizeUnit.FL_OZ,
|
||||
upc_variants=["0041250000001", "0041250000002"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
assert product.category == ProductCategory.DAIRY
|
||||
assert product.size_unit == SizeUnit.FL_OZ
|
||||
|
||||
|
||||
class TestPriceHistoryModel:
|
||||
def test_price_source_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Eggs, Large, 12ct",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([store, product])
|
||||
session.flush()
|
||||
ph = PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 15),
|
||||
regular_price=Decimal("4.99"),
|
||||
sale_price=Decimal("3.99"),
|
||||
source=PriceSource.RECEIPT,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(ph)
|
||||
session.commit()
|
||||
assert ph.source == PriceSource.RECEIPT
|
||||
assert ph.regular_price == Decimal("4.99")
|
||||
|
||||
|
||||
class TestCouponModel:
|
||||
def test_coupon_discount_types(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
coupon = Coupon(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
title="$2 off eggs",
|
||||
discount_type=DiscountType.FIXED,
|
||||
discount_value=Decimal("2.00"),
|
||||
requires_clip=True,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(coupon)
|
||||
session.commit()
|
||||
assert coupon.discount_type == DiscountType.FIXED
|
||||
assert coupon.discount_value == Decimal("2.00")
|
||||
|
||||
|
||||
class TestShrinkflationEventModel:
|
||||
def test_shrinkflation_event(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Cereal, Honey Oats",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.flush()
|
||||
event = ShrinkflationEvent(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
detected_date=date(2026, 3, 10),
|
||||
old_size="18",
|
||||
new_size="15.4",
|
||||
old_unit=SizeUnit.OZ,
|
||||
new_unit=SizeUnit.OZ,
|
||||
price_at_old_size=Decimal("4.99"),
|
||||
price_at_new_size=Decimal("4.99"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced by 14.4%, price unchanged",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
assert event.confidence == Decimal("0.95")
|
||||
assert event.old_unit == SizeUnit.OZ
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Verify all expected routes are present in the OpenAPI spec."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from cartsnitch_api.main import app
|
||||
|
||||
EXPECTED_ROUTES = [
|
||||
# Auth (6)
|
||||
("post", "/auth/register"),
|
||||
("post", "/auth/login"),
|
||||
("post", "/auth/refresh"),
|
||||
("get", "/auth/me"),
|
||||
("patch", "/auth/me"),
|
||||
("delete", "/auth/me"),
|
||||
# Stores (4)
|
||||
("get", "/stores"),
|
||||
("get", "/me/stores"),
|
||||
("post", "/me/stores/{store_slug}/connect"),
|
||||
("delete", "/me/stores/{store_slug}"),
|
||||
# Purchases (3)
|
||||
("get", "/purchases"),
|
||||
("get", "/purchases/stats"),
|
||||
("get", "/purchases/{purchase_id}"),
|
||||
# Products (3)
|
||||
("get", "/products"),
|
||||
("get", "/products/{product_id}"),
|
||||
("get", "/products/{product_id}/prices"),
|
||||
# Prices (3)
|
||||
("get", "/prices/trends"),
|
||||
("get", "/prices/increases"),
|
||||
("get", "/prices/comparison"),
|
||||
# Coupons (2)
|
||||
("get", "/coupons"),
|
||||
("get", "/coupons/relevant"),
|
||||
# Shopping (2)
|
||||
("post", "/shopping/optimize"),
|
||||
("get", "/shopping/lists"),
|
||||
# Alerts (3)
|
||||
("get", "/alerts"),
|
||||
("get", "/alerts/settings"),
|
||||
("put", "/alerts/settings"),
|
||||
# Scraping (2)
|
||||
("post", "/scraping/{store_slug}/sync"),
|
||||
("get", "/scraping/status"),
|
||||
# Public (3)
|
||||
("get", "/public/trends/{product_id}"),
|
||||
("get", "/public/store-comparison"),
|
||||
("get", "/public/inflation"),
|
||||
# Health (1)
|
||||
("get", "/health"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_routes_in_openapi():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
registered = set()
|
||||
for path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
registered.add((method, path))
|
||||
|
||||
missing = []
|
||||
for method, path in EXPECTED_ROUTES:
|
||||
if (method, path) not in registered:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_count():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
count = 0
|
||||
for _path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
count += 1
|
||||
|
||||
assert count == 33, f"Expected 33 routes, found {count}"
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for alert endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_alerts_empty(client, auth_headers):
|
||||
"""No purchases means no alerts."""
|
||||
resp = await client.get("/alerts", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_alert_settings(client, auth_headers):
|
||||
resp = await client.get("/alerts/settings", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["price_increase_threshold_pct"] == 5.0
|
||||
assert data["shrinkflation_enabled"] is True
|
||||
assert data["email_notifications"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_alert_settings_returns_501(client, auth_headers):
|
||||
resp = await client.put(
|
||||
"/alerts/settings",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"price_increase_threshold_pct": 10.0,
|
||||
"shrinkflation_enabled": False,
|
||||
"email_notifications": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 501
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Integration tests for coupon endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import Coupon, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coupon_data(db_engine, auth_headers):
|
||||
"""Seed stores and coupons."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
|
||||
coupon = Coupon(
|
||||
store_id=store.id,
|
||||
title="$2 off laundry",
|
||||
description="$2 off any laundry detergent",
|
||||
discount_value=Decimal("2.00"),
|
||||
discount_type="fixed",
|
||||
valid_from=date(2026, 1, 1),
|
||||
valid_to=date(2026, 12, 31),
|
||||
)
|
||||
session.add(coupon)
|
||||
await session.commit()
|
||||
|
||||
return {"store": store, "coupon": coupon, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons(client, coupon_data):
|
||||
resp = await client.get("/coupons", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons_by_store(client, coupon_data):
|
||||
store_id = str(coupon_data["store"].id)
|
||||
resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevant_coupons_empty(client, auth_headers):
|
||||
"""No purchases means no relevant coupons."""
|
||||
resp = await client.get("/coupons/relevant", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for price endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def price_data(db_engine, auth_headers):
|
||||
"""Seed products with price history showing an increase."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Walmart", slug="walmart")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Tide Pods 42ct",
|
||||
category="household",
|
||||
brand="Tide",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
# Two price points — second is higher (increase)
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 2, 1),
|
||||
regular_price=Decimal("12.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("14.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends(client, price_data):
|
||||
resp = await client.get("/prices/trends", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["data_points"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends_by_category(client, price_data):
|
||||
resp = await client.get("/prices/trends?category=household", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_increases(client, price_data):
|
||||
resp = await client.get("/prices/increases", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
increase = data[0]
|
||||
assert increase["old_price"] == 12.99
|
||||
assert increase["new_price"] == 14.49
|
||||
assert increase["increase_pct"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_comparison(client, price_data):
|
||||
pid = str(price_data["product"].id)
|
||||
resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["prices"]) >= 1
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Integration tests for product endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def product_data(db_engine, auth_headers):
|
||||
"""Seed products and price history."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("4.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 10),
|
||||
regular_price=Decimal("5.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_products(client, product_data):
|
||||
resp = await client.get("/products", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["name"] == "Cheerios 18oz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_products(client, product_data):
|
||||
resp = await client.get("/products?q=Cheerios", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/products?q=nonexistent", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_detail(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert len(data["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_prices(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) == 2
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Integration tests for public endpoints (no auth)."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def public_data(db_engine):
|
||||
"""Seed data for public endpoints."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Skippy PB 16oz",
|
||||
category="pantry",
|
||||
brand="Skippy",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 5),
|
||||
regular_price=Decimal("3.99"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add(ph)
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Skippy PB 16oz"
|
||||
assert len(data["data_points"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend_not_found(client):
|
||||
resp = await client.get(f"/public/trends/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_store_comparison(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["products"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_inflation(client, public_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "categories" in data
|
||||
assert "cartsnitch_index" in data
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Integration tests for purchase endpoints."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def purchase_data(db_engine):
|
||||
"""Seed a user, store, purchase, items, and a valid session."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
user = User(
|
||||
email="buyer@example.com",
|
||||
hashed_password="not-used-with-better-auth",
|
||||
display_name="Buyer",
|
||||
)
|
||||
store = Store(name="Kroger", slug="kroger")
|
||||
session.add_all([user, store])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(store)
|
||||
|
||||
purchase = Purchase(
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="receipt-001",
|
||||
purchase_date=date(2026, 3, 10),
|
||||
total=Decimal("42.50"),
|
||||
)
|
||||
session.add(purchase)
|
||||
await session.commit()
|
||||
await session.refresh(purchase)
|
||||
|
||||
item = PurchaseItem(
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Organic Milk 1gal",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("5.99"),
|
||||
extended_price=Decimal("5.99"),
|
||||
)
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
# Create a session token directly in the sessions table
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"user_id": str(user.id),
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"store": store,
|
||||
"purchase": purchase,
|
||||
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_purchases(client, purchase_data):
|
||||
resp = await client.get("/purchases", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["store_name"] == "Kroger"
|
||||
assert data[0]["total"] == 42.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_detail(client, purchase_data):
|
||||
pid = str(purchase_data["purchase"].id)
|
||||
resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["line_items"]) == 1
|
||||
assert data["line_items"][0]["name"] == "Organic Milk 1gal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purchase_stats(client, purchase_data):
|
||||
resp = await client.get("/purchases/stats", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total_spent"] == 42.50
|
||||
assert data["purchase_count"] == 1
|
||||
assert "Kroger" in data["by_store"]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Integration tests for store endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seeded_store(db_engine):
|
||||
"""Insert a test store directly into the DB."""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None)
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_stores(client, seeded_store):
|
||||
resp = await client.get("/stores")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["slug"] == "meijer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_stores_empty(client, auth_headers):
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
|
||||
# Connect
|
||||
resp = await client.post(
|
||||
"/me/stores/meijer/connect",
|
||||
headers=auth_headers,
|
||||
json={"credentials": None},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["connected"] is True
|
||||
|
||||
# List should show connected
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
# Disconnect
|
||||
resp = await client.delete("/me/stores/meijer", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# List should be empty again
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_nonexistent_store(client, auth_headers):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent/connect",
|
||||
headers=auth_headers,
|
||||
json={},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_duplicate_store(client, auth_headers, seeded_store):
|
||||
await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,11 @@
|
||||
# Required: Generate with `openssl rand -base64 32`
|
||||
BETTER_AUTH_SECRET=change-me-in-production-min-32-chars!!
|
||||
|
||||
# Base URL of the auth service
|
||||
BETTER_AUTH_URL=http://localhost:3001
|
||||
|
||||
# Shared PostgreSQL database
|
||||
DATABASE_URL=postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch
|
||||
|
||||
# Port the auth service listens on
|
||||
PORT=3001
|
||||
@@ -0,0 +1,17 @@
|
||||
FROM node:22-alpine AS builder
|
||||
WORKDIR /app
|
||||
COPY package.json package-lock.json* ./
|
||||
RUN npm ci
|
||||
COPY tsconfig.json ./
|
||||
COPY src/ src/
|
||||
RUN npm run build
|
||||
|
||||
FROM node:22-alpine
|
||||
WORKDIR /app
|
||||
ENV NODE_ENV=production
|
||||
COPY package.json package-lock.json* ./
|
||||
RUN npm ci --omit=dev
|
||||
COPY --from=builder /app/dist/ dist/
|
||||
USER 101
|
||||
EXPOSE 3001
|
||||
CMD ["node", "dist/index.js"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user