Compare commits
68 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a66583b883 | |||
| 4a7d5131fc | |||
| 56b1ff9a36 | |||
| b660336897 | |||
| af713f422b | |||
| 55ab0b7ceb | |||
| 93a94e9777 | |||
| 1bb669f3ca | |||
| 82978f072b | |||
| 9ba745b5a9 | |||
| c13e640864 | |||
| c7b7494151 | |||
| f023480100 | |||
| 9acaf5e83a | |||
| 4e10c75fd0 | |||
| ffdc26cce5 | |||
| 2e96e8f0a7 | |||
| 88ac74e94c | |||
| 66279716ba | |||
| 15ab4ed38c | |||
| fbd77a9434 | |||
| fef5e86645 | |||
| cf39ed1dcd | |||
| 53ffef0ed1 | |||
| cfad4eab37 | |||
| d8e7a416d2 | |||
| c03e599ae3 | |||
| f051e4b4af | |||
| c715c0e47a | |||
| c968088a3f | |||
| 2b32bfdfe1 | |||
| 16200c5500 | |||
| 1803d09095 | |||
| e29bad9a39 | |||
| 349b519a00 | |||
| 7fc524b593 | |||
| 4e139dc4b6 | |||
| 1ce5d738d1 | |||
| 4c217757c3 | |||
| 121dc5724e | |||
| ee45400c7c | |||
| 6481cf03e4 | |||
| 37c75c3887 | |||
| 8a0b2c03a1 | |||
| aa893d9cc1 | |||
| 91c062130c | |||
| 0aef2455fd | |||
| 6602b8c105 | |||
| dbbc8d2e7b | |||
| 1267caf43c | |||
| 015401861a | |||
| 9891e1aefb | |||
| 69ad161e36 | |||
| 485f890df3 | |||
| bf3ed0ede3 | |||
| 3f41eb7346 | |||
| 6cbd1ef298 | |||
| 94214f762e | |||
| 562c6ef6f6 | |||
| ccc8189d88 | |||
| 86594e4a8e | |||
| c2f1a83c1d | |||
| 6f8e5a9577 | |||
| bbfa816e57 | |||
| 5904eb03a2 | |||
| 87b6433ff7 | |||
| d7c9938f7e | |||
| 02434060ee |
@@ -166,6 +166,8 @@ jobs:
|
||||
- name: Scan frontend image for vulnerabilities
|
||||
uses: anchore/scan-action@v5
|
||||
id: scan
|
||||
env:
|
||||
GRYPE_CONFIG: .grype.yaml
|
||||
with:
|
||||
image: "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ github.sha }}"
|
||||
fail-build: true
|
||||
@@ -263,6 +265,8 @@ jobs:
|
||||
- name: Scan auth image for vulnerabilities
|
||||
uses: anchore/scan-action@v5
|
||||
id: scan
|
||||
env:
|
||||
GRYPE_CONFIG: .grype.yaml
|
||||
with:
|
||||
image: "${{ env.REGISTRY }}/${{ env.AUTH_IMAGE_NAME }}:sha-${{ github.sha }}"
|
||||
fail-build: true
|
||||
@@ -343,12 +347,16 @@ jobs:
|
||||
load: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
APT_CACHE_BUST=${{ github.run_id }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Scan receiptwitness image for vulnerabilities
|
||||
uses: anchore/scan-action@v5
|
||||
id: scan
|
||||
env:
|
||||
GRYPE_CONFIG: .grype.yaml
|
||||
with:
|
||||
image: "${{ env.REGISTRY }}/${{ env.RECEIPTWITNESS_IMAGE_NAME }}:sha-${{ github.sha }}"
|
||||
fail-build: true
|
||||
@@ -371,6 +379,8 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
APT_CACHE_BUST=${{ github.run_id }}
|
||||
cache-from: type=gha
|
||||
|
||||
build-and-push-api:
|
||||
@@ -429,12 +439,16 @@ jobs:
|
||||
load: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
APT_CACHE_BUST=${{ github.run_id }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Scan api image for vulnerabilities
|
||||
uses: anchore/scan-action@v5
|
||||
id: scan
|
||||
env:
|
||||
GRYPE_CONFIG: .grype.yaml
|
||||
with:
|
||||
image: "${{ env.REGISTRY }}/${{ env.API_IMAGE_NAME }}:sha-${{ github.sha }}"
|
||||
fail-build: true
|
||||
@@ -457,6 +471,8 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
APT_CACHE_BUST=${{ github.run_id }}
|
||||
cache-from: type=gha
|
||||
|
||||
deploy-dev:
|
||||
@@ -553,6 +569,7 @@ jobs:
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/dev/kustomization.yaml
|
||||
git diff --cached --quiet && echo "No image changes to deploy" && exit 0
|
||||
git commit -m "ci(dev): update cartsnitch, auth, receiptwitness, and api images"
|
||||
git pull --rebase origin main
|
||||
git push origin main
|
||||
@@ -651,6 +668,7 @@ jobs:
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/uat/kustomization.yaml
|
||||
git diff --cached --quiet && echo "No image changes to deploy" && exit 0
|
||||
git commit -m "ci(uat): update cartsnitch, auth, receiptwitness, and api images"
|
||||
git pull --rebase origin main
|
||||
git push origin main
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
ignore:
|
||||
# Python 3.12 CVEs — only fixed in 3.13+, cannot upgrade major version safely
|
||||
- vulnerability: CVE-2025-13836
|
||||
- vulnerability: CVE-2026-4519
|
||||
@@ -1,5 +1,6 @@
|
||||
FROM python:3.12-slim AS build
|
||||
|
||||
ARG APT_CACHE_BUST=0
|
||||
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
|
||||
libpq-dev \
|
||||
build-essential \
|
||||
@@ -12,6 +13,7 @@ RUN pip install --no-cache-dir --prefix=/install .
|
||||
|
||||
FROM python:3.12-slim AS prod
|
||||
|
||||
ARG APT_CACHE_BUST=0
|
||||
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends libpq5 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -47,5 +47,30 @@ class CacheClient:
|
||||
return
|
||||
await self._client.delete(key)
|
||||
|
||||
async def invalidate_price_cache(self, product_id: str) -> None:
|
||||
"""Invalidate all price-related cache entries for a product."""
|
||||
if not self._client:
|
||||
return
|
||||
pattern = f"price:*:{product_id}"
|
||||
await self._delete_pattern(pattern)
|
||||
|
||||
async def invalidate_product_cache(self, product_id: str) -> None:
|
||||
"""Invalidate the product detail cache entry."""
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.delete(f"product:{product_id}")
|
||||
|
||||
async def _delete_pattern(self, pattern: str) -> None:
|
||||
"""Delete all keys matching a pattern using SCAN."""
|
||||
if not self._client:
|
||||
return
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor=cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await self._client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
|
||||
cache_client = CacheClient()
|
||||
|
||||
@@ -32,6 +32,9 @@ class Settings(BaseSettings):
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_auth_requests: int = 5
|
||||
rate_limit_auth_window_seconds: int = 60
|
||||
rate_limit_redis_enabled: bool = True
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
|
||||
def normalize_database_url(self):
|
||||
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
|
||||
if self.database_url.startswith("postgresql://"):
|
||||
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
self.database_url = self.database_url.replace(
|
||||
"postgresql://", "postgresql+asyncpg://", 1
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,31 @@ Per-IP limiting on public endpoints, per-token limiting on authenticated endpoin
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis, RedisError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
|
||||
class RateLimitBackend(Protocol):
|
||||
"""Protocol for rate limit backends."""
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
|
||||
|
||||
class InMemorySlidingWindow:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
@@ -25,13 +38,12 @@ class _SlidingWindowCounter:
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
@@ -44,15 +56,84 @@ class _SlidingWindowCounter:
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
class RedisSlidingWindow:
|
||||
"""Redis-backed sliding window rate limiter using sorted sets."""
|
||||
|
||||
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
|
||||
self.redis = redis
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
try:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
now_ms = int(now * 1000)
|
||||
cutoff_ms = int(cutoff * 1000)
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, cutoff_ms)
|
||||
pipe.zcard(key)
|
||||
results = await pipe.execute()
|
||||
|
||||
current_count = results[1]
|
||||
|
||||
if current_count >= self.max_requests:
|
||||
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
|
||||
if oldest:
|
||||
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
|
||||
else:
|
||||
retry_after = self.window_seconds
|
||||
return False, 0, retry_after
|
||||
|
||||
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zadd(key, {member: now_ms})
|
||||
pipe.expire(key, self.window_seconds)
|
||||
await pipe.execute()
|
||||
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
except RedisError as e:
|
||||
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
|
||||
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
|
||||
return await in_memory.is_allowed(key)
|
||||
|
||||
|
||||
_redis_client: Redis | None = None
|
||||
_use_redis = False
|
||||
|
||||
if settings.rate_limit_redis_enabled:
|
||||
try:
|
||||
_redis_client = Redis.from_url(settings.redis_url)
|
||||
_use_redis = True
|
||||
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
|
||||
_use_redis = False
|
||||
|
||||
if _use_redis and _redis_client:
|
||||
_public_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
else:
|
||||
_public_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
@@ -63,30 +144,30 @@ def _get_client_ip(request: Request) -> str:
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
if request.url.path.startswith("/auth/") and request.method == "POST":
|
||||
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
return f"token:{token_hash}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
|
||||
@@ -1,49 +1,184 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.middleware.rate_limit import (
|
||||
InMemorySlidingWindow,
|
||||
RedisSlidingWindow,
|
||||
_get_client_ip,
|
||||
_get_rate_limit_key,
|
||||
)
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
class TestInMemorySlidingWindow:
|
||||
def test_allows_within_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
||||
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
|
||||
for i in range(5):
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4 - i
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
||||
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
||||
for _ in range(3):
|
||||
counter.is_allowed("test-key")
|
||||
limiter.is_allowed("test-key")
|
||||
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert remaining == 0
|
||||
assert retry > 0
|
||||
|
||||
def test_separate_keys(self):
|
||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
||||
# Fill key-a
|
||||
counter.is_allowed("key-a")
|
||||
counter.is_allowed("key-a")
|
||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
||||
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
|
||||
limiter.is_allowed("key-a")
|
||||
limiter.is_allowed("key-a")
|
||||
allowed_a, _, _ = limiter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
||||
allowed_b, remaining, _ = limiter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
def test_resets_after_window_expires(self):
|
||||
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
|
||||
for _ in range(2):
|
||||
limiter.is_allowed("test-key")
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
|
||||
time.sleep(1.1)
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def test_x_forwarded_for_single(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_with_port(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_no_forwarded_header(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client.host = "127.0.0.1"
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
def test_no_client(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(
|
||||
self,
|
||||
path: str = "/purchases",
|
||||
method: str = "GET",
|
||||
auth_header: str = "",
|
||||
headers: dict | None = None,
|
||||
) -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.url.path = path
|
||||
req.method = method
|
||||
req.headers = dict(headers) if headers else {}
|
||||
if auth_header:
|
||||
req.headers["authorization"] = auth_header
|
||||
return req
|
||||
|
||||
def test_public_path_uses_public_limiter(self):
|
||||
req = self._make_request("/public/inflation")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests
|
||||
|
||||
def test_auth_post_path_uses_strict_limiter(self):
|
||||
req = self._make_request("/auth/login", method="POST")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_auth_requests
|
||||
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
|
||||
|
||||
def test_auth_get_path_uses_auth_limiter(self):
|
||||
req = self._make_request("/auth/me", method="GET")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||
|
||||
def test_authenticated_token_uses_auth_limiter(self):
|
||||
req = self._make_request("/purchases", auth_header="Bearer token123")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("token:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||
|
||||
def test_distinct_tokens_produce_distinct_keys(self):
|
||||
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
|
||||
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_token_produces_same_key(self):
|
||||
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 == key2
|
||||
|
||||
def test_key_does_not_contain_raw_token_suffix(self):
|
||||
raw_token = "my_secret_jwt_token_xyz"
|
||||
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
|
||||
key, _ = _get_rate_limit_key(req)
|
||||
assert raw_token[-16:] not in key
|
||||
assert raw_token not in key
|
||||
|
||||
|
||||
class TestRedisSlidingWindowFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_connection_error(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline.return_value = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Connection refused")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_error_during_pipeline(self):
|
||||
mock_redis = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Redis error")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
"""Public endpoint should return 429 after limit exceeded."""
|
||||
# The default limit is 60/min — we won't hit it in normal tests,
|
||||
# but we verify the middleware adds rate limit headers.
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
@@ -51,36 +186,6 @@ async def test_rate_limit_returns_429(client):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(self, auth_header: str = "") -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.url.path = "/purchases"
|
||||
req.headers = {"authorization": auth_header} if auth_header else {}
|
||||
return req
|
||||
|
||||
def test_distinct_tokens_produce_distinct_keys(self):
|
||||
req1 = self._make_request("Bearer token_alpha_12345")
|
||||
req2 = self._make_request("Bearer token_beta_67890")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_token_produces_same_key(self):
|
||||
req1 = self._make_request("Bearer same_token_value_abc")
|
||||
req2 = self._make_request("Bearer same_token_value_abc")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 == key2
|
||||
|
||||
def test_key_does_not_contain_raw_token_suffix(self):
|
||||
raw_token = "my_secret_jwt_token_xyz"
|
||||
req = self._make_request(f"Bearer {raw_token}")
|
||||
key, _ = _get_rate_limit_key(req)
|
||||
assert raw_token[-16:] not in key
|
||||
assert raw_token not in key
|
||||
|
||||
+1
-1
@@ -37,7 +37,7 @@ export const auth = betterAuth({
|
||||
maxPasswordLength: 128,
|
||||
password: {
|
||||
hash: async (password: string) => {
|
||||
return bcrypt.hash(password, 10);
|
||||
return bcrypt.hash(password, 12);
|
||||
},
|
||||
verify: async (data: { hash: string; password: string }) => {
|
||||
return bcrypt.compare(data.password, data.hash);
|
||||
|
||||
+89
-1
@@ -1,4 +1,4 @@
|
||||
import { test as base, expect } from "@playwright/test";
|
||||
import { test as base, expect, type Page } from "@playwright/test";
|
||||
import AxeBuilder from "@axe-core/playwright";
|
||||
|
||||
export const test = base.extend<{ axeCheck: void }>({
|
||||
@@ -10,3 +10,91 @@ export const test = base.extend<{ axeCheck: void }>({
|
||||
});
|
||||
|
||||
export { expect } from "@playwright/test";
|
||||
|
||||
const MOCK_USER_ID = "mock_user_123";
|
||||
const MOCK_SESSION_ID = "mock_session_456";
|
||||
|
||||
async function mockAuthRoutes(page: Page, authenticated = false) {
|
||||
await page.route(/.*\/auth\/sign-up\/email.*/, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
token: null,
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(/.*\/auth\/sign-in\/email.*/, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
redirect: false,
|
||||
token: "mock_token_123",
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(/.*\/auth\/get-session.*/, async (route) => {
|
||||
if (authenticated) {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
session: {
|
||||
id: MOCK_SESSION_ID,
|
||||
expiresAt: new Date(Date.now() + 7 * 24 * 60 * 60 * 1000).toISOString(),
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
ipAddress: null,
|
||||
userAgent: null,
|
||||
},
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 401,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ error: "Unauthorized" }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export async function mockSessionDelayed(page: Page, delayMs = 3000) {
|
||||
await page.route(/.*\/auth\/get-session.*/, async (route) => {
|
||||
await new Promise((r) => setTimeout(r, delayMs));
|
||||
await route.fulfill({
|
||||
status: 401,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ error: "Unauthorized" }),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export { mockAuthRoutes };
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { mockAuthRoutes } from '../fixtures';
|
||||
|
||||
const uniqueEmail = () => `betty+e2e-${Date.now()}@cartsnitch.test`;
|
||||
|
||||
test.describe('J1: Registration and Login', () => {
|
||||
test('can register a new account and lands on dashboard', async ({ page }) => {
|
||||
test('can register a new account and see check your email screen', async ({ page }) => {
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/register');
|
||||
await page.fill('[placeholder="Full Name"]', 'Betty Tester');
|
||||
await page.fill('[placeholder="Email"]', uniqueEmail());
|
||||
await page.fill('[placeholder="Password (min. 8 characters)"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
// With VITE_MOCK_AUTH=true the app navigates to "/" on success
|
||||
await expect(page).toHaveURL('http://localhost:5173/');
|
||||
await expect(page.getByRole('heading', { name: /cart/i })).toBeVisible();
|
||||
await expect(page.getByRole('heading', { name: /check your email/i })).toBeVisible();
|
||||
});
|
||||
|
||||
test('shows validation error when registration fields are empty', async ({ page }) => {
|
||||
@@ -31,22 +31,9 @@ test.describe('J1: Registration and Login', () => {
|
||||
});
|
||||
|
||||
test('can sign in with credentials and land on dashboard', async ({ page }) => {
|
||||
// Register first so we have a real account
|
||||
const email = uniqueEmail();
|
||||
await page.goto('/register');
|
||||
await page.fill('[placeholder="Full Name"]', 'Login Betty');
|
||||
await page.fill('[placeholder="Email"]', email);
|
||||
await page.fill('[placeholder="Password (min. 8 characters)"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
await expect(page).toHaveURL('http://localhost:5173/');
|
||||
|
||||
// Sign out by clearing the mock session (reload with no session)
|
||||
await page.goto('/');
|
||||
await page.reload();
|
||||
|
||||
// Now sign in
|
||||
await mockAuthRoutes(page, true);
|
||||
await page.goto('/login');
|
||||
await page.fill('[placeholder="Email"]', email);
|
||||
await page.fill('[placeholder="Email"]', 'test@cartsnitch.test');
|
||||
await page.fill('[placeholder="Password"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { mockAuthRoutes, mockSessionDelayed } from '../fixtures';
|
||||
|
||||
test.describe('J8: Unauthenticated Access', () => {
|
||||
test('redirects /dashboard (/) to /login when not authenticated', async ({ page }) => {
|
||||
// No session cookie — start fresh
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -11,7 +11,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /purchases to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/purchases');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -19,7 +19,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /products to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/products');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -27,7 +27,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /coupons to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/coupons');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -35,15 +35,9 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('shows loading spinner while auth session is pending', async ({ page }) => {
|
||||
// Intercept but don't respond — session stays pending
|
||||
await page.context().clearCookies();
|
||||
await page.request.fetch('/api/auth/session', {
|
||||
method: 'GET',
|
||||
});
|
||||
|
||||
// Just navigate to a protected route — ProtectedRoute will show spinner while session is pending
|
||||
await mockSessionDelayed(page, 3000);
|
||||
await page.goto('/purchases');
|
||||
// Spinner is visible briefly; once resolved, should redirect to login
|
||||
await expect(page.locator('.animate-spin')).toBeVisible({ timeout: 2000 });
|
||||
await expect(page).toHaveURL(/\/login/, { timeout: 10_000 });
|
||||
});
|
||||
});
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
import { test, expect } from './fixtures';
|
||||
import { test, expect, mockAuthRoutes } from './fixtures';
|
||||
|
||||
test('app loads', async ({ page }) => {
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/');
|
||||
// Unauthenticated users are redirected to /login
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
await expect(page.getByRole('heading', { name: /CartSnitch/i })).toBeVisible();
|
||||
});
|
||||
|
||||
@@ -9,7 +9,7 @@ export default defineConfig({
|
||||
},
|
||||
],
|
||||
webServer: {
|
||||
command: 'VITE_MOCK_AUTH=true npm run dev',
|
||||
command: 'npm run dev',
|
||||
url: 'http://localhost:5173',
|
||||
reuseExistingServer: !process.env.CI,
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ WORKDIR /app
|
||||
|
||||
# build-essential and libpq-dev are needed to compile any C-extension wheels
|
||||
# (e.g. psycopg2 fallback). No git needed — common/ is copied from the repo root.
|
||||
ARG APT_CACHE_BUST=0
|
||||
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
|
||||
libpq-dev \
|
||||
build-essential \
|
||||
@@ -25,6 +26,7 @@ FROM python:3.12-slim AS prod
|
||||
WORKDIR /app
|
||||
|
||||
# Install Playwright system dependencies for Chromium
|
||||
ARG APT_CACHE_BUST=0
|
||||
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
|
||||
libnss3 \
|
||||
libatk1.0-0 \
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
import { useEffect } from 'react'
|
||||
import { Navigate, Outlet } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function ProtectedRoute() {
|
||||
const isMockAuth = import.meta.env.VITE_MOCK_AUTH === 'true'
|
||||
const { data: session, isPending } = authClient.useSession()
|
||||
const isAuthenticated = useAuthStore((s) => s.isAuthenticated)
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
|
||||
useEffect(() => {
|
||||
if (!isMockAuth) {
|
||||
setAuthenticated(!!session)
|
||||
}
|
||||
}, [session, setAuthenticated, isMockAuth])
|
||||
|
||||
// In mock auth mode, rely on Zustand store (set by Login/Register pages)
|
||||
if (isMockAuth) {
|
||||
if (!isAuthenticated) return <Navigate to="/login" replace />
|
||||
return <Outlet />
|
||||
}
|
||||
|
||||
if (isPending) {
|
||||
return (
|
||||
|
||||
@@ -79,21 +79,21 @@ function AuthenticatedDashboard({ userName }: { userName: string }) {
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm">
|
||||
<p className="text-xs font-medium text-gray-500">Watching</p>
|
||||
<p className="mt-1 text-2xl font-bold text-gray-900">{watchingAlerts.length}</p>
|
||||
<p className="text-xs text-gray-400">price alerts</p>
|
||||
<p className="text-xs text-gray-600">price alerts</p>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm">
|
||||
<p className="text-xs font-medium text-gray-500">This Month</p>
|
||||
<p className="mt-1 text-2xl font-bold text-gray-900">
|
||||
${recentPurchases.reduce((sum, p) => sum + p.total, 0).toFixed(0)}
|
||||
</p>
|
||||
<p className="text-xs text-gray-400">grocery spend</p>
|
||||
<p className="text-xs text-gray-600">grocery spend</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Price trend sparklines */}
|
||||
<section className="mt-6">
|
||||
<h2 className="mb-3 text-lg font-semibold text-gray-700">Price Trends</h2>
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm text-center text-sm text-gray-400">
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm text-center text-sm text-gray-600">
|
||||
Connect a store to see price trends
|
||||
</div>
|
||||
</section>
|
||||
|
||||
+1
-8
@@ -1,7 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
import { Link, useNavigate } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function Login() {
|
||||
const [email, setEmail] = useState('')
|
||||
@@ -9,7 +8,6 @@ export function Login() {
|
||||
const [error, setError] = useState('')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const navigate = useNavigate()
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
|
||||
async function handleSubmit(e: React.FormEvent) {
|
||||
e.preventDefault()
|
||||
@@ -40,12 +38,7 @@ export function Login() {
|
||||
setError('Sign in failed. Please try again.')
|
||||
}
|
||||
} catch {
|
||||
if (import.meta.env.VITE_MOCK_AUTH === 'true') {
|
||||
setAuthenticated(true)
|
||||
navigate('/')
|
||||
} else {
|
||||
setError('Invalid email or password. Please try again.')
|
||||
}
|
||||
setError('Invalid email or password. Please try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
+2
-10
@@ -1,7 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
import { Link, useNavigate } from 'react-router-dom'
|
||||
import { Link } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function Register() {
|
||||
const [name, setName] = useState('')
|
||||
@@ -12,8 +11,6 @@ export function Register() {
|
||||
const [registrationComplete, setRegistrationComplete] = useState(false)
|
||||
const [resendLoading, setResendLoading] = useState(false)
|
||||
const [resendMessage, setResendMessage] = useState('')
|
||||
const navigate = useNavigate()
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
|
||||
async function handleSubmit(e: React.FormEvent) {
|
||||
e.preventDefault()
|
||||
@@ -43,12 +40,7 @@ export function Register() {
|
||||
|
||||
setRegistrationComplete(true)
|
||||
} catch {
|
||||
if (import.meta.env.VITE_MOCK_AUTH === 'true') {
|
||||
setAuthenticated(true)
|
||||
navigate('/')
|
||||
} else {
|
||||
setError('Registration failed. Please try again.')
|
||||
}
|
||||
setError('Registration failed. Please try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user