Compare commits

..

2 Commits

Author SHA1 Message Date
Flea Flicker 599a9c4060 ci: push Docker images to Gitea registry (git.farh.net)
CI / lint (pull_request) Has been cancelled
CI / typecheck (pull_request) Has been cancelled
CI / test (pull_request) Has been cancelled
CI / build-and-push (pull_request) Has been cancelled
CI / deploy-dev (pull_request) Has been cancelled
CI / deploy-uat (pull_request) Has been cancelled
2026-05-23 15:37:02 +00:00
cartsnitch-ceo[bot] cb180b511f release: promote API migration to production
Production merge approved by CEO (Coupon Carl). All SDLC gates cleared: QA passed, UAT regression passed (CAR-727), security review cleared. Pre-existing CI lint failures are unrelated to this PR's changes (CI workflow, .grype.yaml, CLAUDE.md only).
2026-04-19 12:27:19 +00:00
45 changed files with 635 additions and 2205 deletions
-172
View File
@@ -1,172 +0,0 @@
name: CI
on:
push:
branches: [main, dev, uat]
pull_request:
branches: [main, dev, uat]
concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
packages: write
env:
REGISTRY: git.farh.net
IMAGE_NAME: cartsnitch/api
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- run: pip install ruff
- name: Ruff lint
run: ruff check .
- name: Ruff format check
run: ruff format --check .
typecheck:
runs-on: ubuntu-latest
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- run: pip install -e ".[dev]" mypy
- name: Type check
run: mypy src/cartsnitch_api
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15-alpine
env:
POSTGRES_USER: cartsnitch
POSTGRES_PASSWORD: cartsnitch_test
POSTGRES_DB: cartsnitch_test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test
CARTSNITCH_REDIS_URL: redis://localhost:6379/0
CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod
CARTSNITCH_SERVICE_KEY: test-service-key-do-not-use-in-prod
CARTSNITCH_FERNET_KEY: wXWQsC0FZlhSz2t_tfVQjNUSP8vgAGG3o3pkjrX8Bw0=
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- run: pip install -e ".[dev]"
- name: Run tests
run: pytest --tb=short -q
build-and-push:
if: github.event_name == 'push'
runs-on: ubuntu-latest
needs: [lint, test]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generate CalVer tag
id: calver
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
DATE_TAG=$(date -u +%Y.%m.%d)
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
if [ -z "$EXISTING" ]; then
VERSION="$DATE_TAG"
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
VERSION="${DATE_TAG}.2"
else
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
fi
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
echo "CalVer tag: $VERSION"
- name: Log in to Gitea Container Registry
run: echo "${{ secrets.REGISTRY_TOKEN }}" | docker login git.farh.net -u ${{ github.actor }} --password-stdin
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
- name: Build Docker image
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
load: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
APT_CACHE_BUST=${{ github.run_id }}
- name: Scan api image for vulnerabilities
uses: anchore/scan-action@v5
id: scan
env:
GRYPE_CONFIG: .grype.yaml
with:
image: "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ github.sha }}"
fail-build: true
severity-cutoff: high
only-fixed: "true"
output-format: sarif
- name: Push Docker image
if: github.event_name == 'push'
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Create git tag
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
git tag "v${{ steps.calver.outputs.version }}"
git push origin "v${{ steps.calver.outputs.version }}"
+298
View File
@@ -0,0 +1,298 @@
name: CI
on:
push:
branches: [main, dev]
pull_request:
branches: [main, dev]
concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: write
packages: write
env:
REGISTRY: git.farh.net
IMAGE_NAME: cartsnitch/api
jobs:
lint:
runs-on: runners-cartsnitch
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- run: pip install ruff
- name: Ruff lint
run: ruff check .
- name: Ruff format check
run: ruff format --check .
typecheck:
runs-on: runners-cartsnitch
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- run: pip install -e ".[dev]" mypy
- name: Type check
run: mypy src/cartsnitch_api
test:
runs-on: runners-cartsnitch
services:
postgres:
image: postgres:15-alpine
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
env:
POSTGRES_USER: cartsnitch
POSTGRES_PASSWORD: cartsnitch_test
POSTGRES_DB: cartsnitch_test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
CARTSNITCH_DATABASE_URL: postgresql+asyncpg://cartsnitch:cartsnitch_test@localhost:5432/cartsnitch_test
CARTSNITCH_REDIS_URL: redis://localhost:6379/0
CARTSNITCH_JWT_SECRET_KEY: test-secret-do-not-use-in-prod
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libpq-dev build-essential
- run: pip install -e ".[dev]"
- name: Run tests
run: pytest --tb=short -q
build-and-push:
runs-on: runners-cartsnitch
needs: [lint, test]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generate CalVer tag
id: calver
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
DATE_TAG=$(date -u +%Y.%m.%d)
EXISTING=$(git tag -l "v${DATE_TAG}*" | sort -V | tail -1)
if [ -z "$EXISTING" ]; then
VERSION="$DATE_TAG"
elif [ "$EXISTING" = "v${DATE_TAG}" ]; then
VERSION="${DATE_TAG}.2"
else
BUILD_NUM=$(echo "$EXISTING" | sed "s/v${DATE_TAG}\.//")
VERSION="${DATE_TAG}.$((BUILD_NUM + 1))"
fi
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
echo "CalVer tag: $VERSION"
- name: Log in to Gitea Container Registry
if: github.event_name == 'push'
uses: docker/login-action@v3
with:
registry: git.farh.net
username: cartsnitch
password: ${{ secrets.GITEA_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
- name: Build Docker image
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
load: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
APT_CACHE_BUST=${{ github.run_id }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Scan api image for vulnerabilities
uses: anchore/scan-action@v5
id: scan
env:
GRYPE_CONFIG: .grype.yaml
with:
image: "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ github.sha }}"
fail-build: true
severity-cutoff: high
only-fixed: "true"
output-format: sarif
- name: Upload api scan results to GitHub Security
uses: github/codeql-action/upload-sarif@v3
if: always()
with:
sarif_file: ${{ steps.scan.outputs.sarif }}
- name: Push Docker image
if: github.event_name == 'push'
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
APT_CACHE_BUST=${{ github.run_id }}
cache-from: type=gha
- name: Create git tag
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
run: |
git tag "v${{ steps.calver.outputs.version }}"
git push origin "v${{ steps.calver.outputs.version }}"
deploy-dev:
runs-on: runners-cartsnitch
needs: [build-and-push]
if: always() && !cancelled() && github.event_name == 'push' && (github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/main')
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ secrets.CARTSNITCH_APP_ID }}
private-key: ${{ secrets.CARTSNITCH_APP_PRIVATE_KEY }}
owner: ${{ github.repository_owner }}
repositories: infra
- name: Checkout infra repo
uses: actions/checkout@v4
with:
repository: cartsnitch/infra
token: ${{ steps.app-token.outputs.token }}
ref: main
path: infra
- name: Install kubectl
uses: azure/setup-kubectl@v4
- name: Install kustomize
uses: imranismail/setup-kustomize@v2
- name: Determine image tag
id: api_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
- name: Update api image tag
if: needs.build-and-push.result == 'success'
run: |
cd infra/apps/overlays/dev
kustomize edit set image ghcr.io/cartsnitch/api:${{ steps.api_tag.outputs.tag }}
- name: Commit and push to infra
run: |
cd infra
git config user.name "cartsnitch-ci[bot]"
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
git add apps/overlays/dev/kustomization.yaml
git commit -m "ci(dev): update api image"
git pull --rebase origin main
git push origin main
deploy-uat:
runs-on: runners-cartsnitch
needs: [build-and-push]
if: always() && !cancelled() && github.event_name == 'push' && (github.ref == 'refs/heads/uat' || github.ref == 'refs/heads/main')
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ secrets.CARTSNITCH_APP_ID }}
private-key: ${{ secrets.CARTSNITCH_APP_PRIVATE_KEY }}
owner: ${{ github.repository_owner }}
repositories: infra
- name: Checkout infra repo
uses: actions/checkout@v4
with:
repository: cartsnitch/infra
token: ${{ steps.app-token.outputs.token }}
ref: main
path: infra
- name: Install kubectl
uses: azure/setup-kubectl@v4
- name: Install kustomize
uses: imranismail/setup-kustomize@v2
- name: Determine image tag
id: api_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
- name: Update api image tag
if: needs.build-and-push.result == 'success'
run: |
cd infra/apps/overlays/uat
kustomize edit set image ghcr.io/cartsnitch/api:${{ steps.api_tag.outputs.tag }}
- name: Commit and push to infra
run: |
cd infra
git config user.name "cartsnitch-ci[bot]"
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
git add apps/overlays/uat/kustomization.yaml
git commit -m "ci(uat): update api image"
git pull --rebase origin main
git push origin main
-11
View File
@@ -1,11 +0,0 @@
{
"mcpServers": {
"gitea": {
"type": "http",
"url": "https://git-mcp.farh.net/mcp",
"headers": {
"Authorization": "Bearer ${GITEA_TOKEN}"
}
}
}
}
+1 -6
View File
@@ -45,11 +45,7 @@ def run_migrations_online() -> None:
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
version_table_column_width=128,
)
context.configure(connection=connection, target_metadata=target_metadata, version_table_column_width=128)
with context.begin_transaction():
context.run_migrations()
# Create any tables defined in models but not yet created by migrations.
@@ -60,7 +56,6 @@ def run_migrations_online() -> None:
connection.commit()
except Exception as exc:
import logging
logging.getLogger("alembic.env").warning(
"create_all failed (non-fatal, migrations should handle table creation): %s", exc
)
+9 -44
View File
@@ -30,10 +30,7 @@ def upgrade() -> None:
if inspector.has_table("users"):
existing_user_cols = [c["name"] for c in inspector.get_columns("users")]
if "email_verified" not in existing_user_cols:
op.add_column(
"users",
sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
if "image" not in existing_user_cols:
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
@@ -47,18 +44,8 @@ def upgrade() -> None:
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("ip_address", sa.Text(), nullable=True),
sa.Column("user_agent", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
@@ -79,18 +66,8 @@ def upgrade() -> None:
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("id_token", sa.Text(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
@@ -103,18 +80,8 @@ def upgrade() -> None:
sa.Column("identifier", sa.Text(), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
@@ -129,10 +96,8 @@ def upgrade() -> None:
user_id_str = str(user_id)
conn.execute(
text(
"INSERT INTO accounts "
"(id, user_id, account_id, provider_id, password, created_at, updated_at) "
"VALUES (gen_random_uuid()::text, :user_id, :account_id, "
"'credential', :password, now(), now())"
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
),
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
)
+2 -12
View File
@@ -40,12 +40,7 @@ def upgrade() -> None:
return # already TEXT — nothing to do
# Step 1: Drop existing FK constraints (ignore if they don't exist)
op.execute(
text(
"ALTER TABLE user_store_accounts "
"DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"
)
)
op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"))
op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey"))
# Step 2: Alter users.id from uuid to text
@@ -94,12 +89,7 @@ def upgrade() -> None:
def downgrade() -> None:
# Drop FK constraints
op.execute(
text(
"ALTER TABLE user_store_accounts "
"DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"
)
)
op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"))
op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey"))
# Revert users.id from text to uuid
@@ -20,7 +20,7 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Guard: on fresh DB, Base.metadata.create_all already has the column
# Guard: on a fresh DB Base.metadata.create_all creates users table with the column already present
if not inspector.has_table("users"):
return
existing_cols = [c["name"] for c in inspector.get_columns("users")]
@@ -6,7 +6,6 @@ Create Date: 2026-04-04
"""
import sqlalchemy as sa
from alembic import op
revision = "006_email_inbound_token_server_default"
@@ -30,8 +29,7 @@ def upgrade() -> None:
"users",
"email_inbound_token",
server_default=sa.text(
"replace(replace(trim(trailing '=' from "
"encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
"replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
),
)
+3 -13
View File
@@ -27,8 +27,7 @@ def upgrade() -> None:
if inspector.has_table("users"):
return # Table already exists (non-fresh DB or create_all already ran)
conn.execute(
text("""
conn.execute(text("""
CREATE TABLE users (
id TEXT PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
@@ -37,20 +36,11 @@ def upgrade() -> None:
email_verified BOOLEAN NOT NULL DEFAULT false,
image TEXT,
email_inbound_token VARCHAR(22) NOT NULL UNIQUE
DEFAULT (
replace(
replace(
trim(trailing '=' from encode(gen_random_bytes(16), 'base64')),
'+', '-'
),
'/', '_'
)
),
DEFAULT replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_'),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
)
"""))
def downgrade() -> None:
+26 -150
View File
@@ -29,18 +29,8 @@ def upgrade() -> None:
sa.Column("slug", sa.String(20), nullable=False, unique=True),
sa.Column("logo_url", sa.String(500), nullable=True),
sa.Column("website_url", sa.String(500), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 2. store_locations
@@ -55,18 +45,8 @@ def upgrade() -> None:
sa.Column("zip", sa.String(10), nullable=False),
sa.Column("lat", sa.Float(), nullable=True),
sa.Column("lng", sa.Float(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 3. normalized_products
@@ -81,18 +61,8 @@ def upgrade() -> None:
sa.Column("size", sa.String(50), nullable=True),
sa.Column("size_unit", sa.String(10), nullable=True),
sa.Column("upc_variants", sa.JSON(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 4. purchases
@@ -102,9 +72,7 @@ def upgrade() -> None:
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column(
"store_location_id", sa.Uuid(), sa.ForeignKey("store_locations.id"), nullable=True
),
sa.Column("store_location_id", sa.Uuid(), sa.ForeignKey("store_locations.id"), nullable=True),
sa.Column("receipt_id", sa.String(200), nullable=False),
sa.Column("purchase_date", sa.Date(), nullable=False),
sa.Column("total", sa.Numeric(10, 2), nullable=False),
@@ -113,24 +81,9 @@ def upgrade() -> None:
sa.Column("savings_total", sa.Numeric(10, 2), nullable=True),
sa.Column("source_url", sa.String(500), nullable=True),
sa.Column("raw_data", sa.JSON(), nullable=True),
sa.Column(
"ingested_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("ingested_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
sa.Index("ix_purchases_user_store", "user_id", "store_id"),
)
@@ -151,24 +104,9 @@ def upgrade() -> None:
sa.Column("coupon_discount", sa.Numeric(10, 2), nullable=True),
sa.Column("loyalty_discount", sa.Numeric(10, 2), nullable=True),
sa.Column("category_raw", sa.String(100), nullable=True),
sa.Column(
"normalized_product_id",
sa.Uuid(),
sa.ForeignKey("normalized_products.id"),
nullable=True,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 6. coupons
@@ -177,12 +115,7 @@ def upgrade() -> None:
"coupons",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column(
"normalized_product_id",
sa.Uuid(),
sa.ForeignKey("normalized_products.id"),
nullable=True,
),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
sa.Column("title", sa.String(300), nullable=False),
sa.Column("description", sa.String(1000), nullable=True),
sa.Column("discount_type", sa.String(20), nullable=False),
@@ -194,18 +127,8 @@ def upgrade() -> None:
sa.Column("coupon_code", sa.String(100), nullable=True),
sa.Column("source_url", sa.String(500), nullable=True),
sa.Column("scraped_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 7. price_history
@@ -213,12 +136,7 @@ def upgrade() -> None:
op.create_table(
"price_history",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column(
"normalized_product_id",
sa.Uuid(),
sa.ForeignKey("normalized_products.id"),
nullable=False,
),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("observed_date", sa.Date(), nullable=False),
sa.Column("regular_price", sa.Numeric(10, 2), nullable=False),
@@ -226,27 +144,10 @@ def upgrade() -> None:
sa.Column("loyalty_price", sa.Numeric(10, 2), nullable=True),
sa.Column("coupon_price", sa.Numeric(10, 2), nullable=True),
sa.Column("source", sa.String(20), nullable=False),
sa.Column(
"purchase_item_id", sa.Uuid(), sa.ForeignKey("purchase_items.id"), nullable=True
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Index(
"ix_price_history_product_store_date",
"normalized_product_id",
"store_id",
"observed_date",
),
sa.Column("purchase_item_id", sa.Uuid(), sa.ForeignKey("purchase_items.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Index("ix_price_history_product_store_date", "normalized_product_id", "store_id", "observed_date"),
)
# 8. shrinkflation_events
@@ -254,12 +155,7 @@ def upgrade() -> None:
op.create_table(
"shrinkflation_events",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column(
"normalized_product_id",
sa.Uuid(),
sa.ForeignKey("normalized_products.id"),
nullable=False,
),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
sa.Column("detected_date", sa.Date(), nullable=False),
sa.Column("old_size", sa.String(50), nullable=False),
sa.Column("new_size", sa.String(50), nullable=False),
@@ -269,18 +165,8 @@ def upgrade() -> None:
sa.Column("price_at_new_size", sa.Numeric(10, 2), nullable=True),
sa.Column("confidence", sa.Numeric(3, 2), server_default=text("1.00"), nullable=False),
sa.Column("notes", sa.String(1000), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 9. user_store_accounts
@@ -294,18 +180,8 @@ def upgrade() -> None:
sa.Column("session_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("status", sa.String(20), server_default=text("'active'"), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),
)
@@ -6,7 +6,6 @@ Create Date: 2026-04-14
"""
import sqlalchemy as sa
from alembic import op
revision = "009_add_gin_index_upc_variants"
+1 -7
View File
@@ -5,8 +5,7 @@ Sessions are verified by querying the shared sessions table directly.
"""
from datetime import UTC, datetime
from fastapi import Depends, Header, HTTPException, Request, status
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
@@ -43,11 +42,6 @@ async def _validate_session_token(token: str, db: AsyncSession) -> str:
)
user_id, expires_at = row
# SQLite stores TIMESTAMP as TEXT and returns it as a string via raw
# SQL — normalise to a tz-aware datetime here so the comparison below
# works regardless of driver.
if isinstance(expires_at, str):
expires_at = datetime.fromisoformat(expires_at)
if expires_at.tzinfo is None:
# Treat naive datetimes as UTC
expires_at = expires_at.replace(tzinfo=UTC)
+2 -2
View File
@@ -4,8 +4,8 @@ import bcrypt
def hash_password(password: str) -> str:
return str(bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode())
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bool(bcrypt.checkpw(plain_password.encode(), hashed_password.encode()))
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
+3
View File
@@ -6,10 +6,13 @@ endpoints that query our own user data from the shared database.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from cartsnitch_api.auth.dependencies import get_current_user
from cartsnitch_api.database import get_db
from cartsnitch_api.models import User
from cartsnitch_api.schemas import (
UpdateUserRequest,
UserResponse,
+1 -6
View File
@@ -35,12 +35,7 @@ class CacheClient:
async def get(self, key: str) -> str | None:
if not self._client:
return None
value: str | bytes | None = await self._client.get(key)
if value is None:
return None
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return value
return await self._client.get(key)
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
if not self._client:
+2 -7
View File
@@ -23,12 +23,7 @@ class Settings(BaseSettings):
auth_service_url: str = "http://auth:3001"
cors_origins: list[str] = [
"http://localhost:3000",
"https://cartsnitch.com",
"https://dev.cartsnitch.com",
"https://uat.cartsnitch.com",
]
cors_origins: list[str] = ["http://localhost:3000", "https://cartsnitch.com"]
receiptwitness_url: str = "http://receiptwitness:8001"
stickershock_url: str = "http://stickershock:8002"
@@ -86,4 +81,4 @@ class Settings(BaseSettings):
return self
settings = Settings() # type: ignore[call-arg]
settings = Settings()
+8 -16
View File
@@ -6,22 +6,14 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from cartsnitch_api.config import settings
def _build_engine_kwargs() -> dict:
url = settings.database_url
kwargs: dict = {"echo": False}
if not url.startswith("sqlite"):
kwargs.update(
pool_size=10,
max_overflow=20,
pool_timeout=30,
pool_pre_ping=True,
pool_recycle=3600,
)
return kwargs
engine = create_async_engine(settings.database_url, **_build_engine_kwargs())
engine = create_async_engine(
settings.database_url,
echo=False,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
pool_recycle=3600,
)
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
+2 -3
View File
@@ -6,10 +6,11 @@ from fastapi import APIRouter, FastAPI
from cartsnitch_api.auth.routes import router as auth_router
from cartsnitch_api.cache import cache_client
from cartsnitch_api.middleware.audit import add_audit_middleware
from cartsnitch_api.database import dispose_engine
from cartsnitch_api.middleware.cors import add_cors_middleware
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
from cartsnitch_api.middleware.audit import add_audit_middleware
from cartsnitch_api.routes.alerts import router as alerts_router
from cartsnitch_api.routes.coupons import router as coupons_router
from cartsnitch_api.routes.health import router as health_router
@@ -25,8 +26,6 @@ from cartsnitch_api.routes.user import router as user_router
@asynccontextmanager
async def lifespan(app: FastAPI):
from cartsnitch_api.database import dispose_engine
await cache_client.initialize()
yield
await cache_client.close()
+1 -8
View File
@@ -25,9 +25,6 @@ logger = logging.getLogger(__name__)
class RateLimitBackend(Protocol):
"""Protocol for rate limit backends."""
max_requests: int
window_seconds: int
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
@@ -85,8 +82,7 @@ class RedisSlidingWindow:
if current_count >= self.max_requests:
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if oldest:
oldest_score = float(oldest[0][1])
retry_after = int((oldest_score - cutoff) / 1000) + 1
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
@@ -108,9 +104,6 @@ class RedisSlidingWindow:
_redis_client: Redis | None = None
_use_redis = False
_public_limiter: RateLimitBackend
_auth_limiter: RateLimitBackend
_auth_strict_limiter: RateLimitBackend
if settings.rate_limit_redis_enabled:
try:
+4 -3
View File
@@ -26,7 +26,9 @@ class User(TimestampMixin, Base):
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
display_name: Mapped[str | None] = mapped_column(String(100))
email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false")
email_verified: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default="false"
)
image: Mapped[str | None] = mapped_column(Text, nullable=True)
email_inbound_token: Mapped[str] = mapped_column(
String(22),
@@ -34,8 +36,7 @@ class User(TimestampMixin, Base):
unique=True,
default=lambda: secrets.token_urlsafe(16),
server_default=sa.text(
"replace(replace(trim(trailing '=' from "
"encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
"replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
),
)
+3 -27
View File
@@ -1,40 +1,16 @@
"""Health check and error metrics endpoints."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import APIRouter, Depends
from cartsnitch_api.auth.dependencies import verify_service_key
from cartsnitch_api.database import get_db
from cartsnitch_api.middleware.error_handler import get_error_monitor
logger = logging.getLogger(__name__)
router = APIRouter(tags=["health"])
@router.get("/health")
async def health(db: AsyncSession = Depends(get_db)):
"""Liveness + DB connectivity probe.
Returns HTTP 200 when the API process is responsive *and* the database
is reachable, so Kubernetes readiness probes can correctly route traffic
away from pods that have lost their database connection.
Returns HTTP 503 when the database is unreachable so K8s marks the pod
unhealthy and stops sending traffic to it.
"""
try:
await db.execute(text("SELECT 1"))
except Exception as exc:
logger.exception("Health check failed: database unreachable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"status": "unavailable", "database": "disconnected"},
) from exc
return {"status": "ok", "database": "connected"}
async def health():
return {"status": "ok"}
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
+1 -1
View File
@@ -16,7 +16,7 @@ class UpdateUserRequest(BaseModel):
class UserResponse(BaseModel):
id: UUID
id: str
email: str
display_name: str
created_at: datetime
+6 -145
View File
@@ -10,113 +10,15 @@ from datetime import UTC, datetime, timedelta
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import String, TypeDecorator, Uuid, create_engine, event, text
from sqlalchemy import create_engine, event, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import CHAR
from cartsnitch_api.config import settings as cartsnitch_settings
from cartsnitch_api.database import get_db
from cartsnitch_api.main import create_app
from cartsnitch_api.middleware import rate_limit as _rate_limit_module
from cartsnitch_api.models import Base
class _StringUUID(TypeDecorator):
"""TypeDecorator that lets Text/String/UUID columns accept uuid.UUID on bind.
SQLite has no native UUID type — passing a ``uuid.UUID`` raises
``type 'UUID' is not supported``. This stores UUID values as their hex
string in the DB, accepts either uuid.UUID or str at bind time, and
returns uuid.UUID on read so existing test assertions like
``isinstance(store.id, uuid.UUID)`` still work.
"""
impl = CHAR(36)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return None
if isinstance(value, uuid.UUID):
return str(value)
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return None
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(value)
def _set_timestamp_defaults(mapper, connection, target):
"""Populate created_at/updated_at and missing PK IDs for SQLite.
SQLite can't bind ``uuid.UUID`` objects to Text/String columns, and has
no server-side default for ``func.now()`` or ``gen_random_uuid()``. We
strip those server_defaults elsewhere; this listener fills in
Python-side timestamp defaults at insert time, generates IDs for PK
columns that have no default, and populates ``func.now()`` columns
whose server_default was stripped (e.g. ``ingested_at``). UUID values
for non-PK columns are converted by the ``_StringUUID`` TypeDecorator.
"""
now = datetime.now(UTC)
for col in mapper.columns:
key = col.key
if key in ("created_at", "updated_at"):
if getattr(target, key, None) is None:
setattr(target, key, now)
continue
if col.primary_key and getattr(target, key, None) is None:
setattr(target, key, str(uuid.uuid4()))
continue
if getattr(col, "_sqlite_default_now", False) and getattr(target, key, None) is None:
setattr(target, key, now)
def _adapt_columns_for_sqlite():
"""Strip Postgres-only server_defaults and adapt UUID columns for SQLite.
Must be called BEFORE ``Base.metadata.create_all`` so the DDL reflects
the adapted column types.
"""
for tbl in Base.metadata.tables.values():
for col in tbl.columns.values():
# Strip PostgreSQL-specific function server_defaults (gen_random_uuid,
# gen_random_bytes, now()) but keep simple string-literal defaults
# like ``server_default="false"`` since they work in SQLite.
sd = col.server_default
if sd is not None:
sd_text = str(sd.arg) if hasattr(sd, "arg") else str(sd)
sd_text = sd_text.lower()
if any(x in sd_text for x in ["gen_random_uuid", "gen_random_bytes", "now()"]):
col.server_default = None
if "now()" in sd_text and not col.nullable:
col._sqlite_default_now = True # type: ignore[attr-defined]
# Replace UUID column types with a SQLite-compatible TypeDecorator
if isinstance(col.type, Uuid):
col.type = _StringUUID()
# Text/String PK columns without a default need the _StringUUID type
# so the before_insert listener can generate hex-string IDs.
if col.primary_key and col.default is None and col.server_default is None:
if not isinstance(col.type, _StringUUID):
col.type = _StringUUID()
# FK columns that may receive uuid.UUID values from test code
if col.foreign_keys and not col.primary_key and isinstance(col.type, String):
col.type = _StringUUID()
def _register_event_listeners():
"""Attach before_insert listener to every mapped class."""
for cls in Base.registry._class_registry.values():
if hasattr(cls, "__mapper__"):
event.listen(cls, "before_insert", _set_timestamp_defaults)
TEST_JWT_SECRET = secrets.token_urlsafe(32)
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
@@ -141,52 +43,16 @@ TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest.fixture(autouse=True)
def disable_rate_limiting():
"""Disable rate limiting for all tests to prevent 429 interference.
The rate_limit module creates its Redis client at import time when
``settings.rate_limit_redis_enabled`` is true. We can't undo that by
flipping the setting inside the fixture — the client and the
Redis-backed limiters are already constructed. So we swap them out
for the in-memory limiters directly on the module, which also
prevents "Event loop is closed" errors when the redis client tries
to disconnect after the test event loop ends.
"""
"""Disable rate limiting for all tests to prevent 429 interference."""
cartsnitch_settings.rate_limit_enabled = False
cartsnitch_settings.rate_limit_redis_enabled = False
original_public = _rate_limit_module._public_limiter
original_auth = _rate_limit_module._auth_limiter
original_auth_strict = _rate_limit_module._auth_strict_limiter
_rate_limit_module._redis_client = None
_rate_limit_module._use_redis = False
_rate_limit_module._public_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_requests, cartsnitch_settings.rate_limit_window_seconds
)
_rate_limit_module._auth_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_requests * 5, cartsnitch_settings.rate_limit_window_seconds
)
_rate_limit_module._auth_strict_limiter = _rate_limit_module.InMemorySlidingWindow(
cartsnitch_settings.rate_limit_auth_requests,
cartsnitch_settings.rate_limit_auth_window_seconds,
)
yield
cartsnitch_settings.rate_limit_enabled = True
cartsnitch_settings.rate_limit_redis_enabled = True
_rate_limit_module._public_limiter = original_public
_rate_limit_module._auth_limiter = original_auth
_rate_limit_module._auth_strict_limiter = original_auth_strict
@pytest.fixture
def engine():
"""Sync in-memory SQLite engine for model unit tests.
Strips PostgreSQL-specific server_default expressions, replaces UUID
column types with a SQLite-compatible TypeDecorator, and registers a
before_insert event listener to populate timestamps.
"""
"""Sync in-memory SQLite engine for model unit tests."""
eng = create_engine("sqlite:///:memory:")
_adapt_columns_for_sqlite()
_register_event_listeners()
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@@ -210,11 +76,9 @@ async def db_engine():
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
_adapt_columns_for_sqlite()
_register_event_listeners()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Create Better-Auth tables (not managed by SQLAlchemy models)
await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS sessions (
@@ -313,10 +177,8 @@ async def _create_test_user_and_session(
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, "
"email_verified, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, "
":email_verified, :email_inbound_token, :created_at, :updated_at)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hashed_password, :display_name, :email_verified, :created_at, :updated_at)"
),
{
"id": user_id,
@@ -324,7 +186,6 @@ async def _create_test_user_and_session(
"hashed_password": "not-used-with-better-auth",
"display_name": display_name,
"email_verified": False,
"email_inbound_token": secrets.token_urlsafe(16),
"created_at": now,
"updated_at": now,
},
+2 -4
View File
@@ -138,9 +138,8 @@ async def test_expired_session_rejected(client, db_engine):
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, "
"email_verified, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :token, :ca, :ua)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
),
{
"id": user_id,
@@ -148,7 +147,6 @@ async def test_expired_session_rejected(client, db_engine):
"hp": "unused",
"dn": "Expired User",
"ev": False,
"token": secrets.token_urlsafe(16),
"ca": now,
"ua": now,
},
+6 -15
View File
@@ -1,5 +1,7 @@
"""Tests for Settings config, specifically the database_url env var fallback."""
import os
from cartsnitch_api.config import Settings
@@ -28,10 +30,7 @@ def test_database_url_normalizes_plain_postgresql_prefix():
"DATABASE_URL": "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
}
settings = Settings(**env)
assert (
settings.database_url
== "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
)
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
def test_database_url_preserves_asyncpg_prefix():
@@ -40,18 +39,10 @@ def test_database_url_preserves_asyncpg_prefix():
"CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
}
settings = Settings(**env)
assert (
settings.database_url
== "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
)
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
def test_database_url_default(monkeypatch):
def test_database_url_default():
"""When neither env var is set, the hardcoded default is used."""
monkeypatch.delenv("CARTSNITCH_DATABASE_URL", raising=False)
monkeypatch.delenv("DATABASE_URL", raising=False)
settings = Settings()
assert (
settings.database_url
== "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
)
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
+2 -2
View File
@@ -195,7 +195,7 @@ async def seed_data(db_engine, auth_headers):
discount_type="fixed",
discount_value=Decimal("1.00"),
valid_from=today - timedelta(days=7),
valid_to=date.today() + timedelta(days=30),
valid_to=today + timedelta(days=30),
)
coupon2 = Coupon(
store_id=kroger.id,
@@ -205,7 +205,7 @@ async def seed_data(db_engine, auth_headers):
discount_type="percent",
discount_value=Decimal("10.00"),
valid_from=today - timedelta(days=3),
valid_to=date.today() + timedelta(days=14),
valid_to=today + timedelta(days=14),
)
session.add_all([coupon1, coupon2])
await session.flush()
+12 -14
View File
@@ -65,9 +65,8 @@ class TestSessionValidation:
async with db_engine.begin() as conn:
await conn.execute(
text(
"INSERT INTO users (id, email, hashed_password, display_name, "
"email_verified, email_inbound_token, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :token, :ca, :ua)"
"INSERT INTO users (id, email, hashed_password, display_name, email_verified, created_at, updated_at) "
"VALUES (:id, :email, :hp, :dn, :ev, :ca, :ua)"
),
{
"id": user_id,
@@ -75,7 +74,6 @@ class TestSessionValidation:
"hp": "unused",
"dn": "Expired User",
"ev": False,
"token": secrets.token_urlsafe(16),
"ca": now,
"ua": now,
},
@@ -109,13 +107,13 @@ class TestAuthProtectedEndpoints:
@pytest.mark.parametrize(
"method,path",
[
("GET", "/api/v1/purchases"),
("GET", "/api/v1/products"),
("GET", "/api/v1/prices/trends"),
("GET", "/api/v1/prices/increases"),
("GET", "/api/v1/coupons"),
("GET", "/api/v1/alerts"),
("GET", "/api/v1/me/stores"),
("GET", "/purchases"),
("GET", "/products"),
("GET", "/prices/trends"),
("GET", "/prices/increases"),
("GET", "/coupons"),
("GET", "/alerts"),
("GET", "/me/stores"),
],
)
async def test_endpoints_require_auth(self, client, db_engine, method, path):
@@ -136,7 +134,7 @@ class TestCrossUserDataIsolation:
)
user_b_headers = {"Cookie": f"better-auth.session_token={session_token}"}
resp = await client.get(f"/api/v1/purchases/{purchase_id}", headers=user_b_headers)
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
assert resp.status_code in (403, 404), (
"User B should not be able to access User A's purchase"
)
@@ -148,7 +146,7 @@ class TestCrossUserDataIsolation:
)
user_c_headers = {"Cookie": f"better-auth.session_token={session_token}"}
resp = await client.get("/api/v1/purchases", headers=user_c_headers)
resp = await client.get("/purchases", headers=user_c_headers)
assert resp.status_code == 200
assert len(resp.json()) == 0, "New user should have no purchases"
@@ -159,6 +157,6 @@ class TestCrossUserDataIsolation:
)
user_d_headers = {"Cookie": f"better-auth.session_token={session_token}"}
resp = await client.get("/api/v1/me/stores", headers=user_d_headers)
resp = await client.get("/me/stores", headers=user_d_headers)
assert resp.status_code == 200
assert len(resp.json()) == 0, "New user should have no connected stores"
+12 -12
View File
@@ -10,23 +10,23 @@ class TestStoreConnectToPurchaseFlow:
async def test_connect_store_then_list(self, client, seed_data):
headers = seed_data["headers"]
# Connect to Meijer
resp = await client.post("/api/v1/me/stores/meijer/connect", json={}, headers=headers)
resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert resp.status_code in (200, 201)
# Verify store appears in user's connected stores
stores = await client.get("/api/v1/me/stores", headers=headers)
stores = await client.get("/me/stores", headers=headers)
assert stores.status_code == 200
slugs = [s["store"]["slug"] for s in stores.json()]
assert "meijer" in slugs
async def test_disconnect_store(self, client, seed_data):
headers = seed_data["headers"]
await client.post("/api/v1/me/stores/kroger/connect", json={}, headers=headers)
resp = await client.delete("/api/v1/me/stores/kroger", headers=headers)
await client.post("/me/stores/kroger/connect", json={}, headers=headers)
resp = await client.delete("/me/stores/kroger", headers=headers)
assert resp.status_code in (200, 204)
# Verify store no longer in connected list
stores = await client.get("/api/v1/me/stores", headers=headers)
stores = await client.get("/me/stores", headers=headers)
slugs = [s["store"]["slug"] for s in stores.json()]
assert "kroger" not in slugs
@@ -41,7 +41,7 @@ class TestPurchaseToPriceFlow:
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
# Get purchase detail
purchase = await client.get(f"/api/v1/purchases/{purchase_id}", headers=headers)
purchase = await client.get(f"/purchases/{purchase_id}", headers=headers)
assert purchase.status_code == 200
items = purchase.json()["line_items"]
@@ -50,7 +50,7 @@ class TestPurchaseToPriceFlow:
assert len(product_ids) >= 1
for pid in product_ids:
product = await client.get(f"/api/v1/products/{pid}", headers=headers)
product = await client.get(f"/products/{pid}", headers=headers)
assert product.status_code == 200
assert len(product.json()["prices_by_store"]) >= 1
@@ -61,7 +61,7 @@ class TestCouponFlow:
async def test_list_all_coupons(self, client, seed_data):
headers = seed_data["headers"]
resp = await client.get("/api/v1/coupons", headers=headers)
resp = await client.get("/coupons", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
@@ -71,7 +71,7 @@ class TestCouponFlow:
async def test_filter_coupons_by_store(self, client, seed_data):
headers = seed_data["headers"]
meijer_id = str(seed_data["stores"]["meijer"].id)
resp = await client.get("/api/v1/coupons", params={"store_id": meijer_id}, headers=headers)
resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers)
assert resp.status_code == 200
data = resp.json()
assert all(c["store_name"] == "Meijer" for c in data)
@@ -79,7 +79,7 @@ class TestCouponFlow:
async def test_relevant_coupons_for_user(self, client, seed_data):
"""User bought Cheerios, so the Cheerios coupon should be relevant."""
headers = seed_data["headers"]
resp = await client.get("/api/v1/coupons/relevant", headers=headers)
resp = await client.get("/coupons/relevant", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases"
@@ -94,7 +94,7 @@ class TestAlertFlow:
async def test_list_alerts(self, client, seed_data):
"""User bought Cheerios which has a shrinkflation event — may appear as alert."""
headers = seed_data["headers"]
resp = await client.get("/api/v1/alerts", headers=headers)
resp = await client.get("/alerts", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
@@ -107,7 +107,7 @@ class TestAlertFlow:
async def test_alert_settings_default(self, client, seed_data):
headers = seed_data["headers"]
resp = await client.get("/api/v1/alerts/settings", headers=headers)
resp = await client.get("/alerts/settings", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert "price_increase_threshold_pct" in data
+9 -16
View File
@@ -6,12 +6,6 @@ from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
@pytest.mark.asyncio
@pytest.mark.skip(
reason=(
"/auth/register, /auth/login, /auth/refresh are handled by "
"the Better-Auth service, not this gateway"
)
)
class TestRegistrationErrors:
"""Validation errors during user registration."""
@@ -53,7 +47,6 @@ class TestRegistrationErrors:
@pytest.mark.asyncio
@pytest.mark.skip(reason="/auth/login is handled by the Better-Auth service, not this gateway")
class TestLoginErrors:
"""Login failure modes."""
@@ -85,15 +78,15 @@ class TestNotFoundErrors:
"""404 responses for missing resources."""
async def test_product_not_found(self, client, seed_data):
resp = await client.get(f"/api/v1/products/{ZERO_UUID}", headers=seed_data["headers"])
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_purchase_not_found(self, client, seed_data):
resp = await client.get(f"/api/v1/purchases/{ZERO_UUID}", headers=seed_data["headers"])
resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_public_trend_not_found(self, client, seed_data):
resp = await client.get(f"/api/v1/public/trends/{ZERO_UUID}")
resp = await client.get(f"/public/trends/{ZERO_UUID}")
assert resp.status_code == 404
@@ -102,15 +95,15 @@ class TestMalformedInput:
"""Invalid UUID formats and bad query params."""
async def test_invalid_uuid_product(self, client, seed_data):
resp = await client.get(f"/api/v1/products/{BAD_UUID}", headers=seed_data["headers"])
resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_invalid_uuid_purchase(self, client, seed_data):
resp = await client.get(f"/api/v1/purchases/{BAD_UUID}", headers=seed_data["headers"])
resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_invalid_uuid_public_trend(self, client, seed_data):
resp = await client.get(f"/api/v1/public/trends/{BAD_UUID}")
resp = await client.get(f"/public/trends/{BAD_UUID}")
assert resp.status_code == 422
@@ -120,7 +113,7 @@ class TestStoreConnectionErrors:
async def test_connect_nonexistent_store(self, client, seed_data):
resp = await client.post(
"/api/v1/me/stores/nonexistent-store/connect",
"/me/stores/nonexistent-store/connect",
json={},
headers=seed_data["headers"],
)
@@ -128,7 +121,7 @@ class TestStoreConnectionErrors:
async def test_connect_store_twice(self, client, seed_data):
headers = seed_data["headers"]
first = await client.post("/api/v1/me/stores/meijer/connect", json={}, headers=headers)
first = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert first.status_code in (200, 201)
second = await client.post("/api/v1/me/stores/meijer/connect", json={}, headers=headers)
second = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert second.status_code == 409
+8 -8
View File
@@ -8,7 +8,7 @@ class TestPriceTrends:
"""Verify price trend aggregation against seeded history."""
async def test_trends_returns_all_products(self, client, seed_data):
resp = await client.get("/api/v1/prices/trends", headers=seed_data["headers"])
resp = await client.get("/prices/trends", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
product_names = [t["product_name"] for t in data]
@@ -17,7 +17,7 @@ class TestPriceTrends:
async def test_trends_filter_by_category(self, client, seed_data):
resp = await client.get(
"/api/v1/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
"/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
data = resp.json()
@@ -27,7 +27,7 @@ class TestPriceTrends:
assert trend["product_name"] == "Whole Milk 1gal"
async def test_trends_contain_data_points(self, client, seed_data):
resp = await client.get("/api/v1/prices/trends", headers=seed_data["headers"])
resp = await client.get("/prices/trends", headers=seed_data["headers"])
data = resp.json()
cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz")
assert len(cheerios_trend["data_points"]) >= 3
@@ -38,7 +38,7 @@ class TestPriceIncreases:
"""Detect price increases from seeded price history."""
async def test_increases_detected(self, client, seed_data):
resp = await client.get("/api/v1/prices/increases", headers=seed_data["headers"])
resp = await client.get("/prices/increases", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
# Cheerios at Meijer went from 3.99 → 4.29 → 4.79
@@ -52,7 +52,7 @@ class TestPriceIncreases:
async def test_stable_prices_not_flagged(self, client, seed_data):
"""Kroger Cheerios price is stable at $4.49 — should not appear as increase."""
resp = await client.get("/api/v1/prices/increases", headers=seed_data["headers"])
resp = await client.get("/prices/increases", headers=seed_data["headers"])
data = resp.json()
kroger_increases = [
inc
@@ -69,7 +69,7 @@ class TestPriceComparison:
async def test_compare_cheerios_across_stores(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(
"/api/v1/prices/comparison",
"/prices/comparison",
params={"product_ids": cheerios_id},
headers=seed_data["headers"],
)
@@ -84,14 +84,14 @@ class TestPriceComparison:
async def test_compare_requires_product_ids(self, client, seed_data):
"""product_ids is required — omitting it must return 422."""
resp = await client.get("/api/v1/prices/comparison", headers=seed_data["headers"])
resp = await client.get("/prices/comparison", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_compare_multiple_products(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
milk_id = str(seed_data["products"]["milk"].id)
resp = await client.get(
"/api/v1/prices/comparison",
"/prices/comparison",
params=[("product_ids", cheerios_id), ("product_ids", milk_id)],
headers=seed_data["headers"],
)
+8 -12
View File
@@ -10,7 +10,7 @@ class TestProductSearch:
"""Search and filter products against seeded data."""
async def test_list_all_products(self, client, seed_data):
resp = await client.get("/api/v1/products", headers=seed_data["headers"])
resp = await client.get("/products", headers=seed_data["headers"])
assert resp.status_code == 200
products = resp.json()
names = [p["name"] for p in products]
@@ -19,9 +19,7 @@ class TestProductSearch:
assert "Chicken Breast 1lb" in names
async def test_search_by_name(self, client, seed_data):
resp = await client.get(
"/api/v1/products", params={"q": "cheerios"}, headers=seed_data["headers"]
)
resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"])
assert resp.status_code == 200
products = resp.json()
assert len(products) >= 1
@@ -29,7 +27,7 @@ class TestProductSearch:
async def test_search_by_category(self, client, seed_data):
resp = await client.get(
"/api/v1/products", params={"category": "dairy"}, headers=seed_data["headers"]
"/products", params={"category": "dairy"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
products = resp.json()
@@ -38,7 +36,7 @@ class TestProductSearch:
async def test_search_no_results(self, client, seed_data):
resp = await client.get(
"/api/v1/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
"/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
assert resp.json() == []
@@ -50,7 +48,7 @@ class TestProductLookup:
async def test_get_product_detail_with_prices(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/api/v1/products/{cheerios_id}", headers=seed_data["headers"])
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Cheerios 18oz"
@@ -64,20 +62,18 @@ class TestProductLookup:
async def test_product_prices_reflect_latest(self, client, seed_data):
"""The latest Meijer price for Cheerios should be 4.79 (the increase)."""
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/api/v1/products/{cheerios_id}", headers=seed_data["headers"])
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
data = resp.json()
meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer")
assert meijer_price["current_price"] == 4.79
async def test_product_not_found(self, client, seed_data):
resp = await client.get(f"/api/v1/products/{ZERO_UUID}", headers=seed_data["headers"])
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_product_price_history(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(
f"/api/v1/products/{cheerios_id}/prices", headers=seed_data["headers"]
)
resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations
+6 -6
View File
@@ -11,16 +11,16 @@ class TestPublicTrends:
async def test_public_trend_returns_data(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/api/v1/public/trends/{cheerios_id}")
resp = await client.get(f"/public/trends/{cheerios_id}")
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Cheerios 18oz"
assert len(data["data_points"]) >= 2
assert len(data["data_points"]) >= 3
async def test_public_trend_no_auth_needed(self, client, seed_data):
"""Confirm no Authorization header is required."""
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/api/v1/public/trends/{cheerios_id}")
resp = await client.get(f"/public/trends/{cheerios_id}")
assert resp.status_code == 200
@@ -31,7 +31,7 @@ class TestPublicStoreComparison:
async def test_store_comparison(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(
"/api/v1/public/store-comparison",
"/public/store-comparison",
params=[("product_ids", cheerios_id)],
)
assert resp.status_code == 200
@@ -42,7 +42,7 @@ class TestPublicStoreComparison:
async def test_store_comparison_rejects_more_than_20_ids(self, client):
"""max_length=20 guard: 21 product IDs must return 422."""
too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)]
resp = await client.get("/api/v1/public/store-comparison", params=too_many)
resp = await client.get("/public/store-comparison", params=too_many)
assert resp.status_code == 422
@@ -51,7 +51,7 @@ class TestPublicInflation:
"""Public inflation index endpoint."""
async def test_inflation_returns_index(self, client, seed_data):
resp = await client.get("/api/v1/public/inflation")
resp = await client.get("/public/inflation")
assert resp.status_code == 200
data = resp.json()
assert "cartsnitch_index" in data
+8 -8
View File
@@ -10,7 +10,7 @@ class TestPurchaseList:
"""List and filter a user's purchases."""
async def test_list_user_purchases(self, client, seed_data):
resp = await client.get("/api/v1/purchases", headers=seed_data["headers"])
resp = await client.get("/purchases", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
@@ -21,7 +21,7 @@ class TestPurchaseList:
async def test_filter_purchases_by_store(self, client, seed_data):
meijer_id = str(seed_data["stores"]["meijer"].id)
resp = await client.get(
"/api/v1/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
"/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
)
assert resp.status_code == 200
data = resp.json()
@@ -29,7 +29,7 @@ class TestPurchaseList:
assert all(p["store_name"] == "Meijer" for p in data)
async def test_purchases_require_auth(self, client, seed_data):
resp = await client.get("/api/v1/purchases")
resp = await client.get("/purchases")
assert resp.status_code in (401, 403)
@@ -39,7 +39,7 @@ class TestPurchaseDetail:
async def test_get_purchase_detail(self, client, seed_data):
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
resp = await client.get(f"/api/v1/purchases/{purchase_id}", headers=seed_data["headers"])
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["store_name"] == "Meijer"
@@ -51,7 +51,7 @@ class TestPurchaseDetail:
async def test_line_item_amounts_correct(self, client, seed_data):
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
resp = await client.get(f"/api/v1/purchases/{purchase_id}", headers=seed_data["headers"])
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
data = resp.json()
cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"])
assert cheerios_item["unit_price"] == 4.79
@@ -60,7 +60,7 @@ class TestPurchaseDetail:
async def test_purchase_not_found(self, client, seed_data):
resp = await client.get(
f"/api/v1/purchases/{ZERO_UUID}",
f"/purchases/{ZERO_UUID}",
headers=seed_data["headers"],
)
assert resp.status_code == 404
@@ -71,7 +71,7 @@ class TestPurchaseStats:
"""Verify spending aggregation across purchases."""
async def test_purchase_stats_totals(self, client, seed_data):
resp = await client.get("/api/v1/purchases/stats", headers=seed_data["headers"])
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["purchase_count"] == 2
@@ -79,7 +79,7 @@ class TestPurchaseStats:
assert abs(data["total_spent"] - 39.23) < 0.01
async def test_purchase_stats_by_store(self, client, seed_data):
resp = await client.get("/api/v1/purchases/stats", headers=seed_data["headers"])
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
data = resp.json()
assert "Meijer" in data["by_store"]
assert "Kroger" in data["by_store"]
+18 -1
View File
@@ -5,13 +5,30 @@ import json
import pytest
from cryptography.fernet import Fernet
from pydantic import ValidationError
from sqlalchemy import column, table, text
from sqlalchemy import column, create_engine, table, text
from sqlalchemy.orm import sessionmaker
from cartsnitch_api.config import settings
from cartsnitch_api.models import Base
from cartsnitch_api.models.store import Store
from cartsnitch_api.models.user import User, UserStoreAccount
@pytest.fixture
def engine():
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def session(engine):
factory = sessionmaker(bind=engine)
with factory() as sess:
yield sess
@pytest.fixture
def store(session):
s = Store(name="Test Store", slug="test-store")
+5 -10
View File
@@ -2,8 +2,6 @@
import pytest
from cartsnitch_api.config import settings
@pytest.mark.asyncio
async def test_404_returns_structured_error(client):
@@ -17,14 +15,11 @@ async def test_404_returns_structured_error(client):
@pytest.mark.asyncio
async def test_validation_error_returns_422_with_field_errors(client, auth_headers):
async def test_validation_error_returns_422_with_field_errors(client):
"""Invalid request body should return structured validation errors."""
# Use the auth/me PATCH endpoint with an invalid email — Pydantic will
# return 422 with structured field errors before any DB lookup runs.
resp = await client.patch(
"/auth/me",
json={"email": "not-an-email"},
headers=auth_headers,
resp = await client.post(
"/auth/register",
json={"email": "not-an-email", "password": "short", "display_name": ""},
)
assert resp.status_code == 422
body = resp.json()
@@ -51,7 +46,7 @@ async def test_error_stats_with_valid_key(client):
"""Error stats endpoint returns monitoring data with valid key."""
resp = await client.get(
"/internal/error-stats",
headers={"X-Service-Key": settings.service_key},
headers={"X-Service-Key": "change-me-in-production"},
)
assert resp.status_code == 200
body = resp.json()
+26 -39
View File
@@ -1,7 +1,7 @@
"""Tests for rate limiting middleware."""
import time
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -15,47 +15,43 @@ from cartsnitch_api.middleware.rate_limit import (
class TestInMemorySlidingWindow:
@pytest.mark.asyncio
async def test_allows_within_limit(self):
def test_allows_within_limit(self):
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
for i in range(5):
allowed, remaining, retry = await limiter.is_allowed("test-key")
allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 4 - i
@pytest.mark.asyncio
async def test_blocks_over_limit(self):
def test_blocks_over_limit(self):
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
for _ in range(3):
await limiter.is_allowed("test-key")
limiter.is_allowed("test-key")
allowed, remaining, retry = await limiter.is_allowed("test-key")
allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is False
assert remaining == 0
assert retry > 0
@pytest.mark.asyncio
async def test_separate_keys(self):
def test_separate_keys(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
await limiter.is_allowed("key-a")
await limiter.is_allowed("key-a")
allowed_a, _, _ = await limiter.is_allowed("key-a")
limiter.is_allowed("key-a")
limiter.is_allowed("key-a")
allowed_a, _, _ = limiter.is_allowed("key-a")
assert allowed_a is False
allowed_b, remaining, _ = await limiter.is_allowed("key-b")
allowed_b, remaining, _ = limiter.is_allowed("key-b")
assert allowed_b is True
assert remaining == 1
@pytest.mark.asyncio
async def test_resets_after_window_expires(self):
def test_resets_after_window_expires(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
for _ in range(2):
await limiter.is_allowed("test-key")
allowed, remaining, _ = await limiter.is_allowed("test-key")
limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is False
time.sleep(1.1)
allowed, remaining, _ = await limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 1
@@ -77,7 +73,7 @@ class TestGetClientIp:
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1:8080"
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
@@ -125,7 +121,7 @@ class TestGetRateLimitKey:
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
@@ -158,15 +154,11 @@ class TestGetRateLimitKey:
class TestRedisSlidingWindowFallback:
@pytest.mark.asyncio
async def test_fallback_on_redis_connection_error(self):
mock_redis = MagicMock()
from redis.exceptions import RedisError
async def raise_on_execute(*args, **kwargs):
raise RedisError("Connection refused")
pipe_mock = MagicMock()
pipe_mock.execute = raise_on_execute
mock_redis.pipeline = MagicMock(return_value=pipe_mock)
mock_redis = AsyncMock()
mock_redis.pipeline.return_value = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Connection refused")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
@@ -175,15 +167,10 @@ class TestRedisSlidingWindowFallback:
@pytest.mark.asyncio
async def test_fallback_on_redis_error_during_pipeline(self):
mock_redis = MagicMock()
from redis.exceptions import RedisError
async def raise_on_execute(*args, **kwargs):
raise RedisError("Redis error")
pipe_mock = MagicMock()
pipe_mock.execute = raise_on_execute
mock_redis.pipeline = MagicMock(return_value=pipe_mock)
mock_redis = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Redis error")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
+31 -27
View File
@@ -6,44 +6,48 @@ from httpx import ASGITransport, AsyncClient
from cartsnitch_api.main import app
EXPECTED_ROUTES = [
# Auth (3 — register/login/refresh are handled by Better-Auth service)
# Auth (7)
("post", "/auth/register"),
("post", "/auth/login"),
("post", "/auth/refresh"),
("get", "/auth/me"),
("patch", "/auth/me"),
("delete", "/auth/me"),
("get", "/auth/me/email-in-address"),
# Stores (4)
("get", "/api/v1/stores"),
("get", "/api/v1/me/stores"),
("post", "/api/v1/me/stores/{store_slug}/connect"),
("delete", "/api/v1/me/stores/{store_slug}"),
("get", "/stores"),
("get", "/me/stores"),
("post", "/me/stores/{store_slug}/connect"),
("delete", "/me/stores/{store_slug}"),
# Purchases (3)
("get", "/api/v1/purchases"),
("get", "/api/v1/purchases/stats"),
("get", "/api/v1/purchases/{purchase_id}"),
("get", "/purchases"),
("get", "/purchases/stats"),
("get", "/purchases/{purchase_id}"),
# Products (3)
("get", "/api/v1/products"),
("get", "/api/v1/products/{product_id}"),
("get", "/api/v1/products/{product_id}/prices"),
("get", "/products"),
("get", "/products/{product_id}"),
("get", "/products/{product_id}/prices"),
# Prices (3)
("get", "/api/v1/prices/trends"),
("get", "/api/v1/prices/increases"),
("get", "/api/v1/prices/comparison"),
("get", "/prices/trends"),
("get", "/prices/increases"),
("get", "/prices/comparison"),
# Coupons (2)
("get", "/api/v1/coupons"),
("get", "/api/v1/coupons/relevant"),
("get", "/coupons"),
("get", "/coupons/relevant"),
# Shopping (2)
("post", "/api/v1/shopping/optimize"),
("get", "/api/v1/shopping/lists"),
("post", "/shopping/optimize"),
("get", "/shopping/lists"),
# Alerts (3)
("get", "/api/v1/alerts"),
("get", "/api/v1/alerts/settings"),
("put", "/api/v1/alerts/settings"),
("get", "/alerts"),
("get", "/alerts/settings"),
("put", "/alerts/settings"),
# Scraping (2)
("post", "/api/v1/scraping/{store_slug}/sync"),
("get", "/api/v1/scraping/status"),
("post", "/scraping/{store_slug}/sync"),
("get", "/scraping/status"),
# Public (3)
("get", "/api/v1/public/trends/{product_id}"),
("get", "/api/v1/public/store-comparison"),
("get", "/api/v1/public/inflation"),
("get", "/public/trends/{product_id}"),
("get", "/public/store-comparison"),
("get", "/public/inflation"),
# Health (1)
("get", "/health"),
]
@@ -86,4 +90,4 @@ async def test_route_count():
if method in ("get", "post", "put", "delete", "patch"):
count += 1
assert count == 31, f"Expected 31 routes, found {count}"
assert count == 34, f"Expected 34 routes, found {count}"
+3 -3
View File
@@ -6,14 +6,14 @@ import pytest
@pytest.mark.asyncio
async def test_list_alerts_empty(client, auth_headers):
"""No purchases means no alerts."""
resp = await client.get("/api/v1/alerts", headers=auth_headers)
resp = await client.get("/alerts", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
@pytest.mark.asyncio
async def test_get_alert_settings(client, auth_headers):
resp = await client.get("/api/v1/alerts/settings", headers=auth_headers)
resp = await client.get("/alerts/settings", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["price_increase_threshold_pct"] == 5.0
@@ -24,7 +24,7 @@ async def test_get_alert_settings(client, auth_headers):
@pytest.mark.asyncio
async def test_update_alert_settings_returns_501(client, auth_headers):
resp = await client.put(
"/api/v1/alerts/settings",
"/alerts/settings",
headers=auth_headers,
json={
"price_increase_threshold_pct": 10.0,
+3 -3
View File
@@ -36,7 +36,7 @@ async def coupon_data(db_engine, auth_headers):
@pytest.mark.asyncio
async def test_list_coupons(client, coupon_data):
resp = await client.get("/api/v1/coupons", headers=coupon_data["headers"])
resp = await client.get("/coupons", headers=coupon_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@@ -45,7 +45,7 @@ async def test_list_coupons(client, coupon_data):
@pytest.mark.asyncio
async def test_list_coupons_by_store(client, coupon_data):
store_id = str(coupon_data["store"].id)
resp = await client.get(f"/api/v1/coupons?store_id={store_id}", headers=coupon_data["headers"])
resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) >= 1
@@ -53,6 +53,6 @@ async def test_list_coupons_by_store(client, coupon_data):
@pytest.mark.asyncio
async def test_relevant_coupons_empty(client, auth_headers):
"""No purchases means no relevant coupons."""
resp = await client.get("/api/v1/coupons/relevant", headers=auth_headers)
resp = await client.get("/coupons/relevant", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
+5 -11
View File
@@ -48,7 +48,7 @@ async def price_data(db_engine, auth_headers):
@pytest.mark.asyncio
async def test_price_trends(client, price_data):
resp = await client.get("/api/v1/prices/trends", headers=price_data["headers"])
resp = await client.get("/prices/trends", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@@ -58,22 +58,18 @@ async def test_price_trends(client, price_data):
@pytest.mark.asyncio
async def test_price_trends_by_category(client, price_data):
resp = await client.get(
"/api/v1/prices/trends?category=household", headers=price_data["headers"]
)
resp = await client.get("/prices/trends?category=household", headers=price_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 1
resp = await client.get(
"/api/v1/prices/trends?category=nonexistent", headers=price_data["headers"]
)
resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 0
@pytest.mark.asyncio
async def test_price_increases(client, price_data):
resp = await client.get("/api/v1/prices/increases", headers=price_data["headers"])
resp = await client.get("/prices/increases", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@@ -86,9 +82,7 @@ async def test_price_increases(client, price_data):
@pytest.mark.asyncio
async def test_price_comparison(client, price_data):
pid = str(price_data["product"].id)
resp = await client.get(
f"/api/v1/prices/comparison?product_ids={pid}", headers=price_data["headers"]
)
resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
+6 -6
View File
@@ -49,7 +49,7 @@ async def product_data(db_engine, auth_headers):
@pytest.mark.asyncio
async def test_list_products(client, product_data):
resp = await client.get("/api/v1/products", headers=product_data["headers"])
resp = await client.get("/products", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@@ -58,11 +58,11 @@ async def test_list_products(client, product_data):
@pytest.mark.asyncio
async def test_search_products(client, product_data):
resp = await client.get("/api/v1/products?q=Cheerios", headers=product_data["headers"])
resp = await client.get("/products?q=Cheerios", headers=product_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 1
resp = await client.get("/api/v1/products?q=nonexistent", headers=product_data["headers"])
resp = await client.get("/products?q=nonexistent", headers=product_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 0
@@ -70,7 +70,7 @@ async def test_search_products(client, product_data):
@pytest.mark.asyncio
async def test_get_product_detail(client, product_data):
pid = str(product_data["product"].id)
resp = await client.get(f"/api/v1/products/{pid}", headers=product_data["headers"])
resp = await client.get(f"/products/{pid}", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Cheerios 18oz"
@@ -80,14 +80,14 @@ async def test_get_product_detail(client, product_data):
@pytest.mark.asyncio
async def test_get_product_not_found(client, auth_headers):
resp = await client.get(f"/api/v1/products/{uuid.uuid4()}", headers=auth_headers)
resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_get_product_prices(client, product_data):
pid = str(product_data["product"].id)
resp = await client.get(f"/api/v1/products/{pid}/prices", headers=product_data["headers"])
resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Cheerios 18oz"
+77 -9
View File
@@ -1,7 +1,7 @@
"""Integration tests for public endpoints (no auth)."""
import uuid
from datetime import date, timedelta
from datetime import date
from decimal import Decimal
import pytest
@@ -29,7 +29,7 @@ async def public_data(db_engine):
ph = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date.today() - timedelta(days=30),
observed_date=date(2026, 3, 5),
regular_price=Decimal("3.99"),
source="receipt",
)
@@ -42,7 +42,7 @@ async def public_data(db_engine):
@pytest.mark.asyncio
async def test_public_trend(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/api/v1/public/trends/{pid}")
resp = await client.get(f"/public/trends/{pid}")
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Skippy PB 16oz"
@@ -51,14 +51,14 @@ async def test_public_trend(client, public_data):
@pytest.mark.asyncio
async def test_public_trend_not_found(client):
resp = await client.get(f"/api/v1/public/trends/{uuid.uuid4()}")
resp = await client.get(f"/public/trends/{uuid.uuid4()}")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_public_store_comparison(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/api/v1/public/store-comparison?product_ids={pid}")
resp = await client.get(f"/public/store-comparison?product_ids={pid}")
assert resp.status_code == 200
data = resp.json()
assert len(data["products"]) == 1
@@ -66,7 +66,7 @@ async def test_public_store_comparison(client, public_data):
@pytest.mark.asyncio
async def test_public_inflation(client, public_data):
resp = await client.get("/api/v1/public/inflation")
resp = await client.get("/public/inflation")
assert resp.status_code == 200
data = resp.json()
assert "categories" in data
@@ -75,7 +75,7 @@ async def test_public_inflation(client, public_data):
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/api/v1/public/trends/not-a-uuid")
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@@ -84,7 +84,7 @@ async def test_trend_invalid_uuid(client):
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/api/v1/public/trends/{pid}?days=0")
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@@ -93,7 +93,75 @@ async def test_trend_days_zero(client, public_data):
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/api/v1/public/trends/{pid}?days=-1")
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
+4 -4
View File
@@ -80,7 +80,7 @@ async def purchase_data(db_engine):
@pytest.mark.asyncio
async def test_list_purchases(client, purchase_data):
resp = await client.get("/api/v1/purchases", headers=purchase_data["headers"])
resp = await client.get("/purchases", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
@@ -91,7 +91,7 @@ async def test_list_purchases(client, purchase_data):
@pytest.mark.asyncio
async def test_get_purchase_detail(client, purchase_data):
pid = str(purchase_data["purchase"].id)
resp = await client.get(f"/api/v1/purchases/{pid}", headers=purchase_data["headers"])
resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data["line_items"]) == 1
@@ -100,13 +100,13 @@ async def test_get_purchase_detail(client, purchase_data):
@pytest.mark.asyncio
async def test_get_purchase_not_found(client, auth_headers):
resp = await client.get(f"/api/v1/purchases/{uuid.uuid4()}", headers=auth_headers)
resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_purchase_stats(client, purchase_data):
resp = await client.get("/api/v1/purchases/stats", headers=purchase_data["headers"])
resp = await client.get("/purchases/stats", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["total_spent"] == 42.50
+9 -9
View File
@@ -21,7 +21,7 @@ async def seeded_store(db_engine):
@pytest.mark.asyncio
async def test_list_stores(client, seeded_store):
resp = await client.get("/api/v1/stores")
resp = await client.get("/stores")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@@ -30,7 +30,7 @@ async def test_list_stores(client, seeded_store):
@pytest.mark.asyncio
async def test_list_user_stores_empty(client, auth_headers):
resp = await client.get("/api/v1/me/stores", headers=auth_headers)
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
@@ -39,7 +39,7 @@ async def test_list_user_stores_empty(client, auth_headers):
async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
# Connect
resp = await client.post(
"/api/v1/me/stores/meijer/connect",
"/me/stores/meijer/connect",
headers=auth_headers,
json={"credentials": None},
)
@@ -47,23 +47,23 @@ async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
assert resp.json()["connected"] is True
# List should show connected
resp = await client.get("/api/v1/me/stores", headers=auth_headers)
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.status_code == 200
assert len(resp.json()) == 1
# Disconnect
resp = await client.delete("/api/v1/me/stores/meijer", headers=auth_headers)
resp = await client.delete("/me/stores/meijer", headers=auth_headers)
assert resp.status_code == 204
# List should be empty again
resp = await client.get("/api/v1/me/stores", headers=auth_headers)
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.json() == []
@pytest.mark.asyncio
async def test_connect_nonexistent_store(client, auth_headers):
resp = await client.post(
"/api/v1/me/stores/nonexistent/connect",
"/me/stores/nonexistent/connect",
headers=auth_headers,
json={},
)
@@ -72,6 +72,6 @@ async def test_connect_nonexistent_store(client, auth_headers):
@pytest.mark.asyncio
async def test_connect_duplicate_store(client, auth_headers, seeded_store):
await client.post("/api/v1/me/stores/meijer/connect", headers=auth_headers, json={})
resp = await client.post("/api/v1/me/stores/meijer/connect", headers=auth_headers, json={})
await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
assert resp.status_code == 409
Generated
-1348
View File
File diff suppressed because it is too large Load Diff