Compare commits
125 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9bc46121f | |||
| 56d9d5ad2e | |||
| b7b9e987df | |||
| e6ed9d9193 | |||
| f0c60778cc | |||
| 7d31491114 | |||
| d0cecf9686 | |||
| ee731c4aa3 | |||
| 98d95a661a | |||
| de120cb429 | |||
| b18cb24ec4 | |||
| 1491974aba | |||
| fe8e2567a2 | |||
| ea8dcad398 | |||
| e9eb9cf489 | |||
| 14ba9d0b82 | |||
| 6b73647689 | |||
| 4f42247bf2 | |||
| d5ee743d84 | |||
| 41380e9526 | |||
| 4c29d8a241 | |||
| 31b7c14719 | |||
| 6b6b9e7d01 | |||
| c62a151210 | |||
| 835aff3522 | |||
| 5588c1b5d8 | |||
| c5ed863ab1 | |||
| 8d0552f73f | |||
| 3a75ee7aee | |||
| 30d670a257 | |||
| cfa4d8fa91 | |||
| 39e8d5c9f9 | |||
| 44c475265e | |||
| 8e1f61214c | |||
| fb1c5fb929 | |||
| 75be08ccf3 | |||
| 5596e22d0c | |||
| f45a49059e | |||
| 47ba602b02 | |||
| 5b12625e3f | |||
| d7a4086647 | |||
| b43ec1fb9b | |||
| 129f0adc96 | |||
| 587d444773 | |||
| ea789378dd | |||
| 2f096c985a | |||
| ad218c07ec | |||
| fff9f6f63a | |||
| b0ea4767b6 | |||
| c1778074e3 | |||
| 5de258220e | |||
| 003c62da3e | |||
| 57ce4315a1 | |||
| 7426ff1909 | |||
| 782448a54a | |||
| b9a66dfc8b | |||
| 7a1267de79 | |||
| 4415c56a53 | |||
| da8b413f76 | |||
| dd6a683b90 | |||
| cf8e821bdc | |||
| c9be9324cf | |||
| cc0957fc92 | |||
| f3a7b33093 | |||
| 342906c9d1 | |||
| b736e62d4f | |||
| 4cf6f91e95 | |||
| 27fe957074 | |||
| fc99e8a82e | |||
| cb1d926fc4 | |||
| fc689a3f90 | |||
| d2337a7ef7 | |||
| b7e7960f35 | |||
| aa4da81b6e | |||
| ce9e71c793 | |||
| e662ff5fab | |||
| 656c8d3842 | |||
| 853d722044 | |||
| 61540905dd | |||
| bea3342042 | |||
| 95317884ff | |||
| ca0dbd0e63 | |||
| cdcffc8582 | |||
| 8cccb8cbf0 | |||
| d201753d83 | |||
| 516697b4bd | |||
| b3aa18d7df | |||
| 6e681b9010 | |||
| 979a671300 | |||
| 860dd827d3 | |||
| 7d2e0ba64e | |||
| 118946898b | |||
| 90c81f9c8f | |||
| 4baac1ae26 | |||
| 0f1e158e89 | |||
| a9101246c9 | |||
| cf4ae49ad7 | |||
| 634d54b7fc | |||
| c74a4226f4 | |||
| 14c8aa5797 | |||
| 77c45e7eac | |||
| d6175760d1 | |||
| 6a130a9d76 | |||
| 38c860f1bb | |||
| 91ff8f76d0 | |||
| ab358f44bb | |||
| 5b8d132948 | |||
| 66565fff5c | |||
| a65361106c | |||
| 66376f6a87 | |||
| 580864ac69 | |||
| e8a53399c2 | |||
| b8091e367e | |||
| d0c887e29f | |||
| c81e14b8e7 | |||
| ec81004268 | |||
| fb6f4a0ed4 | |||
| e6f09a0212 | |||
| 58844b33fe | |||
| 0000297e4f | |||
| e572a32021 | |||
| 0789de39f0 | |||
| e57baa4468 | |||
| e42b7e1a66 | |||
| 265f2ae654 |
+244
-11
@@ -11,16 +11,19 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: cartsnitch/cartsnitch
|
||||
AUTH_IMAGE_NAME: cartsnitch/auth
|
||||
RECEIPTWITNESS_IMAGE_NAME: cartsnitch/receiptwitness
|
||||
API_IMAGE_NAME: cartsnitch/api
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -34,7 +37,7 @@ jobs:
|
||||
run: npx tsc --noEmit
|
||||
|
||||
test:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -45,17 +48,44 @@ jobs:
|
||||
- name: Run tests
|
||||
run: npx vitest run
|
||||
|
||||
build-and-push:
|
||||
runs-on: local-ubuntu-latest-cartsnitch
|
||||
needs: [lint, test]
|
||||
e2e:
|
||||
runs-on: runners-cartsnitch
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
node-version: "20"
|
||||
cache: npm
|
||||
- run: npm ci
|
||||
- run: npx playwright install --with-deps chromium
|
||||
- run: npx playwright test
|
||||
|
||||
build-and-push:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test, e2e]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then
|
||||
VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
|
||||
VERSION="${DATE_TAG}.2"
|
||||
else
|
||||
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
|
||||
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
|
||||
fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
echo "CalVer tag: $VERSION"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
@@ -72,6 +102,7 @@ jobs:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
@@ -82,3 +113,205 @@ jobs:
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
target: prod
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Create git tag
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
git tag "v${{ steps.calver.outputs.version }}"
|
||||
git push origin "v${{ steps.calver.outputs.version }}"
|
||||
|
||||
build-and-push-auth:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test, e2e]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then
|
||||
VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
|
||||
VERSION="${DATE_TAG}.2"
|
||||
else
|
||||
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
|
||||
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
|
||||
fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (auth)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.AUTH_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push auth Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./auth
|
||||
file: ./auth/Dockerfile
|
||||
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
build-and-push-receiptwitness:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then VERSION="${DATE_TAG}.2"
|
||||
else BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//"); VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"; fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.RECEIPTWITNESS_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push receiptwitness image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./receiptwitness/Dockerfile
|
||||
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
build-and-push-api:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [lint, test]
|
||||
outputs:
|
||||
calver_tag: ${{ steps.calver.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate CalVer tag
|
||||
id: calver
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
DATE_TAG=$(date -u +%Y.%m.%d)
|
||||
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
|
||||
if [ -z "$EXISTING" ]; then VERSION="$DATE_TAG"
|
||||
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then VERSION="${DATE_TAG}.2"
|
||||
else BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//"); VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"; fi
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Log in to GHCR
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (API)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.API_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha,prefix=sha-
|
||||
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push API Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./api/Dockerfile
|
||||
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
deploy-dev:
|
||||
runs-on: runners-cartsnitch
|
||||
needs: [build-and-push, build-and-push-auth, build-and-push-receiptwitness, build-and-push-api]
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: actions/create-github-app-token@v1
|
||||
with:
|
||||
app-id: ${{ secrets.CARTSNITCH_APP_ID }}
|
||||
private-key: ${{ secrets.CARTSNITCH_APP_PRIVATE_KEY }}
|
||||
owner: ${{ github.repository_owner }}
|
||||
repositories: infra
|
||||
|
||||
- name: Checkout infra repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: cartsnitch/infra
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
ref: main
|
||||
path: infra
|
||||
|
||||
- name: Install kubectl
|
||||
uses: azure/setup-kubectl@v4
|
||||
|
||||
- name: Install kustomize
|
||||
uses: imranismail/setup-kustomize@v2
|
||||
|
||||
- name: Update dev overlay image tags
|
||||
run: |
|
||||
cd infra/apps/overlays/dev
|
||||
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ needs.build-and-push.outputs.calver_tag }}
|
||||
kustomize edit set image ghcr.io/cartsnitch/auth:${{ needs.build-and-push-auth.outputs.calver_tag }}
|
||||
kustomize edit set image ghcr.io/cartsnitch/receiptwitness:${{ needs.build-and-push-receiptwitness.outputs.calver_tag }}
|
||||
kustomize edit set image ghcr.io/cartsnitch/api:${{ needs.build-and-push-api.outputs.calver_tag }}
|
||||
|
||||
- name: Commit and push to infra
|
||||
run: |
|
||||
cd infra
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/dev/kustomization.yaml
|
||||
git commit -m "ci(dev): update cartsnitch, auth, receiptwitness, and api images"
|
||||
git push origin main
|
||||
|
||||
@@ -11,6 +11,7 @@ node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
.env
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
# CartSnitch Frontend
|
||||
# CartSnitch Monorepo
|
||||
|
||||
## Project Context
|
||||
|
||||
CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/cartsnitch`) is the mobile-first Progressive Web App — the flagship repo and primary user interface.
|
||||
CartSnitch is a self-hosted grocery price intelligence platform. This repo (`cartsnitch/cartsnitch`) is the **monorepo** containing the flagship frontend PWA and core backend services.
|
||||
|
||||
**GitHub org:** github.com/cartsnitch
|
||||
**Domain:** cartsnitch.com
|
||||
|
||||
### CartSnitch Services
|
||||
### Monorepo Layout
|
||||
|
||||
| Directory | Service | Purpose |
|
||||
|-----------|---------|---------|
|
||||
| `/` (root) | Frontend | React PWA, mobile-first (this directory) |
|
||||
| `auth/` | Auth | Better-Auth Node.js service (session management, email/password, OAuth) |
|
||||
| `api/` | API Gateway | Frontend-facing REST API |
|
||||
| `common/` | Common | Shared Python models, schemas, Alembic migrations |
|
||||
| `receiptwitness/` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
|
||||
### Other CartSnitch Repos (still separate)
|
||||
|
||||
| Repo | Service | Purpose |
|
||||
|------|---------|---------|
|
||||
| `cartsnitch/common` | — | Shared models, schemas, utilities |
|
||||
| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
| `cartsnitch/api` | API Gateway | Frontend-facing REST API |
|
||||
| `cartsnitch/cartsnitch` | Frontend | React PWA, mobile-first (this repo) |
|
||||
| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison |
|
||||
| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring |
|
||||
| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization |
|
||||
@@ -161,9 +167,13 @@ frontend/
|
||||
|
||||
All data comes from the CartSnitch API gateway (`cartsnitch/api`). Base URL configured via environment variable `VITE_API_URL`.
|
||||
|
||||
- JWT auth: store access token in memory (not localStorage), refresh token in httpOnly cookie if possible, or secure storage.
|
||||
- **Authentication via Better-Auth** (`auth/` service). Sessions are managed via httpOnly cookies — no tokens in localStorage or memory.
|
||||
- Auth service URL configured via `VITE_AUTH_URL` (default: `http://localhost:3001`)
|
||||
- Frontend uses `better-auth/react` client for sign-in, sign-up, sign-out, and `useSession()` hook
|
||||
- API gateway validates sessions by querying the shared `sessions` table in Postgres
|
||||
- Both cookie-based and Bearer token auth are supported (cookies for web, Bearer for API clients)
|
||||
- TanStack Query handles caching, background refetching, and optimistic updates.
|
||||
- API client should handle 401 responses by attempting token refresh before retrying.
|
||||
- API client sends `credentials: 'include'` on all requests to forward session cookies.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
|
||||
+5
-4
@@ -9,13 +9,14 @@ RUN npm ci
|
||||
COPY . .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Production
|
||||
FROM nginx:stable-alpine AS prod
|
||||
# Stage 2: Production — uses nginxinc/nginx-unprivileged which runs as non-root (UID 101)
|
||||
FROM nginxinc/nginx-unprivileged:stable-alpine AS prod
|
||||
|
||||
COPY --from=build /app/dist /usr/share/nginx/html
|
||||
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
EXPOSE 80
|
||||
USER 101
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget -qO- http://localhost/health || exit 1
|
||||
CMD wget -qO- http://localhost:8080/health || exit 1
|
||||
|
||||
@@ -1,73 +1 @@
|
||||
# React + TypeScript + Vite
|
||||
|
||||
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||
|
||||
Currently, two official plugins are available:
|
||||
|
||||
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Oxc](https://oxc.rs)
|
||||
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/)
|
||||
|
||||
## React Compiler
|
||||
|
||||
The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
|
||||
|
||||
## Expanding the ESLint configuration
|
||||
|
||||
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||
|
||||
```js
|
||||
export default defineConfig([
|
||||
globalIgnores(['dist']),
|
||||
{
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
extends: [
|
||||
// Other configs...
|
||||
|
||||
// Remove tseslint.configs.recommended and replace with this
|
||||
tseslint.configs.recommendedTypeChecked,
|
||||
// Alternatively, use this for stricter rules
|
||||
tseslint.configs.strictTypeChecked,
|
||||
// Optionally, add this for stylistic rules
|
||||
tseslint.configs.stylisticTypeChecked,
|
||||
|
||||
// Other configs...
|
||||
],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
// other options...
|
||||
},
|
||||
},
|
||||
])
|
||||
```
|
||||
|
||||
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||
|
||||
```js
|
||||
// eslint.config.js
|
||||
import reactX from 'eslint-plugin-react-x'
|
||||
import reactDom from 'eslint-plugin-react-dom'
|
||||
|
||||
export default defineConfig([
|
||||
globalIgnores(['dist']),
|
||||
{
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
extends: [
|
||||
// Other configs...
|
||||
// Enable lint rules for React
|
||||
reactX.configs['recommended-typescript'],
|
||||
// Enable lint rules for React DOM
|
||||
reactDom.configs.recommended,
|
||||
],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
// other options...
|
||||
},
|
||||
},
|
||||
])
|
||||
```
|
||||
# CartSnitch
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
.git
|
||||
.github
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
__pycache__
|
||||
*.py[cod]
|
||||
*.egg-info
|
||||
dist
|
||||
.venv
|
||||
.env
|
||||
tests
|
||||
openapi.json
|
||||
CLAUDE.md
|
||||
README.md
|
||||
@@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
dist/
|
||||
.venv/
|
||||
.env
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
openapi.json
|
||||
+175
@@ -0,0 +1,175 @@
|
||||
# CartSnitch API Gateway
|
||||
|
||||
## Project Context
|
||||
|
||||
CartSnitch is a self-hosted grocery price intelligence platform built as a polyrepo microservices architecture. This repo (`cartsnitch/api`) is the public-facing API gateway that serves the frontend and proxies requests to internal services.
|
||||
|
||||
**GitHub org:** github.com/cartsnitch
|
||||
**Domain:** cartsnitch.com
|
||||
|
||||
### CartSnitch Services
|
||||
|
||||
| Repo | Service | Purpose |
|
||||
|------|---------|---------|
|
||||
| `cartsnitch/common` | — | Shared models, schemas, utilities |
|
||||
| `cartsnitch/receiptwitness` | ReceiptWitness | Purchase data ingestion via retailer scrapers |
|
||||
| `cartsnitch/api` | API Gateway | Frontend-facing REST API (this repo) |
|
||||
| `cartsnitch/cartsnitch` | Frontend | React PWA (mobile-first) |
|
||||
| `cartsnitch/stickershock` | StickerShock | Price increase detection & CPI comparison |
|
||||
| `cartsnitch/shrinkray` | ShrinkRay | Shrinkflation monitoring |
|
||||
| `cartsnitch/clipartist` | ClipArtist | Coupon/deal watching & shopping optimization |
|
||||
| `cartsnitch/infra` | — | K8s manifests, Flux kustomizations |
|
||||
|
||||
### Architecture Decisions
|
||||
|
||||
- **Polyrepo:** Each service has its own repo, Dockerfile, CI/CD pipeline.
|
||||
- **Shared DB:** One PostgreSQL cluster. This service reads from all tables for serving frontend queries. Models come from `cartsnitch-common`.
|
||||
- **Inter-service comms:** REST to internal services, Redis pub/sub for event subscriptions.
|
||||
- **Target scale:** 500–1,000 users initially.
|
||||
|
||||
## What This Service Does
|
||||
|
||||
The API Gateway is the single entry point for the frontend PWA and any external consumers. It:
|
||||
|
||||
1. **Handles user authentication** — registration, login, JWT token management
|
||||
2. **Serves purchase/product/price data** — reads from the shared DB
|
||||
3. **Proxies scraping operations** — forwards scrape triggers to ReceiptWitness
|
||||
4. **Serves coupon/deal data** — reads from shared DB (written by ClipArtist)
|
||||
5. **Serves alerts** — price increase alerts (StickerShock), shrinkflation alerts (ShrinkRay)
|
||||
6. **Provides public data endpoints** — aggregate price trends for the transparency/shaming features
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- Python 3.12+
|
||||
- FastAPI (async)
|
||||
- SQLAlchemy 2.0 (via `cartsnitch-common`, read-heavy)
|
||||
- Pydantic v2 (request/response validation)
|
||||
- python-jose or PyJWT (JWT auth)
|
||||
- passlib + bcrypt (password hashing)
|
||||
- httpx (async HTTP client for proxying to internal services)
|
||||
- Redis (subscribe to events for websocket push, caching)
|
||||
- uvicorn (ASGI server)
|
||||
|
||||
## Repo Structure
|
||||
|
||||
```
|
||||
api/
|
||||
├── CLAUDE.md
|
||||
├── README.md
|
||||
├── pyproject.toml
|
||||
├── Dockerfile
|
||||
├── docker-compose.yml
|
||||
├── src/
|
||||
│ └── cartsnitch_api/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # Service-specific settings
|
||||
│ ├── main.py # FastAPI app factory, lifespan, middleware
|
||||
│ ├── auth/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── jwt.py # JWT creation/validation
|
||||
│ │ ├── passwords.py # Hashing, verification
|
||||
│ │ ├── dependencies.py # FastAPI dependency injection (get_current_user)
|
||||
│ │ └── routes.py # /auth/register, /auth/login, /auth/refresh
|
||||
│ ├── routes/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── purchases.py # Purchase history endpoints
|
||||
│ │ ├── products.py # Normalized product catalog
|
||||
│ │ ├── prices.py # Price history and trends
|
||||
│ │ ├── coupons.py # Active coupons and deals
|
||||
│ │ ├── alerts.py # Price increase / shrinkflation alerts
|
||||
│ │ ├── stores.py # Store info, user store account management
|
||||
│ │ ├── scraping.py # Proxy to ReceiptWitness (trigger scrape, status)
|
||||
│ │ ├── shopping.py # Optimized shopping list (proxy to ClipArtist)
|
||||
│ │ ├── public.py # Public price transparency endpoints (no auth)
|
||||
│ │ └── health.py
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── receiptwitness.py # HTTP client for ReceiptWitness internal API
|
||||
│ │ ├── stickershock.py # HTTP client for StickerShock internal API
|
||||
│ │ ├── clipartist.py # HTTP client for ClipArtist internal API
|
||||
│ │ └── shrinkray.py # HTTP client for ShrinkRay internal API
|
||||
│ ├── middleware/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── cors.py
|
||||
│ │ └── rate_limit.py
|
||||
│ └── cache.py # Redis caching helpers
|
||||
└── tests/
|
||||
├── conftest.py
|
||||
├── test_auth/
|
||||
├── test_routes/
|
||||
└── test_services/
|
||||
```
|
||||
|
||||
## API Endpoint Design
|
||||
|
||||
### Auth
|
||||
- `POST /auth/register` — create account
|
||||
- `POST /auth/login` — get JWT access + refresh tokens
|
||||
- `POST /auth/refresh` — refresh access token
|
||||
- `GET /auth/me` — current user profile
|
||||
|
||||
### Store Accounts
|
||||
- `GET /stores` — list supported stores
|
||||
- `GET /me/stores` — list user's connected store accounts + sync status
|
||||
- `POST /me/stores/{store_slug}/connect` — initiate store connection flow
|
||||
- `DELETE /me/stores/{store_slug}` — disconnect store account
|
||||
|
||||
### Purchases
|
||||
- `GET /purchases` — list user's purchases (paginated, filterable by store/date)
|
||||
- `GET /purchases/{id}` — purchase detail with line items
|
||||
- `GET /purchases/stats` — spending summary (by store, by category, by period)
|
||||
|
||||
### Products
|
||||
- `GET /products` — normalized product catalog (search, filter)
|
||||
- `GET /products/{id}` — product detail with cross-store price comparison
|
||||
- `GET /products/{id}/prices` — price history for a product across stores
|
||||
|
||||
### Prices
|
||||
- `GET /prices/trends` — aggregate price trends (public-capable)
|
||||
- `GET /prices/increases` — recent significant price increases
|
||||
- `GET /prices/comparison` — compare specific items across stores
|
||||
|
||||
### Coupons
|
||||
- `GET /coupons` — active coupons/deals (filterable by store)
|
||||
- `GET /coupons/relevant` — coupons relevant to user's purchase history
|
||||
|
||||
### Shopping
|
||||
- `POST /shopping/optimize` — input: shopping list → output: store-split + coupons
|
||||
- `GET /shopping/lists` — user's saved shopping lists
|
||||
|
||||
### Alerts
|
||||
- `GET /alerts` — user's price increase and shrinkflation alerts
|
||||
- `PUT /alerts/settings` — configure alert thresholds
|
||||
|
||||
### Public (No Auth)
|
||||
- `GET /public/trends/{product_id}` — public price trend for a product
|
||||
- `GET /public/store-comparison` — public store-vs-store price comparison
|
||||
- `GET /public/inflation` — price changes vs CPI baseline
|
||||
|
||||
### Scraping (Proxy to ReceiptWitness)
|
||||
- `POST /scraping/{store_slug}/sync` — trigger a sync for the current user
|
||||
- `GET /scraping/status` — sync status across all stores
|
||||
|
||||
## Authentication
|
||||
|
||||
- JWT-based auth with short-lived access tokens (15 min) and longer refresh tokens (7 days).
|
||||
- Passwords hashed with bcrypt via passlib.
|
||||
- All user-specific endpoints require a valid JWT in the `Authorization: Bearer` header.
|
||||
- Public endpoints under `/public/` do not require auth.
|
||||
- Internal service-to-service calls (ReceiptWitness, etc.) use a shared API key in the `X-Service-Key` header — not user JWTs.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
- **Never push directly to main.** Always create feature branches and open PRs.
|
||||
- Branch naming: `feature/<description>` or `fix/<description>`
|
||||
- Use conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `chore:`
|
||||
- OpenAPI docs auto-generated at `/docs` (Swagger) and `/redoc`.
|
||||
- Write tests for all routes. Use httpx.AsyncClient with FastAPI's TestClient pattern.
|
||||
|
||||
## Important Notes
|
||||
|
||||
- This service is read-heavy on the shared DB. Use async SQLAlchemy sessions.
|
||||
- Consider Redis caching for expensive queries (price trends, product comparisons). Cache invalidation via Redis pub/sub events from other services.
|
||||
- Rate limiting on public endpoints is important — these could get hammered if the price transparency features get attention.
|
||||
- CORS must allow the frontend origin (cartsnitch.com and localhost for dev).
|
||||
- The store connection flow is the trickiest UX challenge: the user needs to authenticate with each retailer, and we need to capture the resulting session. This likely involves a controlled Playwright browser session that the user can see/interact with, or an OAuth-like redirect flow if the retailer supports it (Kroger does for its public API, but not for purchase history access).
|
||||
@@ -0,0 +1,33 @@
|
||||
# Stage 1: Build dependencies
|
||||
# Build context is the repo root. Paths below are relative to the root.
|
||||
FROM python:3.12-slim AS build
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpq-dev \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
COPY api/pyproject.toml ./
|
||||
COPY api/src/ ./src/
|
||||
RUN pip install --no-cache-dir --prefix=/install .
|
||||
|
||||
# Stage 2: Production image
|
||||
FROM python:3.12-slim AS prod
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends libpq5 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
RUN adduser --system --group --uid 1000 app
|
||||
COPY --from=build /install /usr/local
|
||||
COPY api/src/ ./src/
|
||||
COPY api/alembic.ini ./
|
||||
COPY api/alembic/ ./alembic/
|
||||
|
||||
USER 1000
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"
|
||||
|
||||
CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,36 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = postgresql://OVERRIDE_VIA_ENV_VAR
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Alembic environment configuration for CartSnitch."""
|
||||
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from cartsnitch_api.models import Base # noqa: F401 — imports all models for autogenerate
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC")
|
||||
if not db_url:
|
||||
raise RuntimeError(
|
||||
"CARTSNITCH_DATABASE_URL_SYNC must be set. "
|
||||
"Example: postgresql://user:pass@localhost:5432/cartsnitch"
|
||||
)
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
${imports if imports else ""}
|
||||
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Encrypt existing plaintext session_data with Fernet.
|
||||
|
||||
Revision ID: 001_encrypt_session_data
|
||||
Revises:
|
||||
Create Date: 2026-03-19
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import sqlalchemy as sa
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "001_encrypt_session_data"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
key = os.environ.get("CARTSNITCH_FERNET_KEY")
|
||||
if not key:
|
||||
raise RuntimeError("CARTSNITCH_FERNET_KEY must be set to run this migration")
|
||||
return Fernet(key.encode())
|
||||
|
||||
|
||||
def _is_fernet_token(value: str) -> bool:
|
||||
"""Check if a string looks like a Fernet token (base64 starting with gAAAAA)."""
|
||||
return value.startswith("gAAAAA")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Change column type from JSON to TEXT to hold Fernet ciphertext
|
||||
op.alter_column(
|
||||
"user_store_accounts",
|
||||
"session_data",
|
||||
type_=sa.Text(),
|
||||
existing_type=sa.JSON(),
|
||||
existing_nullable=True,
|
||||
postgresql_using="session_data::text",
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
rows = conn.execute(
|
||||
text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
f = _get_fernet()
|
||||
for row_id, session_data in rows:
|
||||
raw = str(session_data)
|
||||
if _is_fernet_token(raw):
|
||||
continue
|
||||
plaintext = raw if isinstance(session_data, str) else json.dumps(session_data)
|
||||
encrypted = f.encrypt(plaintext.encode()).decode()
|
||||
conn.execute(
|
||||
text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"),
|
||||
{"data": encrypted, "id": row_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
rows = conn.execute(
|
||||
text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
f = _get_fernet()
|
||||
for row_id, session_data in rows:
|
||||
raw = str(session_data)
|
||||
if not _is_fernet_token(raw):
|
||||
continue
|
||||
decrypted = f.decrypt(raw.encode()).decode()
|
||||
conn.execute(
|
||||
text("UPDATE user_store_accounts SET session_data = :data WHERE id = :id"),
|
||||
{"data": decrypted, "id": row_id},
|
||||
)
|
||||
|
||||
# Revert column type from TEXT back to JSON
|
||||
op.alter_column(
|
||||
"user_store_accounts",
|
||||
"session_data",
|
||||
type_=sa.JSON(),
|
||||
existing_type=sa.Text(),
|
||||
existing_nullable=True,
|
||||
postgresql_using="session_data::json",
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Add Better-Auth tables and extend users table.
|
||||
|
||||
Creates sessions, accounts, and verifications tables for Better-Auth.
|
||||
Adds email_verified and image columns to existing users table.
|
||||
Migrates password hashes from users.hashed_password to accounts.password.
|
||||
|
||||
Revision ID: 002_better_auth_tables
|
||||
Revises: 001_encrypt_session_data
|
||||
Create Date: 2026-03-28
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "002_better_auth_tables"
|
||||
down_revision = "001_encrypt_session_data"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# --- Extend users table for Better-Auth compatibility ---
|
||||
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
|
||||
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
|
||||
|
||||
# --- Create sessions table ---
|
||||
op.create_table(
|
||||
"sessions",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("token", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("ip_address", sa.Text(), nullable=True),
|
||||
sa.Column("user_agent", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
|
||||
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
|
||||
|
||||
# --- Create accounts table ---
|
||||
op.create_table(
|
||||
"accounts",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("account_id", sa.Text(), nullable=False),
|
||||
sa.Column("provider_id", sa.Text(), nullable=False),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("scope", sa.Text(), nullable=True),
|
||||
sa.Column("id_token", sa.Text(), nullable=True),
|
||||
sa.Column("password", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
|
||||
|
||||
# --- Create verifications table ---
|
||||
op.create_table(
|
||||
"verifications",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("identifier", sa.Text(), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# --- Migrate existing password hashes to accounts table ---
|
||||
# For each user with a hashed_password, create a 'credential' account row
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL")
|
||||
).fetchall()
|
||||
|
||||
for user_id, hashed_password in users:
|
||||
user_id_str = str(user_id)
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
|
||||
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
|
||||
),
|
||||
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("verifications")
|
||||
op.drop_table("accounts")
|
||||
op.drop_index("ix_sessions_user_id", table_name="sessions")
|
||||
op.drop_index("ix_sessions_token", table_name="sessions")
|
||||
op.drop_table("sessions")
|
||||
op.drop_column("users", "image")
|
||||
op.drop_column("users", "email_verified")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Make users.hashed_password nullable.
|
||||
|
||||
Better-Auth inserts users without hashed_password (passwords live in the
|
||||
accounts table). This column is now purely optional.
|
||||
|
||||
Revision ID: 003_make_users_hashed_password_nullable
|
||||
Revises: 002_better_auth_tables
|
||||
Create Date: 2026-03-30
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "003_make_users_hashed_password_nullable"
|
||||
down_revision = "002_better_auth_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=False)
|
||||
@@ -0,0 +1,58 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "cartsnitch-api"
|
||||
version = "0.1.0"
|
||||
description = "CartSnitch API Gateway — public-facing REST API"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"pydantic[email]>=2.9.0",
|
||||
"pydantic-settings>=2.5.0",
|
||||
"sqlalchemy[asyncio]>=2.0.35",
|
||||
"asyncpg>=0.30.0",
|
||||
"alembic>=1.13,<2.0",
|
||||
"psycopg2>=2.9,<3.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"httpx>=0.27.0",
|
||||
"redis[hiredis]>=5.2.0",
|
||||
"cryptography>=43.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"httpx>=0.27.0",
|
||||
"ruff>=0.7.0",
|
||||
"psycopg2-binary>=2.9,<3.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/cartsnitch_api"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "UP", "B"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"src/cartsnitch_api/**/routes*.py" = ["B008"]
|
||||
"src/cartsnitch_api/**/dependencies.py" = ["B008"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
ignore_missing_imports = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"extends": ["local>cartsnitch/.github:renovate-config"]
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""FastAPI dependency injection for authentication.
|
||||
|
||||
Validates Better-Auth session tokens from cookies or Bearer header.
|
||||
Sessions are verified by querying the shared sessions table directly.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.database import get_db
|
||||
|
||||
# Keep Bearer scheme as optional — Better-Auth primarily uses cookies,
|
||||
# but we support Bearer tokens for service-to-service or mobile clients.
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Better-Auth session cookie name
|
||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||
|
||||
|
||||
async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
|
||||
"""Validate a Better-Auth session token against the sessions table.
|
||||
|
||||
Returns the user_id (as UUID) if the session is valid and not expired.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid session token",
|
||||
)
|
||||
|
||||
user_id, expires_at = row
|
||||
if expires_at.tzinfo is None:
|
||||
# Treat naive datetimes as UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < datetime.now(UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired",
|
||||
)
|
||||
|
||||
return UUID(str(user_id))
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> UUID:
|
||||
"""Extract and validate the session token from cookie or Authorization header.
|
||||
|
||||
Checks in order:
|
||||
1. Better-Auth session cookie (primary — web clients)
|
||||
2. Bearer token in Authorization header (fallback — API clients)
|
||||
"""
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie
|
||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
token = cookie_token
|
||||
|
||||
# 2. Fall back to Bearer header
|
||||
if not token and credentials:
|
||||
token = credentials.credentials
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
return await _validate_session_token(token, db)
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
if x_service_key != settings.service_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid service key",
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""JWT token creation and validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def create_access_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "access"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def create_refresh_token(user_id: UUID) -> str:
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||
payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
|
||||
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
try:
|
||||
return cast(
|
||||
dict[str, Any],
|
||||
jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]),
|
||||
)
|
||||
except JWTError as e:
|
||||
raise ValueError(f"Invalid token: {e}") from e
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Password hashing and verification with bcrypt."""
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Auth routes: user profile management.
|
||||
|
||||
Registration, login, refresh, and session management are handled by
|
||||
the Better-Auth service (auth/). This router provides user profile
|
||||
endpoints that query our own user data from the shared database.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
UpdateUserRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from cartsnitch_api.services.auth import AuthService
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.get_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_me(
|
||||
body: UpdateUserRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
return await svc.update_user(user_id, email=body.email, display_name=body.display_name)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_me(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AuthService(db)
|
||||
try:
|
||||
await svc.delete_user(user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
) from None
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Redis/DragonflyDB caching helpers."""
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class CacheClient:
|
||||
"""Stub for Redis/DragonflyDB caching.
|
||||
|
||||
Will be used for expensive queries: price trends, product comparisons.
|
||||
Cache invalidation via Redis pub/sub events from other services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.url = settings.redis_url
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
# TODO: implement with redis-py async
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
@@ -0,0 +1,53 @@
|
||||
import base64
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_prefix": "CARTSNITCH_"}
|
||||
|
||||
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 15
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
service_key: str = "change-me-in-production"
|
||||
# Valid Fernet key for local dev — MUST be overridden in production
|
||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
|
||||
auth_service_url: str = "http://auth:3001"
|
||||
|
||||
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
|
||||
|
||||
receiptwitness_url: str = "http://receiptwitness:8001"
|
||||
stickershock_url: str = "http://stickershock:8002"
|
||||
clipartist_url: str = "http://clipartist:8003"
|
||||
shrinkray_url: str = "http://shrinkray:8004"
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_fernet_key(self):
|
||||
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||
if len(decoded) != 32:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_FERNET_KEY must be a valid Fernet key "
|
||||
"(32 bytes, url-safe base64 encoded). "
|
||||
"Generate one with: python -c "
|
||||
"'from cryptography.fernet import Fernet; "
|
||||
"print(Fernet.generate_key().decode())'"
|
||||
) from None
|
||||
return self
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Constants and enums shared across CartSnitch services."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StoreSlug(StrEnum):
|
||||
"""Supported retailer slugs."""
|
||||
|
||||
MEIJER = "meijer"
|
||||
KROGER = "kroger"
|
||||
TARGET = "target"
|
||||
|
||||
|
||||
class AccountStatus(StrEnum):
|
||||
"""User store account link status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DiscountType(StrEnum):
|
||||
"""Coupon discount type."""
|
||||
|
||||
PERCENT = "percent"
|
||||
FIXED = "fixed"
|
||||
BOGO = "bogo"
|
||||
BUY_X_GET_Y = "buy_x_get_y"
|
||||
|
||||
|
||||
class PriceSource(StrEnum):
|
||||
"""Source of a price observation."""
|
||||
|
||||
RECEIPT = "receipt"
|
||||
CATALOG = "catalog"
|
||||
WEEKLY_AD = "weekly_ad"
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Redis pub/sub event types."""
|
||||
|
||||
RECEIPTS_INGESTED = "cartsnitch.receipts.ingested"
|
||||
PRICES_UPDATED = "cartsnitch.prices.updated"
|
||||
PRODUCTS_NORMALIZED = "cartsnitch.products.normalized"
|
||||
COUPONS_UPDATED = "cartsnitch.coupons.updated"
|
||||
ALERT_PRICE_INCREASE = "cartsnitch.alerts.price_increase"
|
||||
ALERT_SHRINKFLATION = "cartsnitch.alerts.shrinkflation"
|
||||
|
||||
|
||||
class ProductCategory(StrEnum):
|
||||
"""Top-level product categories."""
|
||||
|
||||
PRODUCE = "produce"
|
||||
DAIRY = "dairy"
|
||||
MEAT = "meat"
|
||||
BAKERY = "bakery"
|
||||
FROZEN = "frozen"
|
||||
PANTRY = "pantry"
|
||||
BEVERAGES = "beverages"
|
||||
SNACKS = "snacks"
|
||||
HOUSEHOLD = "household"
|
||||
PERSONAL_CARE = "personal_care"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class MatchConfidence(StrEnum):
|
||||
"""Confidence level for product matching."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class SizeUnit(StrEnum):
|
||||
"""Standardized product size units."""
|
||||
|
||||
OZ = "oz"
|
||||
FL_OZ = "fl_oz"
|
||||
LB = "lb"
|
||||
G = "g"
|
||||
KG = "kg"
|
||||
ML = "ml"
|
||||
L = "l"
|
||||
CT = "ct"
|
||||
PK = "pk"
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Database session management for the API gateway."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,62 @@
|
||||
"""FastAPI app factory for CartSnitch API Gateway."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from cartsnitch_api.auth.routes import router as auth_router
|
||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
|
||||
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
|
||||
from cartsnitch_api.routes.alerts import router as alerts_router
|
||||
from cartsnitch_api.routes.coupons import router as coupons_router
|
||||
from cartsnitch_api.routes.health import router as health_router
|
||||
from cartsnitch_api.routes.prices import router as prices_router
|
||||
from cartsnitch_api.routes.products import router as products_router
|
||||
from cartsnitch_api.routes.public import router as public_router
|
||||
from cartsnitch_api.routes.purchases import router as purchases_router
|
||||
from cartsnitch_api.routes.scraping import router as scraping_router
|
||||
from cartsnitch_api.routes.shopping import router as shopping_router
|
||||
from cartsnitch_api.routes.stores import router as stores_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# TODO: initialize DB session pool, Redis connection, service clients
|
||||
yield
|
||||
# TODO: cleanup connections
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="CartSnitch API",
|
||||
description="Grocery price tracking and shrinkflation detection API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Middleware (order matters — outermost first)
|
||||
add_cors_middleware(app)
|
||||
add_error_monitor_middleware(app)
|
||||
add_rate_limit_middleware(app)
|
||||
|
||||
# Exception handlers
|
||||
add_error_handlers(app)
|
||||
|
||||
# Routers
|
||||
app.include_router(health_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(stores_router)
|
||||
app.include_router(purchases_router)
|
||||
app.include_router(products_router)
|
||||
app.include_router(prices_router)
|
||||
app.include_router(coupons_router)
|
||||
app.include_router(shopping_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(scraping_router)
|
||||
app.include_router(public_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -0,0 +1,16 @@
|
||||
"""CORS middleware configuration."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def add_cors_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Structured error responses and error monitoring.
|
||||
|
||||
Ensures all errors return a consistent JSON shape and never leak stack traces.
|
||||
Provides hooks for error monitoring/alerting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger("cartsnitch_api.errors")
|
||||
|
||||
|
||||
def _error_response(
|
||||
status_code: int,
|
||||
detail: str,
|
||||
code: str | None = None,
|
||||
errors: list[dict] | None = None,
|
||||
) -> JSONResponse:
|
||||
"""Build a consistent error response."""
|
||||
body: dict = {"detail": detail}
|
||||
if code:
|
||||
body["code"] = code
|
||||
if errors:
|
||||
body["errors"] = errors
|
||||
return JSONResponse(status_code=status_code, content=body)
|
||||
|
||||
|
||||
def add_error_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers for consistent error responses."""
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
"""Return 422 with structured field-level error details."""
|
||||
field_errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", ())
|
||||
field_errors.append(
|
||||
{
|
||||
"field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc),
|
||||
"message": err.get("msg", "Invalid value"),
|
||||
"type": err.get("type", "value_error"),
|
||||
}
|
||||
)
|
||||
return _error_response(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Validation error",
|
||||
code="VALIDATION_ERROR",
|
||||
errors=field_errors,
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
"""Wrap HTTP exceptions (Starlette and FastAPI) in consistent format."""
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
return _error_response(
|
||||
status_code=exc.status_code,
|
||||
detail=detail,
|
||||
code=_status_to_code(exc.status_code),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all: log full traceback, return safe 500 to client."""
|
||||
logger.error(
|
||||
"Unhandled exception on %s %s: %s\n%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
_notify_error_monitor(request, exc)
|
||||
|
||||
return _error_response(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
code="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def _status_to_code(status_code: int) -> str:
|
||||
"""Map HTTP status code to a machine-readable error code."""
|
||||
mapping = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "RATE_LIMITED",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
}
|
||||
return mapping.get(status_code, f"HTTP_{status_code}")
|
||||
|
||||
|
||||
# ---------- Error Monitoring ----------
|
||||
|
||||
|
||||
class _ErrorMonitor:
|
||||
"""Simple error counter for monitoring and alerting hooks.
|
||||
|
||||
Tracks error counts and rates. In production, this would forward
|
||||
to an external monitoring service (Prometheus, Sentry, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.error_counts: dict[int, int] = {}
|
||||
self.recent_5xx: list[dict] = []
|
||||
self._max_recent = 100
|
||||
|
||||
def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None:
|
||||
self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1
|
||||
|
||||
if status_code >= 500:
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"status": status_code,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"error": error,
|
||||
}
|
||||
self.recent_5xx.append(entry)
|
||||
if len(self.recent_5xx) > self._max_recent:
|
||||
self.recent_5xx = self.recent_5xx[-self._max_recent :]
|
||||
|
||||
logger.warning(
|
||||
"5xx error recorded: %s %s -> %d (%s)",
|
||||
method,
|
||||
path,
|
||||
status_code,
|
||||
error or "unknown",
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return {
|
||||
"error_counts": dict(self.error_counts),
|
||||
"recent_5xx_count": len(self.recent_5xx),
|
||||
}
|
||||
|
||||
|
||||
_monitor = _ErrorMonitor()
|
||||
|
||||
|
||||
def get_error_monitor() -> _ErrorMonitor:
|
||||
"""Access the global error monitor (for health/metrics endpoints)."""
|
||||
return _monitor
|
||||
|
||||
|
||||
def _notify_error_monitor(request: Request, exc: Exception) -> None:
|
||||
"""Record unhandled exception in the error monitor."""
|
||||
_monitor.record(
|
||||
status_code=500,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc)[:200],
|
||||
)
|
||||
|
||||
|
||||
class ErrorMonitorMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to track all 4xx/5xx responses for monitoring."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable],
|
||||
):
|
||||
response = await call_next(request)
|
||||
|
||||
if response.status_code >= 400:
|
||||
_monitor.record(
|
||||
status_code=response.status_code,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def add_error_monitor_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(ErrorMonitorMiddleware)
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Rate limiting middleware for public and authenticated endpoints.
|
||||
|
||||
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
if current_count >= self.max_requests:
|
||||
retry_after = int(self._hits[key][0] - cutoff) + 1
|
||||
return False, 0, retry_after
|
||||
|
||||
self._hits[key].append(now)
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
# Use last 16 chars of token as key to avoid storing full tokens
|
||||
return f"token:{token[-16:]}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"code": "RATE_LIMITED",
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limiter.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
|
||||
def add_rate_limit_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""SQLAlchemy ORM models — re-exports all models for convenience."""
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import Purchase, PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
"UUIDPrimaryKeyMixin",
|
||||
"Store",
|
||||
"StoreLocation",
|
||||
"User",
|
||||
"UserStoreAccount",
|
||||
"Purchase",
|
||||
"PurchaseItem",
|
||||
"NormalizedProduct",
|
||||
"PriceHistory",
|
||||
"Coupon",
|
||||
"ShrinkflationEvent",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Base model and mixins for all CartSnitch ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all CartSnitch models."""
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin providing created_at / updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin providing a UUID primary key."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=uuid.uuid4, server_default=func.gen_random_uuid()
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Coupon model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import DiscountType
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A coupon or deal for a product at a store."""
|
||||
|
||||
__tablename__ = "coupons"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(String(1000))
|
||||
discount_type: Mapped[DiscountType] = mapped_column(String(20), nullable=False)
|
||||
discount_value: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
min_purchase: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
valid_from: Mapped[date | None] = mapped_column(Date)
|
||||
valid_to: Mapped[date | None] = mapped_column(Date)
|
||||
requires_clip: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
coupon_code: Mapped[str | None] = mapped_column(String(100))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
scraped_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="coupons")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(back_populates="coupons")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""PriceHistory model — tracks product prices over time."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Index, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import PriceSource
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single price observation for a product at a store on a date."""
|
||||
|
||||
__tablename__ = "price_history"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_price_history_product_store_date",
|
||||
"normalized_product_id",
|
||||
"store_id",
|
||||
"observed_date",
|
||||
),
|
||||
)
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
observed_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
regular_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source: Mapped[PriceSource] = mapped_column(String(20), nullable=False)
|
||||
purchase_item_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("purchase_items.id"))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(back_populates="price_histories")
|
||||
store: Mapped["Store"] = relationship(back_populates="price_histories")
|
||||
purchase_item: Mapped["PurchaseItem | None"] = relationship(
|
||||
back_populates="price_history_entries"
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""NormalizedProduct model — the canonical product identity."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import ProductCategory, SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import PurchaseItem
|
||||
from cartsnitch_api.models.shrinkflation import ShrinkflationEvent
|
||||
|
||||
|
||||
class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Canonical product identity — matches products across retailers."""
|
||||
|
||||
__tablename__ = "normalized_products"
|
||||
|
||||
canonical_name: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
category: Mapped[ProductCategory | None] = mapped_column(String(50))
|
||||
subcategory: Mapped[str | None] = mapped_column(String(100))
|
||||
brand: Mapped[str | None] = mapped_column(String(200))
|
||||
size: Mapped[str | None] = mapped_column(String(50))
|
||||
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
|
||||
|
||||
# Relationships
|
||||
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="normalized_product")
|
||||
shrinkflation_events: Mapped[list["ShrinkflationEvent"]] = relationship(
|
||||
back_populates="normalized_product"
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Purchase and PurchaseItem models."""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Date,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Numeric,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
from cartsnitch_api.models.store import Store, StoreLocation
|
||||
from cartsnitch_api.models.user import User
|
||||
|
||||
|
||||
class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A single shopping trip / receipt."""
|
||||
|
||||
__tablename__ = "purchases"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
store_location_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("store_locations.id"))
|
||||
receipt_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
purchase_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
total: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
subtotal: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
tax: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
savings_total: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
source_url: Mapped[str | None] = mapped_column(String(500))
|
||||
raw_data: Mapped[dict | None] = mapped_column(JSON)
|
||||
ingested_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="purchases")
|
||||
store: Mapped["Store"] = relationship(back_populates="purchases")
|
||||
store_location: Mapped["StoreLocation | None"] = relationship(back_populates="purchases")
|
||||
items: Mapped[list["PurchaseItem"]] = relationship(back_populates="purchase")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_purchases_user_store", "user_id", "store_id"),
|
||||
UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
|
||||
)
|
||||
|
||||
|
||||
class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Individual line item on a receipt."""
|
||||
|
||||
__tablename__ = "purchase_items"
|
||||
|
||||
purchase_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("purchases.id"), nullable=False)
|
||||
product_name_raw: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
upc: Mapped[str | None] = mapped_column(String(20))
|
||||
quantity: Mapped[Decimal] = mapped_column(Numeric(10, 3), nullable=False, default=1)
|
||||
unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
extended_price: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False)
|
||||
regular_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
sale_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
coupon_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
loyalty_discount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
category_raw: Mapped[str | None] = mapped_column(String(100))
|
||||
normalized_product_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("normalized_products.id")
|
||||
)
|
||||
|
||||
# Relationships
|
||||
purchase: Mapped["Purchase"] = relationship(back_populates="items")
|
||||
normalized_product: Mapped["NormalizedProduct | None"] = relationship(
|
||||
back_populates="purchase_items"
|
||||
)
|
||||
price_history_entries: Mapped[list["PriceHistory"]] = relationship(
|
||||
back_populates="purchase_item"
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""ShrinkflationEvent model."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Date, ForeignKey, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import SizeUnit
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.product import NormalizedProduct
|
||||
|
||||
|
||||
class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Detected shrinkflation event — product size changed while price held or rose."""
|
||||
|
||||
__tablename__ = "shrinkflation_events"
|
||||
|
||||
normalized_product_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("normalized_products.id"), nullable=False
|
||||
)
|
||||
detected_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
old_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
new_size: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
old_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
new_unit: Mapped[SizeUnit] = mapped_column(String(10), nullable=False)
|
||||
price_at_old_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
price_at_new_size: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
confidence: Mapped[Decimal] = mapped_column(
|
||||
Numeric(3, 2), nullable=False, default=Decimal("1.00")
|
||||
)
|
||||
notes: Mapped[str | None] = mapped_column(String(1000))
|
||||
|
||||
# Relationships
|
||||
normalized_product: Mapped["NormalizedProduct"] = relationship(
|
||||
back_populates="shrinkflation_events"
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Store and StoreLocation models."""
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Float, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import StoreSlug
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.coupon import Coupon
|
||||
from cartsnitch_api.models.price import PriceHistory
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.user import UserStoreAccount
|
||||
|
||||
|
||||
class Store(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Supported retailer."""
|
||||
|
||||
__tablename__ = "stores"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
slug: Mapped[StoreSlug] = mapped_column(String(20), nullable=False, unique=True)
|
||||
logo_url: Mapped[str | None] = mapped_column(String(500))
|
||||
website_url: Mapped[str | None] = mapped_column(String(500))
|
||||
|
||||
# Relationships
|
||||
locations: Mapped[list["StoreLocation"]] = relationship(back_populates="store")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store")
|
||||
user_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="store")
|
||||
price_histories: Mapped[list["PriceHistory"]] = relationship(back_populates="store")
|
||||
coupons: Mapped[list["Coupon"]] = relationship(back_populates="store")
|
||||
|
||||
|
||||
class StoreLocation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Physical store location."""
|
||||
|
||||
__tablename__ = "store_locations"
|
||||
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
address: Mapped[str] = mapped_column(String(300), nullable=False)
|
||||
city: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
state: Mapped[str] = mapped_column(String(2), nullable=False)
|
||||
zip: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
lat: Mapped[float | None] = mapped_column(Float)
|
||||
lng: Mapped[float | None] = mapped_column(Float)
|
||||
|
||||
# Relationships
|
||||
store: Mapped["Store"] = relationship(back_populates="locations")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="store_location")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""User and UserStoreAccount models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_api.constants import AccountStatus
|
||||
from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
from cartsnitch_api.types import EncryptedJSON
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cartsnitch_api.models.purchase import Purchase
|
||||
from cartsnitch_api.models.store import Store
|
||||
|
||||
|
||||
class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Application user."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(String(100))
|
||||
|
||||
# Relationships
|
||||
store_accounts: Mapped[list["UserStoreAccount"]] = relationship(back_populates="user")
|
||||
purchases: Mapped[list["Purchase"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserStoreAccount(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Link between a user and their retailer account credentials."""
|
||||
|
||||
__tablename__ = "user_store_accounts"
|
||||
__table_args__ = (UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),)
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
store_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("stores.id"), nullable=False)
|
||||
session_data: Mapped[dict | None] = mapped_column(EncryptedJSON)
|
||||
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
last_sync_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
status: Mapped[AccountStatus] = mapped_column(
|
||||
String(20), nullable=False, default=AccountStatus.ACTIVE
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="store_accounts")
|
||||
store: Mapped["Store"] = relationship(back_populates="user_accounts")
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Alert routes: list alerts, manage settings."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import AlertResponse, AlertSettingsRequest, AlertSettingsResponse
|
||||
from cartsnitch_api.services.alerts import AlertService
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[AlertResponse])
|
||||
async def list_alerts(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.list_alerts(user_id)
|
||||
|
||||
|
||||
@router.get("/settings", response_model=AlertSettingsResponse)
|
||||
async def get_alert_settings(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = AlertService(db)
|
||||
return await svc.get_settings(user_id)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_alert_settings(
|
||||
body: AlertSettingsRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Alert settings persistence not yet implemented. "
|
||||
"Use GET /alerts/settings for current defaults.",
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Coupon routes: browse, relevant matches."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import CouponResponse
|
||||
from cartsnitch_api.services.coupons import CouponService
|
||||
|
||||
router = APIRouter(prefix="/coupons", tags=["coupons"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[CouponResponse])
|
||||
async def list_coupons(
|
||||
store_id: UUID | None = Query(None),
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.list_coupons(store_id)
|
||||
|
||||
|
||||
@router.get("/relevant", response_model=list[CouponResponse])
|
||||
async def relevant_coupons(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CouponService(db)
|
||||
return await svc.relevant_coupons(user_id)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Health check and error metrics endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from cartsnitch_api.auth.dependencies import verify_service_key
|
||||
from cartsnitch_api.middleware.error_handler import get_error_monitor
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
|
||||
async def error_stats():
|
||||
"""Error monitoring stats — internal only (requires X-Service-Key)."""
|
||||
monitor = get_error_monitor()
|
||||
return monitor.get_stats()
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Price routes: trends, increases, comparison."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PriceComparisonResponse,
|
||||
PriceIncreaseResponse,
|
||||
PriceTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.prices import PriceService
|
||||
|
||||
router = APIRouter(prefix="/prices", tags=["prices"])
|
||||
|
||||
|
||||
@router.get("/trends", response_model=list[PriceTrendResponse])
|
||||
async def price_trends(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
category: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_trends(category)
|
||||
|
||||
|
||||
@router.get("/increases", response_model=list[PriceIncreaseResponse])
|
||||
async def price_increases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_increases()
|
||||
|
||||
|
||||
@router.get("/comparison", response_model=list[PriceComparisonResponse])
|
||||
async def price_comparison(
|
||||
product_ids: Annotated[list[UUID], Query()],
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PriceService(db)
|
||||
return await svc.get_comparison(product_ids)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Product routes: search/list, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PriceTrendResponse, ProductDetailResponse, ProductResponse
|
||||
from cartsnitch_api.services.products import ProductService
|
||||
|
||||
router = APIRouter(prefix="/products", tags=["products"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ProductResponse])
|
||||
async def list_products(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
q: str | None = Query(None),
|
||||
category: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
return await svc.list_products(q, category, page, page_size)
|
||||
|
||||
|
||||
@router.get("/{product_id}", response_model=ProductDetailResponse)
|
||||
async def get_product(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_product(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/{product_id}/prices", response_model=PriceTrendResponse)
|
||||
async def get_product_prices(
|
||||
product_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = ProductService(db)
|
||||
try:
|
||||
return await svc.get_price_history(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Public endpoints: price transparency data (no auth required)."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import (
|
||||
PublicInflationResponse,
|
||||
PublicStoreComparisonResponse,
|
||||
PublicTrendResponse,
|
||||
)
|
||||
from cartsnitch_api.services.public import PublicService
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
try:
|
||||
return await svc.get_trend(product_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||
async def public_store_comparison(
|
||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not product_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one product_id is required",
|
||||
)
|
||||
svc = PublicService(db)
|
||||
return await svc.get_store_comparison(product_ids)
|
||||
|
||||
|
||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||
svc = PublicService(db)
|
||||
return await svc.get_inflation()
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Purchase routes: list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import PurchaseDetailResponse, PurchaseResponse, PurchaseStatsResponse
|
||||
from cartsnitch_api.services.purchases import PurchaseService
|
||||
|
||||
router = APIRouter(prefix="/purchases", tags=["purchases"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[PurchaseResponse])
|
||||
async def list_purchases(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
store_id: UUID | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.list_purchases(user_id, store_id, page, page_size)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=PurchaseStatsResponse)
|
||||
async def purchase_stats(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
return await svc.get_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
|
||||
async def get_purchase(
|
||||
purchase_id: UUID,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PurchaseService(db)
|
||||
try:
|
||||
return await svc.get_purchase(purchase_id, user_id)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Purchase not found"
|
||||
) from None
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import SyncStatusResponse, SyncTriggerResponse
|
||||
from cartsnitch_api.services.receiptwitness import ReceiptWitnessClient
|
||||
|
||||
router = APIRouter(prefix="/scraping", tags=["scraping"])
|
||||
|
||||
|
||||
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
|
||||
async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
result = await client.trigger_sync(str(user_id), store_slug)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Sync service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/status", response_model=list[SyncStatusResponse])
|
||||
async def sync_status(user_id: UUID = Depends(get_current_user)):
|
||||
client = ReceiptWitnessClient()
|
||||
try:
|
||||
return await client.get_sync_status(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach sync service",
|
||||
) from None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Shopping routes: optimize list, saved lists."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from httpx import HTTPStatusError, RequestError
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.schemas import OptimizeRequest, OptimizeResponse, ShoppingListResponse
|
||||
from cartsnitch_api.services.clipartist import ClipArtistClient
|
||||
|
||||
router = APIRouter(prefix="/shopping", tags=["shopping"])
|
||||
|
||||
|
||||
@router.post("/optimize", response_model=OptimizeResponse)
|
||||
async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
result = await client.optimize(
|
||||
user_id=str(user_id),
|
||||
items=[item.model_dump() for item in body.items],
|
||||
preferred_stores=(
|
||||
[str(s) for s in body.preferred_stores] if body.preferred_stores else None
|
||||
),
|
||||
)
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.response.status_code,
|
||||
detail="Shopping optimization service error",
|
||||
) from e
|
||||
except RequestError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping optimization service",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/lists", response_model=list[ShoppingListResponse])
|
||||
async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
|
||||
client = ClipArtistClient()
|
||||
try:
|
||||
return await client.get_shopping_lists(str(user_id))
|
||||
except (HTTPStatusError, RequestError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Unable to reach shopping service",
|
||||
) from None
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Store routes: list stores, manage user store connections."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cartsnitch_api.auth.dependencies import get_current_user
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.schemas import ConnectStoreRequest, StoreAccountResponse, StoreResponse
|
||||
from cartsnitch_api.services.stores import StoreService
|
||||
|
||||
router = APIRouter(tags=["stores"])
|
||||
|
||||
|
||||
@router.get("/stores", response_model=list[StoreResponse])
|
||||
async def list_stores(db: AsyncSession = Depends(get_db)):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_stores()
|
||||
|
||||
|
||||
@router.get("/me/stores", response_model=list[StoreAccountResponse])
|
||||
async def list_user_stores(
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
return await svc.list_user_stores(user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/me/stores/{store_slug}/connect",
|
||||
response_model=StoreAccountResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def connect_store(
|
||||
store_slug: str,
|
||||
body: ConnectStoreRequest,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
return await svc.connect_store(user_id, store_slug, body.credentials)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def disconnect_store(
|
||||
store_slug: str,
|
||||
user_id: UUID = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = StoreService(db)
|
||||
try:
|
||||
await svc.disconnect_store(user_id, store_slug)
|
||||
except LookupError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Pydantic v2 request/response schemas for all API endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
# ---------- Auth ----------
|
||||
# Registration, login, and session management are handled by Better-Auth (auth/ service).
|
||||
# These schemas are for the profile management endpoints only.
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
email: EmailStr | None = None
|
||||
display_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
display_name: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------- Stores ----------
|
||||
|
||||
|
||||
class StoreResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
logo_url: str | None = None
|
||||
supported: bool = True
|
||||
|
||||
|
||||
class StoreAccountResponse(BaseModel):
|
||||
store: StoreResponse
|
||||
connected: bool
|
||||
last_sync_at: datetime | None = None
|
||||
sync_status: str | None = None
|
||||
|
||||
|
||||
class ConnectStoreRequest(BaseModel):
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
# ---------- Purchases ----------
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
id: UUID
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: float
|
||||
unit_price: float
|
||||
total_price: float
|
||||
|
||||
|
||||
class PurchaseResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
purchased_at: datetime
|
||||
total: float
|
||||
item_count: int
|
||||
|
||||
|
||||
class PurchaseDetailResponse(PurchaseResponse):
|
||||
line_items: list[LineItemResponse]
|
||||
|
||||
|
||||
class PurchaseStatsResponse(BaseModel):
|
||||
total_spent: float
|
||||
purchase_count: int
|
||||
by_store: dict[str, float]
|
||||
by_period: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Products ----------
|
||||
|
||||
|
||||
class ProductResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brand: str | None = None
|
||||
category: str | None = None
|
||||
upc: str | None = None
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class ProductDetailResponse(ProductResponse):
|
||||
prices_by_store: list["StorePriceResponse"]
|
||||
|
||||
|
||||
class StorePriceResponse(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
current_price: float
|
||||
last_seen_at: datetime
|
||||
|
||||
|
||||
# ---------- Prices ----------
|
||||
|
||||
|
||||
class PriceTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list["PricePointResponse"]
|
||||
|
||||
|
||||
class PricePointResponse(BaseModel):
|
||||
date: datetime
|
||||
price: float
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
|
||||
|
||||
class PriceIncreaseResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
store_name: str
|
||||
old_price: float
|
||||
new_price: float
|
||||
increase_pct: float
|
||||
detected_at: datetime
|
||||
|
||||
|
||||
class PriceComparisonResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
prices: list[StorePriceResponse]
|
||||
|
||||
|
||||
# ---------- Coupons ----------
|
||||
|
||||
|
||||
class CouponResponse(BaseModel):
|
||||
id: UUID
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
description: str
|
||||
discount_value: float
|
||||
discount_type: str
|
||||
product_id: UUID | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
# ---------- Shopping ----------
|
||||
|
||||
|
||||
class ShoppingListItemRequest(BaseModel):
|
||||
product_id: UUID | None = None
|
||||
name: str
|
||||
quantity: int = 1
|
||||
|
||||
|
||||
class OptimizeRequest(BaseModel):
|
||||
items: list[ShoppingListItemRequest]
|
||||
preferred_stores: list[UUID] | None = None
|
||||
|
||||
|
||||
class OptimizedStoreTrip(BaseModel):
|
||||
store_id: UUID
|
||||
store_name: str
|
||||
items: list["OptimizedItemResponse"]
|
||||
subtotal: float
|
||||
coupons: list[CouponResponse]
|
||||
savings: float
|
||||
|
||||
|
||||
class OptimizedItemResponse(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
product_id: UUID | None = None
|
||||
|
||||
|
||||
class OptimizeResponse(BaseModel):
|
||||
trips: list[OptimizedStoreTrip]
|
||||
total_cost: float
|
||||
total_savings: float
|
||||
|
||||
|
||||
class ShoppingListResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
item_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------- Alerts ----------
|
||||
|
||||
|
||||
class AlertResponse(BaseModel):
|
||||
id: UUID
|
||||
alert_type: str
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
message: str
|
||||
triggered_at: datetime
|
||||
read: bool = False
|
||||
|
||||
|
||||
class AlertSettingsRequest(BaseModel):
|
||||
price_increase_threshold_pct: float | None = None
|
||||
shrinkflation_enabled: bool | None = None
|
||||
email_notifications: bool | None = None
|
||||
|
||||
|
||||
class AlertSettingsResponse(BaseModel):
|
||||
price_increase_threshold_pct: float
|
||||
shrinkflation_enabled: bool
|
||||
email_notifications: bool
|
||||
|
||||
|
||||
# ---------- Scraping ----------
|
||||
|
||||
|
||||
class SyncTriggerResponse(BaseModel):
|
||||
job_id: UUID
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class SyncStatusResponse(BaseModel):
|
||||
store_slug: str
|
||||
status: str
|
||||
last_sync_at: datetime | None = None
|
||||
items_synced: int | None = None
|
||||
|
||||
|
||||
# ---------- Public ----------
|
||||
|
||||
|
||||
class PublicTrendResponse(BaseModel):
|
||||
product_id: UUID
|
||||
product_name: str
|
||||
data_points: list[PricePointResponse]
|
||||
|
||||
|
||||
class PublicStoreComparisonResponse(BaseModel):
|
||||
products: list[PriceComparisonResponse]
|
||||
|
||||
|
||||
class PublicInflationResponse(BaseModel):
|
||||
period: str
|
||||
cartsnitch_index: float
|
||||
cpi_baseline: float
|
||||
categories: dict[str, float]
|
||||
|
||||
|
||||
# ---------- Common ----------
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: str
|
||||
code: str | None = None
|
||||
|
||||
|
||||
# Rebuild forward refs
|
||||
ProductDetailResponse.model_rebuild()
|
||||
PriceTrendResponse.model_rebuild()
|
||||
OptimizedStoreTrip.model_rebuild()
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Alert service — price and shrinkflation alerts for users.
|
||||
|
||||
Alerts are generated by StickerShock and ShrinkRay services and written to the DB.
|
||||
This service reads them for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class AlertService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_alerts(self, user_id: UUID) -> list[dict]:
|
||||
"""List shrinkflation events for products the user has purchased."""
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
|
||||
|
||||
# Get product IDs from user's purchases
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(ShrinkflationEvent)
|
||||
.where(ShrinkflationEvent.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(ShrinkflationEvent.normalized_product))
|
||||
.order_by(ShrinkflationEvent.detected_date.desc())
|
||||
)
|
||||
events = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"alert_type": "shrinkflation",
|
||||
"product_id": e.normalized_product_id,
|
||||
"product_name": e.normalized_product.canonical_name,
|
||||
"message": (
|
||||
f"Size changed from {e.old_size}{e.old_unit} to {e.new_size}{e.new_unit}"
|
||||
),
|
||||
"triggered_at": e.detected_date,
|
||||
"read": False,
|
||||
}
|
||||
for e in events
|
||||
]
|
||||
|
||||
async def get_settings(self, user_id: UUID) -> dict:
|
||||
# Alert settings would be stored in a user_settings table.
|
||||
# For now, return defaults since the table doesn't exist yet in common lib.
|
||||
return {
|
||||
"price_increase_threshold_pct": 5.0,
|
||||
"shrinkflation_enabled": True,
|
||||
"email_notifications": False,
|
||||
}
|
||||
|
||||
async def update_settings(self, user_id: UUID, **fields) -> dict:
|
||||
# Would update user_settings table. Return merged defaults for now.
|
||||
current = await self.get_settings(user_id)
|
||||
for k, v in fields.items():
|
||||
if v is not None and k in current:
|
||||
current[k] = v
|
||||
return current
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Auth service — user profile management.
|
||||
|
||||
Registration, login, token management, and session handling are now
|
||||
handled by the Better-Auth service (auth/). This service provides
|
||||
user lookup and profile update operations for the API gateway.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_user(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def update_user(self, user_id: UUID, **fields) -> dict:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
if "display_name" in fields and fields["display_name"] is not None:
|
||||
user.display_name = fields["display_name"]
|
||||
if "email" in fields and fields["email"] is not None:
|
||||
existing = await self.db.execute(
|
||||
select(User).where(User.email == fields["email"], User.id != user_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Email already in use")
|
||||
user.email = fields["email"]
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
async def delete_user(self, user_id: UUID) -> None:
|
||||
from cartsnitch_api.models import User
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise LookupError("User not found")
|
||||
|
||||
await self.db.delete(user)
|
||||
await self.db.commit()
|
||||
@@ -0,0 +1,52 @@
|
||||
"""HTTP client for ClipArtist internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ClipArtistClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.clipartist_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def optimize(
|
||||
self,
|
||||
user_id: str,
|
||||
items: list[dict],
|
||||
preferred_stores: list[str] | None = None,
|
||||
) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/optimize",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"items": items,
|
||||
"preferred_stores": preferred_stores,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_shopping_lists(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/shopping-lists",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_relevant_coupons(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/coupons/relevant",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Coupon service — browse coupons, find relevant ones."""
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class CouponService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_coupons(self, store_id: UUID | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import Coupon
|
||||
|
||||
today = date.today()
|
||||
query = (
|
||||
select(Coupon)
|
||||
.where((Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)))
|
||||
.options(selectinload(Coupon.store))
|
||||
.order_by(Coupon.valid_to.asc().nullslast())
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Coupon.store_id == store_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
async def relevant_coupons(self, user_id: UUID) -> list[dict]:
|
||||
"""Coupons for products the user has purchased."""
|
||||
from cartsnitch_api.models import Coupon, PurchaseItem
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Get product IDs from user's purchase history
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
items_result = await self.db.execute(
|
||||
select(PurchaseItem.normalized_product_id)
|
||||
.join(Purchase)
|
||||
.where(
|
||||
Purchase.user_id == user_id,
|
||||
PurchaseItem.normalized_product_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
product_ids = [row[0] for row in items_result.all()]
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Coupon)
|
||||
.where(
|
||||
Coupon.normalized_product_id.in_(product_ids),
|
||||
(Coupon.valid_to >= today) | (Coupon.valid_to.is_(None)),
|
||||
)
|
||||
.options(selectinload(Coupon.store))
|
||||
)
|
||||
coupons = result.scalars().all()
|
||||
return [self._to_dict(c) for c in coupons]
|
||||
|
||||
def _to_dict(self, c) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"store_id": c.store_id,
|
||||
"store_name": c.store.name,
|
||||
"description": c.description or c.title,
|
||||
"discount_value": float(c.discount_value) if c.discount_value else 0,
|
||||
"discount_type": c.discount_type,
|
||||
"product_id": c.normalized_product_id,
|
||||
"expires_at": c.valid_to,
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Price service — trends, increases, comparison."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PriceService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trends(self, category: str | None = None) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
query = (
|
||||
select(PriceHistory)
|
||||
.join(NormalizedProduct)
|
||||
.options(
|
||||
selectinload(PriceHistory.store),
|
||||
selectinload(PriceHistory.normalized_product),
|
||||
)
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
prices = result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
by_product: dict[UUID, dict] = {}
|
||||
for ph in prices:
|
||||
pid = ph.normalized_product_id
|
||||
if pid not in by_product:
|
||||
by_product[pid] = {
|
||||
"product_id": pid,
|
||||
"product_name": ph.normalized_product.canonical_name,
|
||||
"data_points": [],
|
||||
}
|
||||
by_product[pid]["data_points"].append(
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
)
|
||||
return list(by_product.values())
|
||||
|
||||
async def get_increases(self) -> list[dict]:
|
||||
"""Find products with recent significant price increases.
|
||||
|
||||
Uses a window function (lag) to compare each price observation with the
|
||||
previous one per product+store, avoiding the N+1 query pattern.
|
||||
"""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
# Use lag() window function to get previous price in a single query
|
||||
prev_price = (
|
||||
func.lag(PriceHistory.regular_price)
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date,
|
||||
)
|
||||
.label("prev_price")
|
||||
)
|
||||
|
||||
row_num = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=[PriceHistory.normalized_product_id, PriceHistory.store_id],
|
||||
order_by=PriceHistory.observed_date.desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
|
||||
inner = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
PriceHistory.regular_price,
|
||||
PriceHistory.observed_date,
|
||||
prev_price,
|
||||
row_num,
|
||||
).subquery()
|
||||
|
||||
# Only keep the latest row (rn=1) where price increased
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
inner.c.normalized_product_id,
|
||||
inner.c.store_id,
|
||||
inner.c.regular_price,
|
||||
inner.c.observed_date,
|
||||
inner.c.prev_price,
|
||||
NormalizedProduct.canonical_name,
|
||||
Store.name.label("store_name"),
|
||||
)
|
||||
.join(NormalizedProduct, NormalizedProduct.id == inner.c.normalized_product_id)
|
||||
.join(Store, Store.id == inner.c.store_id)
|
||||
.where(
|
||||
inner.c.rn == 1,
|
||||
inner.c.prev_price.isnot(None),
|
||||
inner.c.regular_price > inner.c.prev_price,
|
||||
)
|
||||
)
|
||||
|
||||
increases = []
|
||||
for row in result.all():
|
||||
old = float(row.prev_price)
|
||||
new = float(row.regular_price)
|
||||
increases.append(
|
||||
{
|
||||
"product_id": row.normalized_product_id,
|
||||
"product_name": row.canonical_name,
|
||||
"store_name": row.store_name,
|
||||
"old_price": old,
|
||||
"new_price": new,
|
||||
"increase_pct": round((new - old) / old * 100, 2),
|
||||
"detected_at": row.observed_date,
|
||||
}
|
||||
)
|
||||
|
||||
increases.sort(key=lambda x: x["increase_pct"], reverse=True)
|
||||
return increases
|
||||
|
||||
async def get_comparison(self, product_ids: list[UUID]) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return []
|
||||
|
||||
# Fetch all requested products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group prices by product
|
||||
prices_by_product: dict[UUID, list] = {pid: [] for pid in product_ids}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
comparisons = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
comparisons.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
return comparisons
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Product service — catalog, detail, price history."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class ProductService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_products(
|
||||
self,
|
||||
q: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import NormalizedProduct
|
||||
|
||||
query = select(NormalizedProduct)
|
||||
if q:
|
||||
# Escape SQL LIKE wildcards in user input
|
||||
safe_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
query = query.where(NormalizedProduct.canonical_name.ilike(f"%{safe_q}%"))
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
query = query.order_by(NormalizedProduct.canonical_name)
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
products = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.canonical_name,
|
||||
"brand": p.brand,
|
||||
"category": p.category,
|
||||
"upc": (p.upc_variants[0] if p.upc_variants else None),
|
||||
"image_url": None,
|
||||
}
|
||||
for p in products
|
||||
]
|
||||
|
||||
async def get_product(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
# Get latest price per store
|
||||
subq = latest_price_per_store([product_id])
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"id": product.id,
|
||||
"name": product.canonical_name,
|
||||
"brand": product.brand,
|
||||
"category": product.category,
|
||||
"upc": (product.upc_variants[0] if product.upc_variants else None),
|
||||
"image_url": None,
|
||||
"prices_by_store": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_price_history(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Public service — unauthenticated price transparency endpoints."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.services.queries import latest_price_per_store
|
||||
|
||||
|
||||
class PublicService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trend(self, product_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id == product_id)
|
||||
)
|
||||
product = result.scalar_one_or_none()
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
prices = prices_result.scalars().all()
|
||||
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.canonical_name,
|
||||
"data_points": [
|
||||
{
|
||||
"date": ph.observed_date,
|
||||
"price": float(ph.regular_price),
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
}
|
||||
for ph in prices
|
||||
],
|
||||
}
|
||||
|
||||
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return {"products": []}
|
||||
|
||||
# Fetch all products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
subq,
|
||||
and_(
|
||||
PriceHistory.store_id == subq.c.store_id,
|
||||
PriceHistory.observed_date == subq.c.max_date,
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
prices_by_product: dict[UUID, list] = {}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
products = []
|
||||
for pid in product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
products.append(
|
||||
{
|
||||
"product_id": pid,
|
||||
"product_name": product.canonical_name,
|
||||
"prices": [
|
||||
{
|
||||
"store_id": ph.store_id,
|
||||
"store_name": ph.store.name,
|
||||
"current_price": float(ph.regular_price),
|
||||
"last_seen_at": ph.observed_date,
|
||||
}
|
||||
for ph in prices_by_product.get(pid, [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return {"products": products}
|
||||
|
||||
async def get_inflation(self) -> dict:
|
||||
"""Aggregate price change stats. Compares average prices across periods."""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
# Get average prices grouped by category for recent vs older data
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
)
|
||||
.join(NormalizedProduct)
|
||||
.group_by(NormalizedProduct.category)
|
||||
)
|
||||
categories = {}
|
||||
for row in result.all():
|
||||
cat, avg_price = row
|
||||
if cat:
|
||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||
|
||||
return {
|
||||
"period": "all-time",
|
||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||
"cpi_baseline": 100.0,
|
||||
"categories": categories,
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Purchase service — list, detail, stats."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
||||
class PurchaseService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_purchases(
|
||||
self,
|
||||
user_id: UUID,
|
||||
store_id: UUID | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[dict]:
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store
|
||||
|
||||
# Count items per purchase in a single subquery instead of N+1
|
||||
item_counts = (
|
||||
select(
|
||||
PurchaseItem.purchase_id,
|
||||
func.count().label("item_count"),
|
||||
)
|
||||
.group_by(PurchaseItem.purchase_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(Purchase, item_counts.c.item_count, Store.name.label("store_name"))
|
||||
.join(Store, Store.id == Purchase.store_id)
|
||||
.outerjoin(item_counts, item_counts.c.purchase_id == Purchase.id)
|
||||
.where(Purchase.user_id == user_id)
|
||||
)
|
||||
if store_id:
|
||||
query = query.where(Purchase.store_id == store_id)
|
||||
|
||||
query = query.order_by(Purchase.purchase_date.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"store_id": p.store_id,
|
||||
"store_name": store_name,
|
||||
"purchased_at": p.purchase_date,
|
||||
"total": float(p.total),
|
||||
"item_count": item_count or 0,
|
||||
}
|
||||
for p, item_count, store_name in result.all()
|
||||
]
|
||||
|
||||
async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.id == purchase_id, Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store), selectinload(Purchase.items))
|
||||
)
|
||||
purchase = result.scalar_one_or_none()
|
||||
if not purchase:
|
||||
raise LookupError("Purchase not found")
|
||||
|
||||
return {
|
||||
"id": purchase.id,
|
||||
"store_id": purchase.store_id,
|
||||
"store_name": purchase.store.name,
|
||||
"purchased_at": purchase.purchase_date,
|
||||
"total": float(purchase.total),
|
||||
"item_count": len(purchase.items),
|
||||
"line_items": [
|
||||
{
|
||||
"id": item.id,
|
||||
"product_id": item.normalized_product_id,
|
||||
"name": item.product_name_raw,
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"total_price": float(item.extended_price),
|
||||
}
|
||||
for item in purchase.items
|
||||
],
|
||||
}
|
||||
|
||||
async def get_stats(self, user_id: UUID) -> dict:
|
||||
from cartsnitch_api.models import Purchase
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Purchase)
|
||||
.where(Purchase.user_id == user_id)
|
||||
.options(selectinload(Purchase.store))
|
||||
)
|
||||
purchases = result.scalars().all()
|
||||
|
||||
total_spent = sum(float(p.total) for p in purchases)
|
||||
by_store: dict[str, float] = {}
|
||||
by_period: dict[str, float] = {}
|
||||
|
||||
for p in purchases:
|
||||
store_name = p.store.name
|
||||
by_store[store_name] = by_store.get(store_name, 0) + float(p.total)
|
||||
period = p.purchase_date.strftime("%Y-%m")
|
||||
by_period[period] = by_period.get(period, 0) + float(p.total)
|
||||
|
||||
return {
|
||||
"total_spent": total_spent,
|
||||
"purchase_count": len(purchases),
|
||||
"by_store": by_store,
|
||||
"by_period": by_period,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Shared query helpers for service layer."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
|
||||
def latest_price_per_store(product_ids: list[UUID] | None = None):
|
||||
"""Subquery returning the latest observed_date per product+store.
|
||||
|
||||
Optionally filtered to a list of product IDs. Returns a subquery with
|
||||
columns: normalized_product_id, store_id, max_date.
|
||||
"""
|
||||
from cartsnitch_api.models import PriceHistory
|
||||
|
||||
query = select(
|
||||
PriceHistory.normalized_product_id,
|
||||
PriceHistory.store_id,
|
||||
func.max(PriceHistory.observed_date).label("max_date"),
|
||||
).group_by(PriceHistory.normalized_product_id, PriceHistory.store_id)
|
||||
if product_ids is not None:
|
||||
query = query.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
return query.subquery()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""HTTP client for ReceiptWitness internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ReceiptWitnessClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.receiptwitness_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def trigger_sync(self, user_id: str, store_slug: str) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/sync/{store_slug}",
|
||||
headers=self.headers,
|
||||
json={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_sync_status(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/sync/status",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,23 @@
|
||||
"""HTTP client for ShrinkRay internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class ShrinkRayClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.shrinkray_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_shrinkflation_alerts(self, user_id: str) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/alerts",
|
||||
headers=self.headers,
|
||||
params={"user_id": user_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
@@ -0,0 +1,32 @@
|
||||
"""HTTP client for StickerShock internal API."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class StickerShockClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.stickershock_url
|
||||
self.headers = {"X-Service-Key": settings.service_key}
|
||||
|
||||
async def get_price_increases(self, params: dict | None = None) -> list[dict]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/increases",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(list[dict[str, Any]], resp.json())
|
||||
|
||||
async def get_inflation_data(self) -> dict:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/inflation",
|
||||
headers=self.headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Store service — list stores, manage user store account connections."""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class StoreService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def list_stores(self) -> list[dict]:
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
result = await self.db.execute(select(Store).order_by(Store.name))
|
||||
stores = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"slug": s.slug,
|
||||
"logo_url": s.logo_url,
|
||||
"supported": True,
|
||||
}
|
||||
for s in stores
|
||||
]
|
||||
|
||||
async def list_user_stores(self, user_id: UUID) -> list[dict]:
|
||||
from cartsnitch_api.models import UserStoreAccount
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount)
|
||||
.where(UserStoreAccount.user_id == user_id)
|
||||
.options(selectinload(UserStoreAccount.store))
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"store": {
|
||||
"id": a.store.id,
|
||||
"name": a.store.name,
|
||||
"slug": a.store.slug,
|
||||
"logo_url": a.store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": a.status == "active",
|
||||
"last_sync_at": a.last_sync_at,
|
||||
"sync_status": a.status,
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Store account already connected")
|
||||
|
||||
encrypted_data = None
|
||||
if credentials:
|
||||
fernet = _get_fernet()
|
||||
encrypted_data = {
|
||||
"encrypted": fernet.encrypt(json.dumps(credentials).encode()).decode()
|
||||
}
|
||||
|
||||
account = UserStoreAccount(
|
||||
user_id=user_id,
|
||||
store_id=store.id,
|
||||
session_data=encrypted_data,
|
||||
status="active",
|
||||
)
|
||||
self.db.add(account)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(account)
|
||||
|
||||
return {
|
||||
"store": {
|
||||
"id": store.id,
|
||||
"name": store.name,
|
||||
"slug": store.slug,
|
||||
"logo_url": store.logo_url,
|
||||
"supported": True,
|
||||
},
|
||||
"connected": True,
|
||||
"last_sync_at": None,
|
||||
"sync_status": "active",
|
||||
}
|
||||
|
||||
async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
|
||||
from cartsnitch_api.models import Store, UserStoreAccount
|
||||
|
||||
result = await self.db.execute(select(Store).where(Store.slug == store_slug))
|
||||
store = result.scalar_one_or_none()
|
||||
if not store:
|
||||
raise LookupError(f"Store '{store_slug}' not found")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(UserStoreAccount).where(
|
||||
UserStoreAccount.user_id == user_id,
|
||||
UserStoreAccount.store_id == store.id,
|
||||
)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise LookupError("Store account not connected")
|
||||
|
||||
await self.db.delete(account)
|
||||
await self.db.commit()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Custom SQLAlchemy column types."""
|
||||
|
||||
import json
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
return Fernet(settings.fernet_key.encode())
|
||||
|
||||
|
||||
class EncryptedJSON(TypeDecorator):
|
||||
"""SQLAlchemy type that transparently encrypts/decrypts JSON using Fernet.
|
||||
|
||||
Stores data as a Fernet-encrypted text blob in the database.
|
||||
On read, decrypts and deserialises back to a Python dict/list.
|
||||
"""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
plaintext = json.dumps(value).encode()
|
||||
return _get_fernet().encrypt(plaintext).decode()
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
decrypted = _get_fernet().decrypt(value.encode())
|
||||
return json.loads(decrypted)
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Shared test fixtures with in-memory SQLite database.
|
||||
|
||||
Session-based auth: tests create users and sessions directly in the DB,
|
||||
matching the Better-Auth session validation flow.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.main import create_app
|
||||
from cartsnitch_api.models import Base
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limiting():
|
||||
"""Disable rate limiting for all tests to prevent 429 interference."""
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
"""Sync in-memory SQLite engine for model unit tests."""
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
"""Sync SQLAlchemy session for model unit tests."""
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
user_id TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
access_token_expires_at TIMESTAMP,
|
||||
refresh_token_expires_at TIMESTAMP,
|
||||
scope TEXT,
|
||||
id_token TEXT,
|
||||
password TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS verifications (
|
||||
id TEXT PRIMARY KEY,
|
||||
identifier TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
|
||||
yield engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async def override_get_db():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||
"""Create a test user and a valid session directly in the DB.
|
||||
|
||||
Returns (user_dict, session_token).
|
||||
"""
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_overrides.get("email", "test@example.com")
|
||||
display_name = user_overrides.get("display_name", "Test User")
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
session_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": email,
|
||||
"hashed_password": "not-used-with-better-auth",
|
||||
"display_name": display_name,
|
||||
"email_verified": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": session_id,
|
||||
"token": session_token,
|
||||
"user_id": user_id,
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
return {"id": user_id, "email": email, "display_name": display_name}, session_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client, db_engine):
|
||||
"""Create a test user with a valid session and return auth headers."""
|
||||
_, session_token = await _create_test_user_and_session(client, db_engine)
|
||||
return {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Integration tests for auth profile endpoints.
|
||||
|
||||
Registration, login, and session management are handled by the Better-Auth
|
||||
service. These tests cover the profile endpoints (GET/PATCH/DELETE /auth/me)
|
||||
which validate sessions via the shared sessions table.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me(client, auth_headers):
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == "test@example.com"
|
||||
assert data["display_name"] == "Test User"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_unauthorized(client):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_invalid_session(client):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=invalid-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_with_bearer_token(client, db_engine):
|
||||
"""Session tokens can also be passed as Bearer tokens for API clients."""
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@example.com", display_name="Bearer User"
|
||||
)
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me(client, auth_headers):
|
||||
resp = await client.patch(
|
||||
"/auth/me",
|
||||
headers=auth_headers,
|
||||
json={"display_name": "Updated Name"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["display_name"] == "Updated Name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_me(client, auth_headers):
|
||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Session is still valid but user is gone
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session_rejected(client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@example.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"ev": False,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Shared fixtures for E2E integration tests.
|
||||
|
||||
Seeds a realistic dataset with stores, products, price history,
|
||||
purchases, coupons, and shrinkflation events so E2E flows can
|
||||
exercise cross-resource queries against real data.
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
)
|
||||
|
||||
# Shared test constants
|
||||
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
BAD_UUID = "not-a-uuid"
|
||||
# Fixed anchor date for deterministic tests
|
||||
ANCHOR_DATE = date(2026, 3, 15)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seed_data(db_engine, auth_headers):
|
||||
"""Seed a full dataset and return identifiers for test assertions."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
# -- Stores --
|
||||
meijer = Store(name="Meijer", slug="meijer")
|
||||
kroger = Store(name="Kroger", slug="kroger")
|
||||
target = Store(name="Target", slug="target")
|
||||
session.add_all([meijer, kroger, target])
|
||||
await session.flush()
|
||||
|
||||
# -- Products --
|
||||
cheerios = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
size="18",
|
||||
size_unit="oz",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
milk = NormalizedProduct(
|
||||
canonical_name="Whole Milk 1gal",
|
||||
category="dairy",
|
||||
brand="Meijer",
|
||||
size="1",
|
||||
size_unit="gal",
|
||||
)
|
||||
chicken = NormalizedProduct(
|
||||
canonical_name="Chicken Breast 1lb",
|
||||
category="meat",
|
||||
brand=None,
|
||||
size="1",
|
||||
size_unit="lb",
|
||||
)
|
||||
session.add_all([cheerios, milk, chicken])
|
||||
await session.flush()
|
||||
|
||||
# -- Price history (multiple dates, multiple stores) --
|
||||
today = ANCHOR_DATE
|
||||
prices = []
|
||||
# Cheerios at Meijer: price increase over time
|
||||
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=price_val,
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Cheerios at Kroger: stable price
|
||||
for i in range(3):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=Decimal("4.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Milk at Meijer
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=7),
|
||||
regular_price=Decimal("3.29"),
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Milk at Kroger
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=5),
|
||||
regular_price=Decimal("3.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Chicken at Target
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=chicken.id,
|
||||
store_id=target.id,
|
||||
observed_date=today - timedelta(days=3),
|
||||
regular_price=Decimal("5.99"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
session.add_all(prices)
|
||||
await session.flush()
|
||||
|
||||
# -- Get the user_id from the session token in auth_headers --
|
||||
cookie_str = auth_headers.get("Cookie", "")
|
||||
session_token = cookie_str.split("=", 1)[1] if "=" in cookie_str else ""
|
||||
|
||||
result = await session.execute(
|
||||
text("SELECT user_id FROM sessions WHERE token = :token"),
|
||||
{"token": session_token},
|
||||
)
|
||||
row = result.first()
|
||||
user_id = UUID(row[0])
|
||||
|
||||
purchase1 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=meijer.id,
|
||||
receipt_id="meijer-2026-001",
|
||||
purchase_date=today - timedelta(days=10),
|
||||
total=Decimal("23.45"),
|
||||
subtotal=Decimal("21.50"),
|
||||
tax=Decimal("1.95"),
|
||||
)
|
||||
purchase2 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=kroger.id,
|
||||
receipt_id="kroger-2026-001",
|
||||
purchase_date=today - timedelta(days=5),
|
||||
total=Decimal("15.78"),
|
||||
subtotal=Decimal("14.50"),
|
||||
tax=Decimal("1.28"),
|
||||
)
|
||||
session.add_all([purchase1, purchase2])
|
||||
await session.flush()
|
||||
|
||||
# -- Purchase Items --
|
||||
item1 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Cheerios 18oz Box",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.79"),
|
||||
extended_price=Decimal("4.79"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
item2 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Meijer Whole Milk 1gal",
|
||||
quantity=Decimal("2"),
|
||||
unit_price=Decimal("3.29"),
|
||||
extended_price=Decimal("6.58"),
|
||||
normalized_product_id=milk.id,
|
||||
)
|
||||
item3 = PurchaseItem(
|
||||
purchase_id=purchase2.id,
|
||||
product_name_raw="KRO CHEERIOS 18OZ",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.49"),
|
||||
extended_price=Decimal("4.49"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
session.add_all([item1, item2, item3])
|
||||
await session.flush()
|
||||
|
||||
# -- Coupons --
|
||||
coupon1 = Coupon(
|
||||
store_id=meijer.id,
|
||||
normalized_product_id=cheerios.id,
|
||||
title="$1 off Cheerios",
|
||||
description="Save $1 on any Cheerios 18oz or larger",
|
||||
discount_type="fixed",
|
||||
discount_value=Decimal("1.00"),
|
||||
valid_from=today - timedelta(days=7),
|
||||
valid_to=today + timedelta(days=30),
|
||||
)
|
||||
coupon2 = Coupon(
|
||||
store_id=kroger.id,
|
||||
normalized_product_id=None,
|
||||
title="10% off dairy",
|
||||
description="10% off all dairy products",
|
||||
discount_type="percent",
|
||||
discount_value=Decimal("10.00"),
|
||||
valid_from=today - timedelta(days=3),
|
||||
valid_to=today + timedelta(days=14),
|
||||
)
|
||||
session.add_all([coupon1, coupon2])
|
||||
await session.flush()
|
||||
|
||||
# -- Shrinkflation events --
|
||||
shrink = ShrinkflationEvent(
|
||||
normalized_product_id=cheerios.id,
|
||||
detected_date=today - timedelta(days=15),
|
||||
old_size="20",
|
||||
new_size="18",
|
||||
old_unit="oz",
|
||||
new_unit="oz",
|
||||
price_at_old_size=Decimal("3.99"),
|
||||
price_at_new_size=Decimal("4.29"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced from 20oz to 18oz while price increased",
|
||||
)
|
||||
session.add(shrink)
|
||||
await session.commit()
|
||||
|
||||
for obj in [
|
||||
meijer,
|
||||
kroger,
|
||||
target,
|
||||
cheerios,
|
||||
milk,
|
||||
chicken,
|
||||
purchase1,
|
||||
purchase2,
|
||||
item1,
|
||||
item2,
|
||||
item3,
|
||||
coupon1,
|
||||
coupon2,
|
||||
shrink,
|
||||
]:
|
||||
await session.refresh(obj)
|
||||
|
||||
return {
|
||||
"headers": auth_headers,
|
||||
"user_id": user_id,
|
||||
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
|
||||
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
|
||||
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
|
||||
"items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3},
|
||||
"coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2},
|
||||
"shrinkflation": {"cheerios_shrink": shrink},
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
"""E2E: Auth and session validation flows.
|
||||
|
||||
Registration and login are handled by the Better-Auth service.
|
||||
These tests validate session token handling at the API gateway level.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import _create_test_user_and_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSessionValidation:
|
||||
"""Session edge cases and error responses."""
|
||||
|
||||
async def test_invalid_session_token_rejected(self, client, db_engine):
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": "better-auth.session_token=not-a-real-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_missing_auth(self, client, db_engine):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
async def test_bearer_token_also_works(self, client, db_engine):
|
||||
"""Session tokens passed as Bearer tokens should also be accepted."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="bearer@e2e.com", display_name="Bearer E2E"
|
||||
)
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {session_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "bearer@e2e.com"
|
||||
|
||||
async def test_deleted_user_session_returns_not_found(self, client, db_engine):
|
||||
"""After deleting a user, their session should result in 404 for profile."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="delete-me@e2e.com", display_name="Delete Me"
|
||||
)
|
||||
headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||
assert delete_resp.status_code == 204
|
||||
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code == 404
|
||||
|
||||
async def test_expired_session_rejected(self, client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
|
||||
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": user_id,
|
||||
"email": "expired@e2e.com",
|
||||
"hp": "unused",
|
||||
"dn": "Expired User",
|
||||
"ev": False,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :uid, :ea, :ca, :ua)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
"ua": now,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/auth/me",
|
||||
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthProtectedEndpoints:
|
||||
"""Verify auth is enforced on all user-specific endpoints."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,path",
|
||||
[
|
||||
("GET", "/purchases"),
|
||||
("GET", "/products"),
|
||||
("GET", "/prices/trends"),
|
||||
("GET", "/prices/increases"),
|
||||
("GET", "/coupons"),
|
||||
("GET", "/alerts"),
|
||||
("GET", "/me/stores"),
|
||||
],
|
||||
)
|
||||
async def test_endpoints_require_auth(self, client, db_engine, method, path):
|
||||
resp = await client.request(method, path)
|
||||
assert resp.status_code in (401, 403), f"{method} {path} should require auth"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCrossUserDataIsolation:
|
||||
"""Verify that users cannot access other users' data."""
|
||||
|
||||
async def test_user_b_cannot_access_user_a_purchases(self, client, db_engine, seed_data):
|
||||
"""A second user cannot see User A's purchases."""
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userb@e2e.com", display_name="User B"
|
||||
)
|
||||
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||
assert resp.status_code in (403, 404), (
|
||||
"User B should not be able to access User A's purchase"
|
||||
)
|
||||
|
||||
async def test_user_b_purchase_list_is_empty(self, client, db_engine, seed_data):
|
||||
"""A new user should see no purchases."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userc@e2e.com", display_name="User C"
|
||||
)
|
||||
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get("/purchases", headers=user_c_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||
|
||||
async def test_user_b_stores_isolated(self, client, db_engine, seed_data):
|
||||
"""User B's connected stores should be independent from User A."""
|
||||
_, session_token = await _create_test_user_and_session(
|
||||
client, db_engine, email="userd@e2e.com", display_name="User D"
|
||||
)
|
||||
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
|
||||
|
||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||
@@ -0,0 +1,114 @@
|
||||
"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectToPurchaseFlow:
|
||||
"""Connect a store, then verify purchases and related data are accessible."""
|
||||
|
||||
async def test_connect_store_then_list(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
# Connect to Meijer
|
||||
resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert resp.status_code in (200, 201)
|
||||
|
||||
# Verify store appears in user's connected stores
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
assert stores.status_code == 200
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "meijer" in slugs
|
||||
|
||||
async def test_disconnect_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
await client.post("/me/stores/kroger/connect", json={}, headers=headers)
|
||||
resp = await client.delete("/me/stores/kroger", headers=headers)
|
||||
assert resp.status_code in (200, 204)
|
||||
|
||||
# Verify store no longer in connected list
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "kroger" not in slugs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseToPriceFlow:
|
||||
"""Verify purchase data links to price comparison data."""
|
||||
|
||||
async def test_purchase_items_link_to_products(self, client, seed_data):
|
||||
"""Items from purchases reference products that have price data."""
|
||||
headers = seed_data["headers"]
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
# Get purchase detail
|
||||
purchase = await client.get(f"/purchases/{purchase_id}", headers=headers)
|
||||
assert purchase.status_code == 200
|
||||
items = purchase.json()["line_items"]
|
||||
|
||||
# Get product detail for an item that has a product_id
|
||||
product_ids = [li["product_id"] for li in items if li.get("product_id")]
|
||||
assert len(product_ids) >= 1
|
||||
|
||||
for pid in product_ids:
|
||||
product = await client.get(f"/products/{pid}", headers=headers)
|
||||
assert product.status_code == 200
|
||||
assert len(product.json()["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCouponFlow:
|
||||
"""Verify coupon listing and relevance filtering."""
|
||||
|
||||
async def test_list_all_coupons(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
async def test_filter_coupons_by_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert all(c["store_name"] == "Meijer" for c in data)
|
||||
|
||||
async def test_relevant_coupons_for_user(self, client, seed_data):
|
||||
"""User bought Cheerios, so the Cheerios coupon should be relevant."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons/relevant", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases"
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAlertFlow:
|
||||
"""Verify alert listing with seeded data."""
|
||||
|
||||
async def test_list_alerts(self, client, seed_data):
|
||||
"""User bought Cheerios which has a shrinkflation event — may appear as alert."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
# If alerts are generated synchronously, verify shrinkflation alert content
|
||||
if len(data) > 0:
|
||||
alert_types = [a["alert_type"] for a in data]
|
||||
product_names = [a["product_name"] for a in data]
|
||||
assert any(t in ("shrinkflation", "price_increase") for t in alert_types)
|
||||
assert any("Cheerios" in name for name in product_names)
|
||||
|
||||
async def test_alert_settings_default(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts/settings", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "price_increase_threshold_pct" in data
|
||||
assert "shrinkflation_enabled" in data
|
||||
@@ -0,0 +1,127 @@
|
||||
"""E2E: Error responses for bad input across all endpoint categories."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRegistrationErrors:
|
||||
"""Validation errors during user registration."""
|
||||
|
||||
async def test_short_password(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_email(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_missing_fields(self, client, db_engine):
|
||||
resp = await client.post("/auth/register", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_empty_display_name(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_duplicate_email(self, client, db_engine):
|
||||
payload = {
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "First",
|
||||
}
|
||||
first = await client.post("/auth/register", json=payload)
|
||||
assert first.status_code == 201
|
||||
second = await client.post("/auth/register", json=payload)
|
||||
assert second.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoginErrors:
|
||||
"""Login failure modes."""
|
||||
|
||||
async def test_wrong_password(self, client, db_engine):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login-err@example.com",
|
||||
"password": "correctpass1",
|
||||
"display_name": "Login",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "login-err@example.com", "password": "wrongpass123"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_nonexistent_user(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "nobody@example.com", "password": "doesntmatter"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestNotFoundErrors:
|
||||
"""404 responses for missing resources."""
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_public_trend_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{ZERO_UUID}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMalformedInput:
|
||||
"""Invalid UUID formats and bad query params."""
|
||||
|
||||
async def test_invalid_uuid_product(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_purchase(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_public_trend(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{BAD_UUID}")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectionErrors:
|
||||
"""Store connection edge cases."""
|
||||
|
||||
async def test_connect_nonexistent_store(self, client, seed_data):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent-store/connect",
|
||||
json={},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_connect_store_twice(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
first = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert first.status_code in (200, 201)
|
||||
second = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert second.status_code == 409
|
||||
@@ -0,0 +1,102 @@
|
||||
"""E2E: Price history queries returning correct data."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceTrends:
|
||||
"""Verify price trend aggregation against seeded history."""
|
||||
|
||||
async def test_trends_returns_all_products(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
product_names = [t["product_name"] for t in data]
|
||||
assert "Cheerios 18oz" in product_names
|
||||
assert "Whole Milk 1gal" in product_names
|
||||
|
||||
async def test_trends_filter_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
# Only dairy products should appear
|
||||
for trend in data:
|
||||
assert trend["product_name"] == "Whole Milk 1gal"
|
||||
|
||||
async def test_trends_contain_data_points(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz")
|
||||
assert len(cheerios_trend["data_points"]) >= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceIncreases:
|
||||
"""Detect price increases from seeded price history."""
|
||||
|
||||
async def test_increases_detected(self, client, seed_data):
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Cheerios at Meijer went from 3.99 → 4.29 → 4.79
|
||||
cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"]
|
||||
assert len(cheerios_increases) >= 1
|
||||
# Verify the increase data makes sense
|
||||
for inc in cheerios_increases:
|
||||
assert inc["new_price"] > inc["old_price"]
|
||||
assert inc["increase_pct"] > 0
|
||||
assert inc["store_name"] == "Meijer"
|
||||
|
||||
async def test_stable_prices_not_flagged(self, client, seed_data):
|
||||
"""Kroger Cheerios price is stable at $4.49 — should not appear as increase."""
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
kroger_increases = [
|
||||
inc
|
||||
for inc in data
|
||||
if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger"
|
||||
]
|
||||
assert len(kroger_increases) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceComparison:
|
||||
"""Compare prices across stores for specific products."""
|
||||
|
||||
async def test_compare_cheerios_across_stores(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params={"product_ids": cheerios_id},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
cheerios_cmp = data[0]
|
||||
assert cheerios_cmp["product_name"] == "Cheerios 18oz"
|
||||
store_names = [p["store_name"] for p in cheerios_cmp["prices"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_compare_requires_product_ids(self, client, seed_data):
|
||||
"""product_ids is required — omitting it must return 422."""
|
||||
resp = await client.get("/prices/comparison", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_compare_multiple_products(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
milk_id = str(seed_data["products"]["milk"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params=[("product_ids", cheerios_id), ("product_ids", milk_id)],
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
names = [c["product_name"] for c in data]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
@@ -0,0 +1,82 @@
|
||||
"""E2E: Product search/lookup endpoints with real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductSearch:
|
||||
"""Search and filter products against seeded data."""
|
||||
|
||||
async def test_list_all_products(self, client, seed_data):
|
||||
resp = await client.get("/products", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
names = [p["name"] for p in products]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
assert "Chicken Breast 1lb" in names
|
||||
|
||||
async def test_search_by_name(self, client, seed_data):
|
||||
resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all("cheerios" in p["name"].lower() for p in products)
|
||||
|
||||
async def test_search_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all(p["category"] == "dairy" for p in products)
|
||||
|
||||
async def test_search_no_results(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductLookup:
|
||||
"""Detailed product lookups with cross-store pricing."""
|
||||
|
||||
async def test_get_product_detail_with_prices(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert data["category"] == "pantry"
|
||||
# Should have prices from both Meijer and Kroger
|
||||
store_names = [p["store_name"] for p in data["prices_by_store"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_product_prices_reflect_latest(self, client, seed_data):
|
||||
"""The latest Meijer price for Cheerios should be 4.79 (the increase)."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer")
|
||||
assert meijer_price["current_price"] == 4.79
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_product_price_history(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations
|
||||
# Verify chronological ordering exists
|
||||
prices = [dp["price"] for dp in data["data_points"]]
|
||||
assert len(prices) >= 3
|
||||
@@ -0,0 +1,59 @@
|
||||
"""E2E: Public price transparency endpoints (no auth required)."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicTrends:
|
||||
"""Public price trend endpoint — no auth, real data."""
|
||||
|
||||
async def test_public_trend_returns_data(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) >= 3
|
||||
|
||||
async def test_public_trend_no_auth_needed(self, client, seed_data):
|
||||
"""Confirm no Authorization header is required."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicStoreComparison:
|
||||
"""Public store comparison endpoint."""
|
||||
|
||||
async def test_store_comparison(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/public/store-comparison",
|
||||
params=[("product_ids", cheerios_id)],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "products" in data
|
||||
assert len(data["products"]) >= 1
|
||||
|
||||
async def test_store_comparison_rejects_more_than_20_ids(self, client):
|
||||
"""max_length=20 guard: 21 product IDs must return 422."""
|
||||
too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)]
|
||||
resp = await client.get("/public/store-comparison", params=too_many)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicInflation:
|
||||
"""Public inflation index endpoint."""
|
||||
|
||||
async def test_inflation_returns_index(self, client, seed_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "cartsnitch_index" in data
|
||||
assert "cpi_baseline" in data
|
||||
assert "categories" in data
|
||||
@@ -0,0 +1,87 @@
|
||||
"""E2E: Purchase listing, detail, and stats against real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseList:
|
||||
"""List and filter a user's purchases."""
|
||||
|
||||
async def test_list_user_purchases(self, client, seed_data):
|
||||
resp = await client.get("/purchases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
store_names = [p["store_name"] for p in data]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_filter_purchases_by_store(self, client, seed_data):
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get(
|
||||
"/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert all(p["store_name"] == "Meijer" for p in data)
|
||||
|
||||
async def test_purchases_require_auth(self, client, seed_data):
|
||||
resp = await client.get("/purchases")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseDetail:
|
||||
"""Retrieve individual purchase with line items."""
|
||||
|
||||
async def test_get_purchase_detail(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["store_name"] == "Meijer"
|
||||
assert data["total"] == 23.45
|
||||
assert len(data["line_items"]) == 2
|
||||
item_names = [li["name"] for li in data["line_items"]]
|
||||
assert "Cheerios 18oz Box" in item_names
|
||||
assert "Meijer Whole Milk 1gal" in item_names
|
||||
|
||||
async def test_line_item_amounts_correct(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"])
|
||||
assert cheerios_item["unit_price"] == 4.79
|
||||
assert cheerios_item["quantity"] == 1.0
|
||||
assert cheerios_item["total_price"] == 4.79
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
f"/purchases/{ZERO_UUID}",
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseStats:
|
||||
"""Verify spending aggregation across purchases."""
|
||||
|
||||
async def test_purchase_stats_totals(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["purchase_count"] == 2
|
||||
# 23.45 + 15.78 = 39.23
|
||||
assert abs(data["total_spent"] - 39.23) < 0.01
|
||||
|
||||
async def test_purchase_stats_by_store(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
assert "Meijer" in data["by_store"]
|
||||
assert "Kroger" in data["by_store"]
|
||||
assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01
|
||||
assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Tests for EncryptedJSON TypeDecorator and session_data encryption."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import column, create_engine, table, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.models import Base
|
||||
from cartsnitch_api.models.store import Store
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(session):
|
||||
s = Store(name="Test Store", slug="test-store")
|
||||
session.add(s)
|
||||
session.commit()
|
||||
session.refresh(s)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user(session):
|
||||
u = User(email="alice@example.com", hashed_password="fakehash")
|
||||
session.add(u)
|
||||
session.commit()
|
||||
session.refresh(u)
|
||||
return u
|
||||
|
||||
|
||||
class TestEncryptedJSONType:
|
||||
"""Unit tests for the EncryptedJSON TypeDecorator."""
|
||||
|
||||
def test_round_trip(self, session, user, store):
|
||||
"""Data written via the ORM comes back as the original dict."""
|
||||
original = {"token": "abc123", "cookies": {"session_id": "xyz"}}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == original
|
||||
|
||||
def test_stored_value_is_encrypted(self, session, user, store):
|
||||
"""The raw value in the DB should be a Fernet token, not plaintext JSON."""
|
||||
original = {"secret": "do-not-leak"}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
# Use a raw table construct to bypass TypeDecorator on read
|
||||
raw_table = table("user_store_accounts", column("id"), column("session_data"))
|
||||
raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first()
|
||||
# If UUID matching fails with str, try bytes format
|
||||
if raw is None:
|
||||
raw = session.execute(
|
||||
text("SELECT session_data FROM user_store_accounts LIMIT 1")
|
||||
).scalar_one()
|
||||
else:
|
||||
raw = raw[1]
|
||||
|
||||
assert raw != json.dumps(original)
|
||||
assert raw.startswith("gAAAAA")
|
||||
|
||||
# Verify we can decrypt the raw value manually
|
||||
f = Fernet(settings.fernet_key.encode())
|
||||
decrypted = json.loads(f.decrypt(raw.encode()))
|
||||
assert decrypted == original
|
||||
|
||||
def test_null_round_trip(self, session, user, store):
|
||||
"""NULL session_data stays NULL."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data is None
|
||||
|
||||
def test_empty_dict_round_trip(self, session, user, store):
|
||||
"""Empty dict round-trips correctly."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {}
|
||||
|
||||
def test_update_session_data(self, session, user, store):
|
||||
"""Updating session_data re-encrypts the new value."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
account.session_data = {"v": 2, "new_field": True}
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {"v": 2, "new_field": True}
|
||||
|
||||
|
||||
class TestEncryptionKeyValidation:
|
||||
"""Test that invalid/missing keys are caught at startup."""
|
||||
|
||||
def test_invalid_fernet_key_rejected(self, monkeypatch):
|
||||
"""Settings validation rejects a bad key."""
|
||||
monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
from cartsnitch_api.config import Settings
|
||||
|
||||
Settings()
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Conftest for middleware tests — re-enables rate limiting after global disable."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_rate_limiting():
|
||||
"""Re-enable rate limiting after the global disable_rate_limiting fixture runs.
|
||||
|
||||
The root conftest disables rate limiting for all tests to prevent 429
|
||||
interference. Middleware tests need it active to verify headers and
|
||||
enforcement. This fixture runs after the root fixture (more local = later
|
||||
in setup order) so True is the effective value during the test body.
|
||||
"""
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for structured error responses and error monitoring."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_returns_structured_error(client):
|
||||
"""Non-existent route should return structured error."""
|
||||
resp = await client.get("/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert "detail" in body
|
||||
assert "code" in body
|
||||
assert body["code"] == "NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_returns_422_with_field_errors(client):
|
||||
"""Invalid request body should return structured validation errors."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "short", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["code"] == "VALIDATION_ERROR"
|
||||
assert "errors" in body
|
||||
assert isinstance(body["errors"], list)
|
||||
assert len(body["errors"]) > 0
|
||||
# Each error should have field, message, type
|
||||
for err in body["errors"]:
|
||||
assert "field" in err
|
||||
assert "message" in err
|
||||
assert "type" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_requires_service_key(client):
|
||||
"""Error stats endpoint should require X-Service-Key."""
|
||||
resp = await client.get("/internal/error-stats")
|
||||
assert resp.status_code == 422 # Missing required header
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_with_valid_key(client):
|
||||
"""Error stats endpoint returns monitoring data with valid key."""
|
||||
resp = await client.get(
|
||||
"/internal/error-stats",
|
||||
headers={"X-Service-Key": "change-me-in-production"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "error_counts" in body
|
||||
assert "recent_5xx_count" in body
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
def test_allows_within_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
||||
for i in range(5):
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4 - i
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
||||
for _ in range(3):
|
||||
counter.is_allowed("test-key")
|
||||
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert remaining == 0
|
||||
assert retry > 0
|
||||
|
||||
def test_separate_keys(self):
|
||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
||||
# Fill key-a
|
||||
counter.is_allowed("key-a")
|
||||
counter.is_allowed("key-a")
|
||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
"""Public endpoint should return 429 after limit exceeded."""
|
||||
# The default limit is 60/min — we won't hit it in normal tests,
|
||||
# but we verify the middleware adds rate limit headers.
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Tests for SQLAlchemy ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from cartsnitch_api.constants import (
|
||||
AccountStatus,
|
||||
DiscountType,
|
||||
PriceSource,
|
||||
ProductCategory,
|
||||
SizeUnit,
|
||||
StoreSlug,
|
||||
)
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
StoreLocation,
|
||||
User,
|
||||
UserStoreAccount,
|
||||
)
|
||||
|
||||
|
||||
class TestTableCreation:
|
||||
"""Verify all expected tables are created."""
|
||||
|
||||
def test_all_tables_exist(self, engine):
|
||||
inspector = inspect(engine)
|
||||
table_names = set(inspector.get_table_names())
|
||||
expected = {
|
||||
"stores",
|
||||
"store_locations",
|
||||
"users",
|
||||
"user_store_accounts",
|
||||
"purchases",
|
||||
"purchase_items",
|
||||
"normalized_products",
|
||||
"price_history",
|
||||
"coupons",
|
||||
"shrinkflation_events",
|
||||
}
|
||||
assert expected.issubset(table_names)
|
||||
|
||||
def test_ten_tables_total(self, engine):
|
||||
inspector = inspect(engine)
|
||||
assert len(inspector.get_table_names()) == 10
|
||||
|
||||
|
||||
class TestUUIDPrimaryKeys:
|
||||
"""All models use UUID PKs."""
|
||||
|
||||
def test_store_uuid_pk(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert isinstance(store.id, uuid.UUID)
|
||||
|
||||
def test_user_uuid_pk(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
assert isinstance(user.id, uuid.UUID)
|
||||
|
||||
|
||||
class TestStoreModel:
|
||||
def test_store_slug_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert store.slug == StoreSlug.KROGER
|
||||
|
||||
def test_store_unique_slug(self, session):
|
||||
s1 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
s2 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target Duplicate",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(s1)
|
||||
session.commit()
|
||||
session.add(s2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestStoreLocationModel:
|
||||
def test_store_location_fields(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
loc = StoreLocation(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
address="123 Main St",
|
||||
city="Ann Arbor",
|
||||
state="MI",
|
||||
zip="48104",
|
||||
lat=42.2808,
|
||||
lng=-83.7430,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(loc)
|
||||
session.commit()
|
||||
assert loc.city == "Ann Arbor"
|
||||
assert loc.lat == pytest.approx(42.2808)
|
||||
|
||||
|
||||
class TestUserStoreAccountModel:
|
||||
def test_account_status_enum(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
acct = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(acct)
|
||||
session.commit()
|
||||
assert acct.status == AccountStatus.ACTIVE
|
||||
|
||||
def test_unique_user_store_constraint(self, session):
|
||||
"""One account per user per store."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="unique@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
a1 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
a2 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.EXPIRED,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(a1)
|
||||
session.commit()
|
||||
session.add(a2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestPurchaseModel:
|
||||
def test_purchase_with_items(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="buyer@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
purchase = Purchase(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="RCP-001",
|
||||
purchase_date=date(2026, 3, 15),
|
||||
total=Decimal("42.50"),
|
||||
ingested_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(purchase)
|
||||
session.flush()
|
||||
item = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Meijer Whole Milk 1 Gallon",
|
||||
upc="0041250000001",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("3.49"),
|
||||
extended_price=Decimal("3.49"),
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
assert item.product_name_raw == "Meijer Whole Milk 1 Gallon"
|
||||
assert item.unit_price == Decimal("3.49")
|
||||
|
||||
|
||||
class TestNormalizedProductModel:
|
||||
def test_product_with_upc_variants(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Whole Milk, 1 Gallon",
|
||||
category=ProductCategory.DAIRY,
|
||||
brand="Store Brand",
|
||||
size="128",
|
||||
size_unit=SizeUnit.FL_OZ,
|
||||
upc_variants=["0041250000001", "0041250000002"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
assert product.category == ProductCategory.DAIRY
|
||||
assert product.size_unit == SizeUnit.FL_OZ
|
||||
|
||||
|
||||
class TestPriceHistoryModel:
|
||||
def test_price_source_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Eggs, Large, 12ct",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([store, product])
|
||||
session.flush()
|
||||
ph = PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 15),
|
||||
regular_price=Decimal("4.99"),
|
||||
sale_price=Decimal("3.99"),
|
||||
source=PriceSource.RECEIPT,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(ph)
|
||||
session.commit()
|
||||
assert ph.source == PriceSource.RECEIPT
|
||||
assert ph.regular_price == Decimal("4.99")
|
||||
|
||||
|
||||
class TestCouponModel:
|
||||
def test_coupon_discount_types(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
coupon = Coupon(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
title="$2 off eggs",
|
||||
discount_type=DiscountType.FIXED,
|
||||
discount_value=Decimal("2.00"),
|
||||
requires_clip=True,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(coupon)
|
||||
session.commit()
|
||||
assert coupon.discount_type == DiscountType.FIXED
|
||||
assert coupon.discount_value == Decimal("2.00")
|
||||
|
||||
|
||||
class TestShrinkflationEventModel:
|
||||
def test_shrinkflation_event(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Cereal, Honey Oats",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.flush()
|
||||
event = ShrinkflationEvent(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
detected_date=date(2026, 3, 10),
|
||||
old_size="18",
|
||||
new_size="15.4",
|
||||
old_unit=SizeUnit.OZ,
|
||||
new_unit=SizeUnit.OZ,
|
||||
price_at_old_size=Decimal("4.99"),
|
||||
price_at_new_size=Decimal("4.99"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced by 14.4%, price unchanged",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
assert event.confidence == Decimal("0.95")
|
||||
assert event.old_unit == SizeUnit.OZ
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Verify all expected routes are present in the OpenAPI spec."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from cartsnitch_api.main import app
|
||||
|
||||
EXPECTED_ROUTES = [
|
||||
# Auth (6)
|
||||
("post", "/auth/register"),
|
||||
("post", "/auth/login"),
|
||||
("post", "/auth/refresh"),
|
||||
("get", "/auth/me"),
|
||||
("patch", "/auth/me"),
|
||||
("delete", "/auth/me"),
|
||||
# Stores (4)
|
||||
("get", "/stores"),
|
||||
("get", "/me/stores"),
|
||||
("post", "/me/stores/{store_slug}/connect"),
|
||||
("delete", "/me/stores/{store_slug}"),
|
||||
# Purchases (3)
|
||||
("get", "/purchases"),
|
||||
("get", "/purchases/stats"),
|
||||
("get", "/purchases/{purchase_id}"),
|
||||
# Products (3)
|
||||
("get", "/products"),
|
||||
("get", "/products/{product_id}"),
|
||||
("get", "/products/{product_id}/prices"),
|
||||
# Prices (3)
|
||||
("get", "/prices/trends"),
|
||||
("get", "/prices/increases"),
|
||||
("get", "/prices/comparison"),
|
||||
# Coupons (2)
|
||||
("get", "/coupons"),
|
||||
("get", "/coupons/relevant"),
|
||||
# Shopping (2)
|
||||
("post", "/shopping/optimize"),
|
||||
("get", "/shopping/lists"),
|
||||
# Alerts (3)
|
||||
("get", "/alerts"),
|
||||
("get", "/alerts/settings"),
|
||||
("put", "/alerts/settings"),
|
||||
# Scraping (2)
|
||||
("post", "/scraping/{store_slug}/sync"),
|
||||
("get", "/scraping/status"),
|
||||
# Public (3)
|
||||
("get", "/public/trends/{product_id}"),
|
||||
("get", "/public/store-comparison"),
|
||||
("get", "/public/inflation"),
|
||||
# Health (1)
|
||||
("get", "/health"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_routes_in_openapi():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
registered = set()
|
||||
for path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
registered.add((method, path))
|
||||
|
||||
missing = []
|
||||
for method, path in EXPECTED_ROUTES:
|
||||
if (method, path) not in registered:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_count():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
count = 0
|
||||
for _path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
count += 1
|
||||
|
||||
assert count == 33, f"Expected 33 routes, found {count}"
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for alert endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_alerts_empty(client, auth_headers):
|
||||
"""No purchases means no alerts."""
|
||||
resp = await client.get("/alerts", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_alert_settings(client, auth_headers):
|
||||
resp = await client.get("/alerts/settings", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["price_increase_threshold_pct"] == 5.0
|
||||
assert data["shrinkflation_enabled"] is True
|
||||
assert data["email_notifications"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_alert_settings_returns_501(client, auth_headers):
|
||||
resp = await client.put(
|
||||
"/alerts/settings",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"price_increase_threshold_pct": 10.0,
|
||||
"shrinkflation_enabled": False,
|
||||
"email_notifications": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 501
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Integration tests for coupon endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import Coupon, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coupon_data(db_engine, auth_headers):
|
||||
"""Seed stores and coupons."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
|
||||
coupon = Coupon(
|
||||
store_id=store.id,
|
||||
title="$2 off laundry",
|
||||
description="$2 off any laundry detergent",
|
||||
discount_value=Decimal("2.00"),
|
||||
discount_type="fixed",
|
||||
valid_from=date(2026, 1, 1),
|
||||
valid_to=date(2026, 12, 31),
|
||||
)
|
||||
session.add(coupon)
|
||||
await session.commit()
|
||||
|
||||
return {"store": store, "coupon": coupon, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons(client, coupon_data):
|
||||
resp = await client.get("/coupons", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons_by_store(client, coupon_data):
|
||||
store_id = str(coupon_data["store"].id)
|
||||
resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevant_coupons_empty(client, auth_headers):
|
||||
"""No purchases means no relevant coupons."""
|
||||
resp = await client.get("/coupons/relevant", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for price endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def price_data(db_engine, auth_headers):
|
||||
"""Seed products with price history showing an increase."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Walmart", slug="walmart")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Tide Pods 42ct",
|
||||
category="household",
|
||||
brand="Tide",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
# Two price points — second is higher (increase)
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 2, 1),
|
||||
regular_price=Decimal("12.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("14.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends(client, price_data):
|
||||
resp = await client.get("/prices/trends", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["data_points"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends_by_category(client, price_data):
|
||||
resp = await client.get("/prices/trends?category=household", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_increases(client, price_data):
|
||||
resp = await client.get("/prices/increases", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
increase = data[0]
|
||||
assert increase["old_price"] == 12.99
|
||||
assert increase["new_price"] == 14.49
|
||||
assert increase["increase_pct"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_comparison(client, price_data):
|
||||
pid = str(price_data["product"].id)
|
||||
resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["prices"]) >= 1
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Integration tests for product endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def product_data(db_engine, auth_headers):
|
||||
"""Seed products and price history."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("4.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 10),
|
||||
regular_price=Decimal("5.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_products(client, product_data):
|
||||
resp = await client.get("/products", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["name"] == "Cheerios 18oz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_products(client, product_data):
|
||||
resp = await client.get("/products?q=Cheerios", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/products?q=nonexistent", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_detail(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert len(data["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_prices(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) == 2
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Integration tests for public endpoints (no auth)."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def public_data(db_engine):
|
||||
"""Seed data for public endpoints."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Skippy PB 16oz",
|
||||
category="pantry",
|
||||
brand="Skippy",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 5),
|
||||
regular_price=Decimal("3.99"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add(ph)
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Skippy PB 16oz"
|
||||
assert len(data["data_points"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend_not_found(client):
|
||||
resp = await client.get(f"/public/trends/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_store_comparison(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["products"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_inflation(client, public_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "categories" in data
|
||||
assert "cartsnitch_index" in data
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Integration tests for purchase endpoints."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def purchase_data(db_engine):
|
||||
"""Seed a user, store, purchase, items, and a valid session."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
user = User(
|
||||
email="buyer@example.com",
|
||||
hashed_password="not-used-with-better-auth",
|
||||
display_name="Buyer",
|
||||
)
|
||||
store = Store(name="Kroger", slug="kroger")
|
||||
session.add_all([user, store])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(store)
|
||||
|
||||
purchase = Purchase(
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="receipt-001",
|
||||
purchase_date=date(2026, 3, 10),
|
||||
total=Decimal("42.50"),
|
||||
)
|
||||
session.add(purchase)
|
||||
await session.commit()
|
||||
await session.refresh(purchase)
|
||||
|
||||
item = PurchaseItem(
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Organic Milk 1gal",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("5.99"),
|
||||
extended_price=Decimal("5.99"),
|
||||
)
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
# Create a session token directly in the sessions table
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO sessions (id, token, user_id, expires_at, created_at, updated_at) "
|
||||
"VALUES (:id, :token, :user_id, :expires_at, :created_at, :updated_at)"
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"user_id": str(user.id),
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"store": store,
|
||||
"purchase": purchase,
|
||||
"headers": {"Cookie": f"better-auth.session_token={session_token}"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_purchases(client, purchase_data):
|
||||
resp = await client.get("/purchases", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["store_name"] == "Kroger"
|
||||
assert data[0]["total"] == 42.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_detail(client, purchase_data):
|
||||
pid = str(purchase_data["purchase"].id)
|
||||
resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["line_items"]) == 1
|
||||
assert data["line_items"][0]["name"] == "Organic Milk 1gal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purchase_stats(client, purchase_data):
|
||||
resp = await client.get("/purchases/stats", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total_spent"] == 42.50
|
||||
assert data["purchase_count"] == 1
|
||||
assert "Kroger" in data["by_store"]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Integration tests for store endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seeded_store(db_engine):
|
||||
"""Insert a test store directly into the DB."""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None)
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_stores(client, seeded_store):
|
||||
resp = await client.get("/stores")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["slug"] == "meijer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_stores_empty(client, auth_headers):
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
|
||||
# Connect
|
||||
resp = await client.post(
|
||||
"/me/stores/meijer/connect",
|
||||
headers=auth_headers,
|
||||
json={"credentials": None},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["connected"] is True
|
||||
|
||||
# List should show connected
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
# Disconnect
|
||||
resp = await client.delete("/me/stores/meijer", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# List should be empty again
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_nonexistent_store(client, auth_headers):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent/connect",
|
||||
headers=auth_headers,
|
||||
json={},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_duplicate_store(client, auth_headers, seeded_store):
|
||||
await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
assert resp.status_code == 409
|
||||
@@ -0,0 +1,11 @@
|
||||
# Required: Generate with `openssl rand -base64 32`
|
||||
BETTER_AUTH_SECRET=change-me-in-production-min-32-chars!!
|
||||
|
||||
# Base URL of the auth service
|
||||
BETTER_AUTH_URL=http://localhost:3001
|
||||
|
||||
# Shared PostgreSQL database
|
||||
DATABASE_URL=postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch
|
||||
|
||||
# Port the auth service listens on
|
||||
PORT=3001
|
||||
@@ -0,0 +1,17 @@
|
||||
FROM node:22-alpine AS builder
|
||||
WORKDIR /app
|
||||
COPY package.json package-lock.json* ./
|
||||
RUN npm ci
|
||||
COPY tsconfig.json ./
|
||||
COPY src/ src/
|
||||
RUN npm run build
|
||||
|
||||
FROM node:22-alpine
|
||||
WORKDIR /app
|
||||
ENV NODE_ENV=production
|
||||
COPY package.json package-lock.json* ./
|
||||
RUN npm ci --omit=dev
|
||||
COPY --from=builder /app/dist/ dist/
|
||||
USER 101
|
||||
EXPOSE 3001
|
||||
CMD ["node", "dist/index.js"]
|
||||
Generated
+1754
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user