Compare commits

..

1 Commits

Author SHA1 Message Date
CartSnitch Engineer Bot c116d0bc8a feat(ci): add deploy-uat job for UAT environment
Mirrors deploy-dev job but targets apps/overlays/uat. Both deploy-dev
and deploy-uat run in parallel after all build jobs complete.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-03 13:23:38 +00:00
38 changed files with 216 additions and 1258 deletions
+33 -111
View File
@@ -2,9 +2,9 @@ name: CI
on:
push:
branches: [main, dev, uat]
branches: [main]
pull_request:
branches: [main, dev, uat]
branches: [main]
concurrency:
group: ci-${{ github.ref }}
@@ -99,11 +99,10 @@ jobs:
build-and-push:
runs-on: runners-cartsnitch
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
needs: [lint, test, e2e]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
@@ -127,14 +126,14 @@ jobs:
echo "CalVer tag: $VERSION"
- name: Log in to Docker Hub
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
@@ -147,7 +146,7 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=sha,prefix=sha-
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
@@ -155,7 +154,7 @@ jobs:
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name == 'push' }}
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
target: prod
@@ -170,11 +169,10 @@ jobs:
build-and-push-auth:
runs-on: runners-cartsnitch
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
needs: [lint, test, e2e]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
@@ -197,14 +195,14 @@ jobs:
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
- name: Log in to Docker Hub
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
@@ -217,7 +215,7 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.AUTH_IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=sha,prefix=sha-
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
@@ -226,17 +224,16 @@ jobs:
with:
context: ./auth
file: ./auth/Dockerfile
push: ${{ github.event_name == 'push' }}
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-and-push-receiptwitness:
runs-on: runners-cartsnitch
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
needs: [lint, test]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
@@ -254,14 +251,14 @@ jobs:
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
- name: Log in to Docker Hub
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
@@ -274,7 +271,7 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.RECEIPTWITNESS_IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=sha,prefix=sha-
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
@@ -283,17 +280,16 @@ jobs:
with:
context: .
file: ./receiptwitness/Dockerfile
push: ${{ github.event_name == 'push' }}
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-and-push-api:
runs-on: runners-cartsnitch
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
needs: [lint, test]
outputs:
calver_tag: ${{ steps.calver.outputs.version }}
sha_tag: sha-${{ github.sha }}
steps:
- uses: actions/checkout@v4
with:
@@ -311,14 +307,14 @@ jobs:
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
- name: Log in to Docker Hub
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
if: github.event_name == 'push'
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
@@ -331,23 +327,23 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.API_IMAGE_NAME }}
tags: |
type=sha,prefix=sha-,format=long
type=sha,prefix=sha-
type=raw,value=${{ steps.calver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
- name: Build and push API Docker image
uses: docker/build-push-action@v6
with:
context: ./api
context: .
file: ./api/Dockerfile
push: ${{ github.event_name == 'push' }}
push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
deploy-dev:
runs-on: runners-cartsnitch
needs: [build-and-push, build-and-push-auth, build-and-push-receiptwitness, build-and-push-api]
if: always() && !cancelled() && github.event_name == 'push' && (github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/main')
if: always() && !cancelled() && github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- name: Generate GitHub App token
id: app-token
@@ -372,65 +368,29 @@ jobs:
- name: Install kustomize
uses: imranismail/setup-kustomize@v2
- name: Determine image tag for frontend
id: frontend_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
- name: Update frontend image tag
if: needs.build-and-push.result == 'success'
run: |
cd infra/apps/overlays/dev
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ steps.frontend_tag.outputs.tag }}
- name: Determine image tag for auth
id: auth_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-auth.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-auth.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ needs.build-and-push.outputs.calver_tag }}
- name: Update auth image tag
if: needs.build-and-push-auth.result == 'success'
run: |
cd infra/apps/overlays/dev
kustomize edit set image ghcr.io/cartsnitch/auth:${{ steps.auth_tag.outputs.tag }}
- name: Determine image tag for receiptwitness
id: receiptwitness_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-receiptwitness.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-receiptwitness.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/auth:${{ needs.build-and-push-auth.outputs.calver_tag }}
- name: Update receiptwitness image tag
if: needs.build-and-push-receiptwitness.result == 'success'
run: |
cd infra/apps/overlays/dev
kustomize edit set image ghcr.io/cartsnitch/receiptwitness:${{ steps.receiptwitness_tag.outputs.tag }}
- name: Determine image tag for api
id: api_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-api.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-api.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/receiptwitness:${{ needs.build-and-push-receiptwitness.outputs.calver_tag }}
- name: Update api image tag
if: needs.build-and-push-api.result == 'success'
run: |
cd infra/apps/overlays/dev
kustomize edit set image ghcr.io/cartsnitch/api:${{ steps.api_tag.outputs.tag }}
kustomize edit set image ghcr.io/cartsnitch/api:${{ needs.build-and-push-api.outputs.calver_tag }}
- name: Commit and push to infra
run: |
@@ -439,13 +399,12 @@ jobs:
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
git add apps/overlays/dev/kustomization.yaml
git commit -m "ci(dev): update cartsnitch, auth, receiptwitness, and api images"
git pull --rebase origin main
git push origin main
deploy-uat:
runs-on: runners-cartsnitch
needs: [build-and-push, build-and-push-auth, build-and-push-receiptwitness, build-and-push-api]
if: always() && !cancelled() && github.event_name == 'push' && (github.ref == 'refs/heads/uat' || github.ref == 'refs/heads/main')
if: always() && !cancelled() && github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- name: Generate GitHub App token
id: app-token
@@ -470,65 +429,29 @@ jobs:
- name: Install kustomize
uses: imranismail/setup-kustomize@v2
- name: Determine image tag for frontend
id: frontend_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
- name: Update frontend image tag
if: needs.build-and-push.result == 'success'
run: |
cd infra/apps/overlays/uat
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ steps.frontend_tag.outputs.tag }}
- name: Determine image tag for auth
id: auth_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-auth.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-auth.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/cartsnitch:${{ needs.build-and-push.outputs.calver_tag }}
- name: Update auth image tag
if: needs.build-and-push-auth.result == 'success'
run: |
cd infra/apps/overlays/uat
kustomize edit set image ghcr.io/cartsnitch/auth:${{ steps.auth_tag.outputs.tag }}
- name: Determine image tag for receiptwitness
id: receiptwitness_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-receiptwitness.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-receiptwitness.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/auth:${{ needs.build-and-push-auth.outputs.calver_tag }}
- name: Update receiptwitness image tag
if: needs.build-and-push-receiptwitness.result == 'success'
run: |
cd infra/apps/overlays/uat
kustomize edit set image ghcr.io/cartsnitch/receiptwitness:${{ steps.receiptwitness_tag.outputs.tag }}
- name: Determine image tag for api
id: api_tag
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tag=${{ needs.build-and-push-api.outputs.calver_tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ needs.build-and-push-api.outputs.sha_tag }}" >> "$GITHUB_OUTPUT"
fi
kustomize edit set image ghcr.io/cartsnitch/receiptwitness:${{ needs.build-and-push-receiptwitness.outputs.calver_tag }}
- name: Update api image tag
if: needs.build-and-push-api.result == 'success'
run: |
cd infra/apps/overlays/uat
kustomize edit set image ghcr.io/cartsnitch/api:${{ steps.api_tag.outputs.tag }}
kustomize edit set image ghcr.io/cartsnitch/api:${{ needs.build-and-push-api.outputs.calver_tag }}
- name: Commit and push to infra
run: |
@@ -537,5 +460,4 @@ jobs:
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
git add apps/overlays/uat/kustomization.yaml
git commit -m "ci(uat): update cartsnitch, auth, receiptwitness, and api images"
git pull --rebase origin main
git push origin main
+1 -5
View File
@@ -12,14 +12,10 @@ RUN pip install --no-cache-dir --prefix=/install .
FROM python:3.12-slim AS prod
RUN apt-get update && apt-get install -y --no-install-recommends libpq5 && rm -rf /var/lib/apt/lists/*
WORKDIR /app
RUN adduser --system --group --uid 1000 app
COPY --from=build /install /usr/local
COPY src/ ./src/
COPY alembic.ini ./
COPY alembic/ ./alembic/
USER 1000
EXPOSE 8000
@@ -27,4 +23,4 @@ EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=3s \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"
CMD ["sh", "-c", "python -m alembic upgrade head && uvicorn cartsnitch_api.main:app --host 0.0.0.0 --port 8000"]
CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
+2 -14
View File
@@ -18,7 +18,7 @@ if not db_url:
"CARTSNITCH_DATABASE_URL_SYNC must be set. "
"Example: postgresql://user:pass@localhost:5432/cartsnitch"
)
config.set_main_option("sqlalchemy.url", db_url.replace("%", "%%"))
config.set_main_option("sqlalchemy.url", db_url)
target_metadata = Base.metadata
@@ -31,7 +31,6 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
version_table_column_width=128,
)
with context.begin_transaction():
context.run_migrations()
@@ -45,20 +44,9 @@ def run_migrations_online() -> None:
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata, version_table_column_width=128)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
# Create any tables defined in models but not yet created by migrations.
# This bootstraps fresh databases that have no legacy schema.
# checkfirst=True ensures this is a no-op on existing databases.
try:
Base.metadata.create_all(bind=connection, checkfirst=True)
connection.commit()
except Exception as exc:
import logging
logging.getLogger("alembic.env").warning(
"create_all failed (non-fatal, migrations should handle table creation): %s", exc
)
if context.is_offline_mode():
@@ -33,21 +33,6 @@ def _is_fernet_token(value: str) -> bool:
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Fresh DB — table created by Base.metadata.create_all with correct TEXT type
if not inspector.has_table("user_store_accounts"):
return
# Already migrated? Skip if session_data is already TEXT (not JSON)
cols = {c["name"]: c for c in inspector.get_columns("user_store_accounts")}
if "session_data" not in cols:
return
col_type = str(cols["session_data"]["type"]).lower()
if "text" in col_type and "json" not in col_type:
return # already TEXT — nothing to do
# Change column type from JSON to TEXT to hold Fernet ciphertext
op.alter_column(
"user_store_accounts",
@@ -58,6 +43,7 @@ def upgrade() -> None:
postgresql_using="session_data::text",
)
conn = op.get_bind()
rows = conn.execute(
text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL")
).fetchall()
+65 -78
View File
@@ -21,94 +21,81 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# --- Extend users table for Better-Auth compatibility ---
# Guard: on a fresh DB Base.metadata.create_all (called in env.py after migrations)
# creates the users table with all columns, so migration 002 must not re-run add_column.
if inspector.has_table("users"):
existing_user_cols = [c["name"] for c in inspector.get_columns("users")]
if "email_verified" not in existing_user_cols:
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
if "image" not in existing_user_cols:
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false"))
op.add_column("users", sa.Column("image", sa.Text(), nullable=True))
# --- Create sessions table ---
if not inspector.has_table("sessions"):
op.create_table(
"sessions",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("ip_address", sa.Text(), nullable=True),
sa.Column("user_agent", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
op.create_table(
"sessions",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("ip_address", sa.Text(), nullable=True),
sa.Column("user_agent", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_sessions_token", "sessions", ["token"], unique=True)
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
# --- Create accounts table ---
if not inspector.has_table("accounts"):
op.create_table(
"accounts",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("account_id", sa.Text(), nullable=False),
sa.Column("provider_id", sa.Text(), nullable=False),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("id_token", sa.Text(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
op.create_table(
"accounts",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("account_id", sa.Text(), nullable=False),
sa.Column("provider_id", sa.Text(), nullable=False),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("id_token", sa.Text(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
# --- Create verifications table ---
if not inspector.has_table("verifications"):
op.create_table(
"verifications",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("identifier", sa.Text(), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"verifications",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("identifier", sa.Text(), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# --- Migrate existing password hashes to accounts table ---
# Only run on existing (non-fresh) DBs that already have users table with data
if inspector.has_table("users"):
users = conn.execute(
text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL")
).fetchall()
# For each user with a hashed_password, create a 'credential' account row
conn = op.get_bind()
users = conn.execute(
text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL")
).fetchall()
for user_id, hashed_password in users:
user_id_str = str(user_id)
conn.execute(
text(
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
),
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
)
for user_id, hashed_password in users:
user_id_str = str(user_id)
conn.execute(
text(
"INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) "
"VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())"
),
{"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password},
)
def downgrade() -> None:
op.execute(text("DROP INDEX IF EXISTS ix_accounts_user_id"))
op.execute(text("DROP TABLE IF EXISTS verifications"))
op.execute(text("DROP TABLE IF EXISTS accounts"))
op.execute(text("DROP INDEX IF EXISTS ix_sessions_user_id"))
op.execute(text("DROP INDEX IF EXISTS ix_sessions_token"))
op.execute(text("DROP TABLE IF EXISTS sessions"))
op.execute(text("ALTER TABLE users DROP COLUMN IF EXISTS image"))
op.execute(text("ALTER TABLE users DROP COLUMN IF EXISTS email_verified"))
op.drop_table("verifications")
op.drop_table("accounts")
op.drop_index("ix_sessions_user_id", table_name="sessions")
op.drop_index("ix_sessions_token", table_name="sessions")
op.drop_table("sessions")
op.drop_column("users", "image")
op.drop_column("users", "email_verified")
@@ -19,25 +19,8 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Fresh DB — nothing to alter
if not inspector.has_table("users"):
return
cols = {c["name"]: c for c in inspector.get_columns("users")}
if "hashed_password" in cols and not cols["hashed_password"]["nullable"]:
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=True)
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=True)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
if not inspector.has_table("users"):
return
cols = {c["name"]: c for c in inspector.get_columns("users")}
if "hashed_password" in cols and cols["hashed_password"]["nullable"]:
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=False)
op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=False)
+1 -15
View File
@@ -25,21 +25,7 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Fresh DB — no tables yet, nothing to convert
if not inspector.has_table("users"):
return
# Check if already TEXT (Base.metadata.create_all uses TEXT for fresh DB)
users_cols = {c["name"]: c for c in inspector.get_columns("users")}
if "id" in users_cols:
id_type = str(users_cols["id"]["type"]).lower()
if "text" in id_type and "uuid" not in id_type:
return # already TEXT — nothing to do
# Step 1: Drop existing FK constraints (ignore if they don't exist)
# Step 1: Drop existing FK constraints
op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey"))
op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey"))
@@ -18,15 +18,6 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Guard: on a fresh DB Base.metadata.create_all creates users table with the column already present
if not inspector.has_table("users"):
return
existing_cols = [c["name"] for c in inspector.get_columns("users")]
if "email_inbound_token" in existing_cols:
return
# Add column nullable first so existing rows can be backfilled
op.add_column(
"users",
@@ -34,10 +25,11 @@ def upgrade() -> None:
)
# Backfill existing users with unique tokens
result = conn.execute(sa.text("SELECT id FROM users WHERE email_inbound_token IS NULL"))
connection = op.get_bind()
result = connection.execute(sa.text("SELECT id FROM users WHERE email_inbound_token IS NULL"))
for (user_id,) in result:
token = secrets.token_urlsafe(16)
conn.execute(
connection.execute(
sa.text("UPDATE users SET email_inbound_token = :token WHERE id = :id"),
{"token": token, "id": user_id},
)
@@ -1,42 +0,0 @@
"""Add server_default to users.email_inbound_token.
Revision ID: 006_email_inbound_token_server_default
Revises: 005_add_email_inbound_token
Create Date: 2026-04-04
"""
import sqlalchemy as sa
from alembic import op
revision = "006_email_inbound_token_server_default"
down_revision = "005_add_email_inbound_token"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# Guard: on a fresh DB Base.metadata.create_all already sets the server_default
if not inspector.has_table("users"):
return
cols = {c["name"]: c for c in inspector.get_columns("users")}
if "email_inbound_token" not in cols:
return
if cols["email_inbound_token"].get("default") is not None:
return
op.alter_column(
"users",
"email_inbound_token",
server_default=sa.text(
"replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
),
)
def downgrade() -> None:
op.alter_column(
"users",
"email_inbound_token",
server_default=None,
)
@@ -1,47 +0,0 @@
"""Bootstrap users table on fresh databases.
On fresh databases, migrations 001-006 skip users-table operations because
the table does not exist yet. Base.metadata.create_all() in env.py is meant
to handle this, but if it fails (import errors, etc.) the table is never
created. This migration creates the users table with raw SQL as a safety net.
Revision ID: 007_bootstrap_users_table
Revises: 006_email_inbound_token_server_default
Create Date: 2026-04-04
"""
import sqlalchemy as sa
from sqlalchemy import text
from alembic import op
revision = "007_bootstrap_users_table"
down_revision = "006_email_inbound_token_server_default"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
if inspector.has_table("users"):
return # Table already exists (non-fresh DB or create_all already ran)
conn.execute(text("""
CREATE TABLE users (
id TEXT PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
hashed_password VARCHAR(255),
display_name VARCHAR(100),
email_verified BOOLEAN NOT NULL DEFAULT false,
image TEXT,
email_inbound_token VARCHAR(22) NOT NULL UNIQUE
DEFAULT replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_'),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
"""))
def downgrade() -> None:
op.execute(text("DROP TABLE IF EXISTS users"))
@@ -1,210 +0,0 @@
"""Create domain tables (stores, purchases, coupons, etc.).
Revision ID: 008_create_domain_tables
Revises: 007_bootstrap_users_table
Create Date: 2026-04-04
"""
import sqlalchemy as sa
from sqlalchemy import text
from alembic import op
revision = "008_create_domain_tables"
down_revision = "007_bootstrap_users_table"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
# 1. stores
if not inspector.has_table("stores"):
op.create_table(
"stores",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("slug", sa.String(20), nullable=False, unique=True),
sa.Column("logo_url", sa.String(500), nullable=True),
sa.Column("website_url", sa.String(500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 2. store_locations
if not inspector.has_table("store_locations"):
op.create_table(
"store_locations",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("address", sa.String(300), nullable=False),
sa.Column("city", sa.String(100), nullable=False),
sa.Column("state", sa.String(2), nullable=False),
sa.Column("zip", sa.String(10), nullable=False),
sa.Column("lat", sa.Float(), nullable=True),
sa.Column("lng", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 3. normalized_products
if not inspector.has_table("normalized_products"):
op.create_table(
"normalized_products",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("canonical_name", sa.String(300), nullable=False),
sa.Column("category", sa.String(50), nullable=True),
sa.Column("subcategory", sa.String(100), nullable=True),
sa.Column("brand", sa.String(200), nullable=True),
sa.Column("size", sa.String(50), nullable=True),
sa.Column("size_unit", sa.String(10), nullable=True),
sa.Column("upc_variants", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 4. purchases
if not inspector.has_table("purchases"):
op.create_table(
"purchases",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("store_location_id", sa.Uuid(), sa.ForeignKey("store_locations.id"), nullable=True),
sa.Column("receipt_id", sa.String(200), nullable=False),
sa.Column("purchase_date", sa.Date(), nullable=False),
sa.Column("total", sa.Numeric(10, 2), nullable=False),
sa.Column("subtotal", sa.Numeric(10, 2), nullable=True),
sa.Column("tax", sa.Numeric(10, 2), nullable=True),
sa.Column("savings_total", sa.Numeric(10, 2), nullable=True),
sa.Column("source_url", sa.String(500), nullable=True),
sa.Column("raw_data", sa.JSON(), nullable=True),
sa.Column("ingested_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
sa.Index("ix_purchases_user_store", "user_id", "store_id"),
)
# 5. purchase_items
if not inspector.has_table("purchase_items"):
op.create_table(
"purchase_items",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("purchase_id", sa.Uuid(), sa.ForeignKey("purchases.id"), nullable=False),
sa.Column("product_name_raw", sa.String(300), nullable=False),
sa.Column("upc", sa.String(20), nullable=True),
sa.Column("quantity", sa.Numeric(10, 3), nullable=False),
sa.Column("unit_price", sa.Numeric(10, 2), nullable=False),
sa.Column("extended_price", sa.Numeric(10, 2), nullable=False),
sa.Column("regular_price", sa.Numeric(10, 2), nullable=True),
sa.Column("sale_price", sa.Numeric(10, 2), nullable=True),
sa.Column("coupon_discount", sa.Numeric(10, 2), nullable=True),
sa.Column("loyalty_discount", sa.Numeric(10, 2), nullable=True),
sa.Column("category_raw", sa.String(100), nullable=True),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 6. coupons
if not inspector.has_table("coupons"):
op.create_table(
"coupons",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
sa.Column("title", sa.String(300), nullable=False),
sa.Column("description", sa.String(1000), nullable=True),
sa.Column("discount_type", sa.String(20), nullable=False),
sa.Column("discount_value", sa.Numeric(10, 2), nullable=True),
sa.Column("min_purchase", sa.Numeric(10, 2), nullable=True),
sa.Column("valid_from", sa.Date(), nullable=True),
sa.Column("valid_to", sa.Date(), nullable=True),
sa.Column("requires_clip", sa.Boolean(), server_default=text("false"), nullable=False),
sa.Column("coupon_code", sa.String(100), nullable=True),
sa.Column("source_url", sa.String(500), nullable=True),
sa.Column("scraped_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 7. price_history
if not inspector.has_table("price_history"):
op.create_table(
"price_history",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("observed_date", sa.Date(), nullable=False),
sa.Column("regular_price", sa.Numeric(10, 2), nullable=False),
sa.Column("sale_price", sa.Numeric(10, 2), nullable=True),
sa.Column("loyalty_price", sa.Numeric(10, 2), nullable=True),
sa.Column("coupon_price", sa.Numeric(10, 2), nullable=True),
sa.Column("source", sa.String(20), nullable=False),
sa.Column("purchase_item_id", sa.Uuid(), sa.ForeignKey("purchase_items.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Index("ix_price_history_product_store_date", "normalized_product_id", "store_id", "observed_date"),
)
# 8. shrinkflation_events
if not inspector.has_table("shrinkflation_events"):
op.create_table(
"shrinkflation_events",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
sa.Column("detected_date", sa.Date(), nullable=False),
sa.Column("old_size", sa.String(50), nullable=False),
sa.Column("new_size", sa.String(50), nullable=False),
sa.Column("old_unit", sa.String(10), nullable=True),
sa.Column("new_unit", sa.String(10), nullable=True),
sa.Column("price_at_old_size", sa.Numeric(10, 2), nullable=True),
sa.Column("price_at_new_size", sa.Numeric(10, 2), nullable=True),
sa.Column("confidence", sa.Numeric(3, 2), server_default=text("1.00"), nullable=False),
sa.Column("notes", sa.String(1000), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
)
# 9. user_store_accounts
if not inspector.has_table("user_store_accounts"):
op.create_table(
"user_store_accounts",
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False),
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
sa.Column("session_data", sa.JSON(), nullable=True),
sa.Column("session_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("status", sa.String(20), server_default=text("'active'"), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),
)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
if inspector.has_table("user_store_accounts"):
op.drop_table("user_store_accounts")
if inspector.has_table("shrinkflation_events"):
op.drop_table("shrinkflation_events")
if inspector.has_table("price_history"):
op.drop_table("price_history")
if inspector.has_table("coupons"):
op.drop_table("coupons")
if inspector.has_table("purchase_items"):
op.drop_table("purchase_items")
if inspector.has_table("purchases"):
op.drop_table("purchases")
if inspector.has_table("normalized_products"):
op.drop_table("normalized_products")
if inspector.has_table("store_locations"):
op.drop_table("store_locations")
if inspector.has_table("stores"):
op.drop_table("stores")
+5 -11
View File
@@ -19,15 +19,12 @@ bearer_scheme = HTTPBearer(auto_error=False)
# Better-Auth session cookie name
SESSION_COOKIE_NAME = "better-auth.session_token"
# Secure prefix used by better-auth on HTTPS deployments
SECURE_SESSION_COOKIE_NAME = "__Secure-better-auth.session_token"
async def _validate_session_token(token: str, db: AsyncSession) -> str:
"""Validate a Better-Auth session token against the sessions table.
Better-Auth stores the raw token in the DB. The cookie/Bearer header
carries the same raw token, so we compare directly.
Returns the user_id (as str) if the session is valid and not expired.
"""
result = await db.execute(
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
@@ -68,17 +65,14 @@ async def get_current_user(
"""
token: str | None = None
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME)
# 1. Check session cookie
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
if cookie_token:
# Better-Auth cookie format is "token.sessionId" — extract just the token part
token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token
token = cookie_token
# 2. Fall back to Bearer header
if not token and credentials:
# Callers might pass the compound value here too
raw = credentials.credentials
token = raw.split(".")[0] if "." in raw else raw
token = credentials.credentials
if not token:
raise HTTPException(
+25
View File
@@ -22,6 +22,11 @@ from cartsnitch_api.services.auth import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
class EmailInAddressResponse(BaseModel):
email_address: str
instructions: str
@router.get("/me", response_model=UserResponse)
async def get_me(
user_id: str = Depends(get_current_user),
@@ -65,3 +70,23 @@ async def delete_me(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) from None
@router.get("/me/email-in-address", response_model=EmailInAddressResponse)
async def get_email_in_address(
user_id: str = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(User.email_inbound_token).where(User.id == user_id))
token = result.scalar_one_or_none()
if not token:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Email inbound token not found"
) from None
return EmailInAddressResponse(
email_address=f"receipts+{token}@receipts.cartsnitch.com",
instructions=(
"Forward your digital receipt emails to this address. "
"We currently support Meijer, Kroger, and Target receipt emails."
),
)
+8 -33
View File
@@ -1,51 +1,26 @@
"""Redis/DragonflyDB caching helpers."""
import redis.asyncio as redis
from cartsnitch_api.config import settings
class CacheClient:
"""Redis/DragonflyDB caching with connection pooling.
"""Stub for Redis/DragonflyDB caching.
Will be used for expensive queries: price trends, product comparisons.
Cache invalidation via Redis pub/sub events from other services.
"""
def __init__(self) -> None:
self._pool: redis.ConnectionPool | None = None
self._client: redis.Redis | None = None
async def initialize(self) -> None:
"""Initialize the Redis connection pool."""
self._pool = redis.ConnectionPool.from_url(
settings.redis_url,
max_connections=20,
decode_responses=True,
)
self._client = redis.Redis(connection_pool=self._pool)
async def close(self) -> None:
"""Close the Redis connection pool."""
if self._client:
await self._client.aclose()
if self._pool:
await self._pool.aclose()
self.url = settings.redis_url
async def get(self, key: str) -> str | None:
if not self._client:
return None
return await self._client.get(key)
# TODO: implement with redis-py async
return None
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
if not self._client:
return
await self._client.set(key, value, ex=ttl_seconds)
# TODO: implement with redis-py async
pass
async def delete(self, key: str) -> None:
if not self._client:
return
await self._client.delete(key)
cache_client = CacheClient()
# TODO: implement with redis-py async
pass
+8 -39
View File
@@ -1,25 +1,23 @@
import base64
from pydantic import AliasChoices, Field, model_validator
from pydantic import model_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
model_config = {"env_prefix": "CARTSNITCH_"}
database_url: str = Field(
default="postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
validation_alias=AliasChoices("CARTSNITCH_DATABASE_URL", "DATABASE_URL"),
)
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
redis_url: str = "redis://localhost:6379/0"
jwt_secret_key: str
jwt_secret_key: str = "change-me-in-production"
jwt_algorithm: str = "HS256"
jwt_access_token_expire_minutes: int = 15
jwt_refresh_token_expire_days: int = 7
service_key: str
fernet_key: str
service_key: str = "change-me-in-production"
# Valid Fernet key for local dev — MUST be overridden in production
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
auth_service_url: str = "http://auth:3001"
@@ -32,31 +30,11 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True
rate_limit_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"}
@model_validator(mode="after")
def validate_secrets(self):
if not self.jwt_secret_key or self.jwt_secret_key in self._PLACEHOLDER_VALUES:
raise ValueError(
"CARTSNITCH_JWT_SECRET_KEY must be set to a secure value. "
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
if not self.service_key or self.service_key in self._PLACEHOLDER_VALUES:
raise ValueError(
"CARTSNITCH_SERVICE_KEY must be set to a secure value. "
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
if not self.fernet_key or self.fernet_key in self._PLACEHOLDER_VALUES:
raise ValueError(
"CARTSNITCH_FERNET_KEY must be set to a valid Fernet key. "
"Generate one with: python -c "
"'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
)
def validate_fernet_key(self):
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
try:
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
if len(decoded) != 32:
@@ -71,14 +49,5 @@ class Settings(BaseSettings):
) from None
return self
@model_validator(mode="after")
def normalize_database_url(self):
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
if self.database_url.startswith("postgresql://"):
self.database_url = self.database_url.replace(
"postgresql://", "postgresql+asyncpg://", 1
)
return self
settings = Settings()
+1 -13
View File
@@ -6,14 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from cartsnitch_api.config import settings
engine = create_async_engine(
settings.database_url,
echo=False,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
pool_recycle=3600,
)
engine = create_async_engine(settings.database_url, echo=False)
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
@@ -21,8 +14,3 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session."""
async with async_session_factory() as session:
yield session
async def dispose_engine() -> None:
"""Dispose the database engine, closing all pooled connections."""
await engine.dispose()
+2 -5
View File
@@ -5,8 +5,6 @@ from contextlib import asynccontextmanager
from fastapi import APIRouter, FastAPI
from cartsnitch_api.auth.routes import router as auth_router
from cartsnitch_api.cache import cache_client
from cartsnitch_api.database import dispose_engine
from cartsnitch_api.middleware.cors import add_cors_middleware
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
@@ -25,10 +23,9 @@ from cartsnitch_api.routes.user import router as user_router
@asynccontextmanager
async def lifespan(app: FastAPI):
await cache_client.initialize()
# TODO: initialize DB session pool, Redis connection, service clients
yield
await cache_client.close()
await dispose_engine()
# TODO: cleanup connections
def create_app() -> FastAPI:
+2 -2
View File
@@ -11,6 +11,6 @@ def add_cors_middleware(app: FastAPI) -> None:
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"],
allow_methods=["*"],
allow_headers=["*"],
)
+19 -101
View File
@@ -4,32 +4,18 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
"""
import hashlib
import logging
import time
import uuid
from collections import defaultdict
from threading import Lock
from typing import Protocol
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis, RedisError
from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class RateLimitBackend(Protocol):
"""Protocol for rate limit backends."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
class InMemorySlidingWindow:
class _SlidingWindowCounter:
"""Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
@@ -38,12 +24,13 @@ class InMemorySlidingWindow:
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock()
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic()
cutoff = now - self.window_seconds
with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key])
@@ -56,84 +43,15 @@ class InMemorySlidingWindow:
return True, remaining, 0
class RedisSlidingWindow:
"""Redis-backed sliding window rate limiter using sorted sets."""
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
self.redis = redis
self.max_requests = max_requests
self.window_seconds = window_seconds
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
try:
now = time.monotonic()
cutoff = now - self.window_seconds
now_ms = int(now * 1000)
cutoff_ms = int(cutoff * 1000)
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, cutoff_ms)
pipe.zcard(key)
results = await pipe.execute()
current_count = results[1]
if current_count >= self.max_requests:
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if oldest:
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
pipe = self.redis.pipeline()
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except RedisError as e:
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
return await in_memory.is_allowed(key)
_redis_client: Redis | None = None
_use_redis = False
if settings.rate_limit_redis_enabled:
try:
_redis_client = Redis.from_url(settings.redis_url)
_use_redis = True
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
except Exception as e:
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
_use_redis = False
if _use_redis and _redis_client:
_public_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
else:
_public_limiter = InMemorySlidingWindow(
settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = InMemorySlidingWindow(
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
# Module-level counters — one for public (per-IP), one for auth (per-token)
_public_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds,
)
_auth_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
window_seconds=settings.rate_limit_window_seconds,
)
def _get_client_ip(request: Request) -> str:
@@ -144,30 +62,30 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
"""Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter
if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
# For authenticated endpoints, use Bearer token as key if present
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter
# Use last 16 chars of token as key to avoid storing full tokens
return f"token:{token[-16:]}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request)
key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = await limiter.is_allowed(key)
allowed, remaining, retry_after = limiter.is_allowed(key)
if not allowed:
return JSONResponse(
+2 -10
View File
@@ -4,8 +4,7 @@ import secrets
from datetime import datetime
from typing import TYPE_CHECKING
import sqlalchemy as sa
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import AccountStatus
@@ -24,20 +23,13 @@ class User(TimestampMixin, Base):
id: Mapped[str] = mapped_column(Text, primary_key=True)
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
display_name: Mapped[str | None] = mapped_column(String(100))
email_verified: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default="false"
)
image: Mapped[str | None] = mapped_column(Text, nullable=True)
email_inbound_token: Mapped[str] = mapped_column(
String(22),
nullable=False,
unique=True,
default=lambda: secrets.token_urlsafe(16),
server_default=sa.text(
"replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
),
)
# Relationships
+1 -7
View File
@@ -19,13 +19,7 @@ async def get_email_in_address(
svc = AuthService(db)
try:
email_address = await svc.get_email_in_address(user_id)
return EmailInAddressResponse(
email_address=email_address,
instructions=(
"Forward your digital receipt emails to this address. "
"We currently support Meijer, Kroger, and Target receipt emails."
),
)
return EmailInAddressResponse(email_address=email_address)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
-1
View File
@@ -24,7 +24,6 @@ class UserResponse(BaseModel):
class EmailInAddressResponse(BaseModel):
email_address: str
instructions: str
# ---------- Stores ----------
+1 -1
View File
@@ -76,4 +76,4 @@ class AuthService:
if not user:
raise LookupError("User not found")
return f"receipts+{user.email_inbound_token}@receipts.cartsnitch.com"
return f"{user.email_inbound_token}@email.cartsnitch.com"
+8 -36
View File
@@ -19,25 +19,6 @@ from cartsnitch_api.database import get_db
from cartsnitch_api.main import create_app
from cartsnitch_api.models import Base
TEST_JWT_SECRET = secrets.token_urlsafe(32)
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
@pytest.fixture(autouse=True)
def setup_test_settings():
original_jwt = cartsnitch_settings.jwt_secret_key
original_service = cartsnitch_settings.service_key
original_fernet = cartsnitch_settings.fernet_key
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
cartsnitch_settings.service_key = TEST_SERVICE_KEY
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
yield
cartsnitch_settings.jwt_secret_key = original_jwt
cartsnitch_settings.service_key = original_service
cartsnitch_settings.fernet_key = original_fernet
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@@ -79,8 +60,7 @@ async def db_engine():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Create Better-Auth tables (not managed by SQLAlchemy models)
await conn.execute(
text("""
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
token TEXT NOT NULL UNIQUE,
@@ -91,10 +71,8 @@ async def db_engine():
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
await conn.execute(
text("""
"""))
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
@@ -110,10 +88,8 @@ async def db_engine():
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
await conn.execute(
text("""
"""))
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS verifications (
id TEXT PRIMARY KEY,
identifier TEXT NOT NULL,
@@ -122,8 +98,7 @@ async def db_engine():
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
)
"""))
yield engine
@@ -158,13 +133,10 @@ async def client(db_engine):
app.dependency_overrides.clear()
async def _create_test_user_and_session(
client: AsyncClient, db_engine, **user_overrides
) -> tuple[dict, str]:
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
"""Create a test user and a valid session directly in the DB.
Returns (user_dict, session_token). Better-Auth stores the raw token
in the DB, so we insert it as-is.
Returns (user_dict, session_token).
"""
user_id = str(uuid.uuid4())
email = user_overrides.get("email", "test@example.com")
@@ -71,56 +71,6 @@ async def test_delete_me(client, auth_headers):
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_get_me_compound_cookie(client, db_engine):
"""Compound cookie value (token.sessionId) must be parsed to extract the token part."""
from tests.conftest import _create_test_user_and_session
_, session_token = await _create_test_user_and_session(
client, db_engine, email="compound@example.com", display_name="Compound User"
)
compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7"
resp = await client.get(
"/auth/me",
headers={"Cookie": f"better-auth.session_token={compound}"},
)
assert resp.status_code == 200
assert resp.json()["email"] == "compound@example.com"
@pytest.mark.asyncio
async def test_get_me_raw_token_cookie(client, db_engine):
"""Raw token (no dot) in cookie must still work — regression guard."""
from tests.conftest import _create_test_user_and_session
_, session_token = await _create_test_user_and_session(
client, db_engine, email="rawcookie@example.com", display_name="Raw Cookie User"
)
resp = await client.get(
"/auth/me",
headers={"Cookie": f"better-auth.session_token={session_token}"},
)
assert resp.status_code == 200
assert resp.json()["email"] == "rawcookie@example.com"
@pytest.mark.asyncio
async def test_get_me_compound_bearer(client, db_engine):
"""Compound Bearer token (token.sessionId) must be parsed to extract the token part."""
from tests.conftest import _create_test_user_and_session
_, session_token = await _create_test_user_and_session(
client, db_engine, email="compoundbearer@example.com", display_name="Compound Bearer User"
)
compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7"
resp = await client.get(
"/auth/me",
headers={"Authorization": f"Bearer {compound}"},
)
assert resp.status_code == 200
assert resp.json()["email"] == "compoundbearer@example.com"
@pytest.mark.asyncio
async def test_expired_session_rejected(client, db_engine):
"""Expired sessions must be rejected."""
-48
View File
@@ -1,48 +0,0 @@
"""Tests for Settings config, specifically the database_url env var fallback."""
import os
from cartsnitch_api.config import Settings
def test_database_url_prefers_cartsnitch_prefix():
"""CARTSNITCH_DATABASE_URL takes precedence over DATABASE_URL."""
env = {
"CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://user1:pass1@host1:5432/db1",
"DATABASE_URL": "postgresql://user2:pass2@host2:5432/db2",
}
settings = Settings(**env)
assert settings.database_url == "postgresql+asyncpg://user1:pass1@host1:5432/db1"
def test_database_url_falls_back_to_database_url():
"""When CARTSNITCH_DATABASE_URL is absent, DATABASE_URL is accepted."""
env = {
"DATABASE_URL": "postgresql://user:pass@dbhost:5432/mydb",
}
settings = Settings(**env)
assert settings.database_url == "postgresql+asyncpg://user:pass@dbhost:5432/mydb"
def test_database_url_normalizes_plain_postgresql_prefix():
"""DATABASE_URL with plain postgresql:// is normalized to postgresql+asyncpg://."""
env = {
"DATABASE_URL": "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
}
settings = Settings(**env)
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
def test_database_url_preserves_asyncpg_prefix():
"""CARTSNITCH_DATABASE_URL with postgresql+asyncpg:// is left unchanged."""
env = {
"CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
}
settings = Settings(**env)
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
def test_database_url_default():
"""When neither env var is set, the hardcoded default is used."""
settings = Settings()
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
+5 -5
View File
@@ -1,4 +1,4 @@
"""Tests for GET /api/v1/me/email-in-address endpoint."""
"""Tests for GET /auth/me/email-in-address endpoint."""
import pytest
from httpx import AsyncClient
@@ -8,7 +8,7 @@ from httpx import AsyncClient
async def test_get_email_in_address_authenticated(client: AsyncClient, auth_headers: dict):
"""Authenticated user gets their email-in address."""
response = await client.get(
"/api/v1/me/email-in-address",
"/auth/me/email-in-address",
headers=auth_headers,
)
@@ -27,7 +27,7 @@ async def test_get_email_in_address_authenticated(client: AsyncClient, auth_head
@pytest.mark.asyncio
async def test_get_email_in_address_unauthenticated(client: AsyncClient):
"""Unauthenticated request returns 401."""
response = await client.get("/api/v1/me/email-in-address")
response = await client.get("/auth/me/email-in-address")
assert response.status_code == 401
@@ -35,7 +35,7 @@ async def test_get_email_in_address_unauthenticated(client: AsyncClient):
async def test_get_email_in_address_invalid_token(client: AsyncClient):
"""Invalid JWT token returns 401."""
response = await client.get(
"/api/v1/me/email-in-address",
"/auth/me/email-in-address",
headers={"Authorization": "Bearer invalid-token-xyz"},
)
assert response.status_code == 401
@@ -45,7 +45,7 @@ async def test_get_email_in_address_invalid_token(client: AsyncClient):
async def test_email_address_format(client: AsyncClient, auth_headers: dict):
"""Email address format is receipts+{22-char-urlsafe-token}@receipts.cartsnitch.com."""
response = await client.get(
"/api/v1/me/email-in-address",
"/auth/me/email-in-address",
headers=auth_headers,
)
+18 -154
View File
@@ -1,184 +1,47 @@
"""Tests for rate limiting middleware."""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow,
RedisSlidingWindow,
_get_client_ip,
_get_rate_limit_key,
)
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
class TestInMemorySlidingWindow:
class TestSlidingWindowCounter:
def test_allows_within_limit(self):
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
for i in range(5):
allowed, remaining, retry = limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key")
assert allowed is True
assert remaining == 4 - i
def test_blocks_over_limit(self):
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
for _ in range(3):
limiter.is_allowed("test-key")
counter.is_allowed("test-key")
allowed, remaining, retry = limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key")
assert allowed is False
assert remaining == 0
assert retry > 0
def test_separate_keys(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
limiter.is_allowed("key-a")
limiter.is_allowed("key-a")
allowed_a, _, _ = limiter.is_allowed("key-a")
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
# Fill key-a
counter.is_allowed("key-a")
counter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
assert allowed_a is False
allowed_b, remaining, _ = limiter.is_allowed("key-b")
# key-b should still be allowed
allowed_b, remaining, _ = counter.is_allowed("key-b")
assert allowed_b is True
assert remaining == 1
def test_resets_after_window_expires(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
for _ in range(2):
limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is False
time.sleep(1.1)
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 1
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestGetRateLimitKey:
def _make_request(
self,
path: str = "/purchases",
method: str = "GET",
auth_header: str = "",
headers: dict | None = None,
) -> MagicMock:
req = MagicMock()
req.url.path = path
req.method = method
req.headers = dict(headers) if headers else {}
if auth_header:
req.headers["authorization"] = auth_header
return req
def test_public_path_uses_public_limiter(self):
req = self._make_request("/public/inflation")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
def test_auth_post_path_uses_strict_limiter(self):
req = self._make_request("/auth/login", method="POST")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_auth_requests
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
def test_auth_get_path_uses_auth_limiter(self):
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("token:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
class TestRedisSlidingWindowFallback:
@pytest.mark.asyncio
async def test_fallback_on_redis_connection_error(self):
mock_redis = AsyncMock()
mock_redis.pipeline.return_value = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Connection refused")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 4
@pytest.mark.asyncio
async def test_fallback_on_redis_error_during_pipeline(self):
mock_redis = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Redis error")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
@pytest.mark.asyncio
async def test_rate_limit_returns_429(client):
"""Public endpoint should return 429 after limit exceeded."""
# The default limit is 60/min — we won't hit it in normal tests,
# but we verify the middleware adds rate limit headers.
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers
@@ -186,6 +49,7 @@ async def test_rate_limit_returns_429(client):
@pytest.mark.asyncio
async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health")
assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers
-1
View File
@@ -95,6 +95,5 @@ export const auth = betterAuth({
"https://cartsnitch.com",
"https://cartsnitch.farh.net",
"https://cartsnitch.dev.farh.net",
"https://cartsnitch.uat.farh.net",
],
});
+1 -1
View File
@@ -14,7 +14,7 @@ if config.config_file_name is not None:
db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC")
if db_url:
config.set_main_option("sqlalchemy.url", db_url.replace("%", "%%"))
config.set_main_option("sqlalchemy.url", db_url)
target_metadata = Base.metadata
@@ -1,37 +0,0 @@
"""Add email_inbound_token to users.
Revision ID: 001_add_email_inbound_token
Revises:
Create Date: 2026-04-02
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "001_add_email_inbound_token"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column("users", sa.Column("email_inbound_token", sa.String(22), nullable=True))
op.create_unique_constraint("uq_users_email_inbound_token", "users", ["email_inbound_token"])
# Backfill existing users with generated tokens (PostgreSQL)
op.execute(
"UPDATE users SET email_inbound_token = "
"substring(replace(gen_random_uuid()::text, '-', ''), 1, 22) "
"WHERE email_inbound_token IS NULL"
)
# Alter to non-nullable
op.alter_column("users", "email_inbound_token", nullable=False)
def downgrade() -> None:
op.drop_constraint("uq_users_email_inbound_token", "users", type_="unique")
op.drop_column("users", "email_inbound_token")
+1 -11
View File
@@ -1,11 +1,10 @@
"""User and UserStoreAccount models."""
import secrets
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, String, Text, UniqueConstraint, text
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_common.constants import AccountStatus
@@ -22,15 +21,6 @@ class User(UUIDPrimaryKeyMixin, TimestampMixin, Base):
__tablename__ = "users"
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
email_inbound_token: Mapped[str] = mapped_column(
String(22),
nullable=False,
unique=True,
default=lambda: secrets.token_urlsafe(16),
server_default=text(
"replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')"
),
)
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
display_name: Mapped[str | None] = mapped_column(String(100))
email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false")
@@ -20,7 +20,6 @@ class UserRead(BaseModel):
id: uuid.UUID
email: str
display_name: str | None
email_inbound_token: str
created_at: datetime
updated_at: datetime
-34
View File
@@ -147,40 +147,6 @@ class TestStoreLocationModel:
assert loc.lat == pytest.approx(42.2808)
class TestUserModel:
def test_email_inbound_token_auto_populated(self, session):
user = User(
id=uuid.uuid4(),
email="token_test@example.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(user)
session.commit()
assert user.email_inbound_token is not None
assert len(user.email_inbound_token) == 22
def test_email_inbound_token_unique(self, session):
user1 = User(
id=uuid.uuid4(),
email="user1@example.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
user2 = User(
id=uuid.uuid4(),
email="user2@example.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add_all([user1, user2])
session.commit()
assert user1.email_inbound_token != user2.email_inbound_token
class TestUserStoreAccountModel:
def test_account_status_enum(self, session):
user = User(
-6
View File
@@ -9,12 +9,6 @@ server {
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript image/svg+xml;
gzip_min_length 256;
# Security headers
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
add_header Content-Security-Policy "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self'; connect-src 'self' https://*.cartsnitch.com https://*.farh.net; frame-ancestors 'self'" always;
# Health endpoint for K8s probes
location /health {
access_log off;
+1 -34
View File
@@ -1,12 +1,8 @@
"""Service-specific configuration for ReceiptWitness."""
from pydantic import model_validator
from pydantic_settings import BaseSettings
_PLACEHOLDER_VALUES = {"change-me-in-production"}
class ReceiptWitnessSettings(BaseSettings):
model_config = {"env_prefix": "RW_"}
@@ -34,34 +30,5 @@ class ReceiptWitnessSettings(BaseSettings):
# Mailgun inbound email webhook
mailgun_webhook_signing_key: str = ""
@model_validator(mode="after")
def validate_required_vars(self):
errors = []
if not self.session_encryption_key or self.session_encryption_key in _PLACEHOLDER_VALUES:
errors.append(
"RW_SESSION_ENCRYPTION_KEY must be set to a secure value. "
'Generate one with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"'
)
if self.notifications_enabled and not self.resend_api_key:
errors.append(
"RW_RESEND_API_KEY must be set when RW_NOTIFICATIONS_ENABLED=true. "
"Get an API key from https://resend.com/api-keys"
)
if errors:
raise ValueError(
"ReceiptWitness startup failed — missing required config:\n"
+ "\n".join(f" - {e}" for e in errors)
)
return self
class _LazySettings:
_instance: ReceiptWitnessSettings | None = None
def __getattr__(self, name: str):
if _LazySettings._instance is None:
_LazySettings._instance = ReceiptWitnessSettings()
return getattr(_LazySettings._instance, name)
settings = _LazySettings()
settings = ReceiptWitnessSettings()
-4
View File
@@ -1,16 +1,12 @@
"""Shared test fixtures."""
import json
import os
from pathlib import Path
import pytest
FIXTURES_DIR = Path(__file__).parent / "fixtures"
os.environ.setdefault("RW_SESSION_ENCRYPTION_KEY", "test-secret-key-for-unit-tests-only-32bytes!")
os.environ.setdefault("RW_MAILGUN_WEBHOOK_SIGNING_KEY", "test-mailgun-signing-key")
@pytest.fixture
def meijer_receipt_data() -> dict:
-46
View File
@@ -1,46 +0,0 @@
import pytest
from receiptwitness.config import ReceiptWitnessSettings
def test_valid_config():
s = ReceiptWitnessSettings(
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
)
assert s.session_encryption_key
def test_missing_session_encryption_key_raises():
with pytest.raises(ValueError, match="RW_SESSION_ENCRYPTION_KEY"):
ReceiptWitnessSettings(session_encryption_key="")
def test_placeholder_session_encryption_key_raises():
with pytest.raises(ValueError, match="RW_SESSION_ENCRYPTION_KEY"):
ReceiptWitnessSettings(session_encryption_key="change-me-in-production")
def test_notifications_enabled_without_resend_key_raises():
with pytest.raises(ValueError, match="RW_RESEND_API_KEY"):
ReceiptWitnessSettings(
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
notifications_enabled=True,
resend_api_key="",
)
def test_notifications_disabled_without_resend_key_ok():
s = ReceiptWitnessSettings(
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
notifications_enabled=False,
resend_api_key="",
)
assert s.notifications_enabled is False
def test_notifications_enabled_with_resend_key_ok():
s = ReceiptWitnessSettings(
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
notifications_enabled=True,
resend_api_key="re_test_1234567890",
)
assert s.resend_api_key == "re_test_1234567890"