chore: promote UAT to production (CAR-662, audit logging middleware)

chore: promote UAT to production (CAR-662, audit logging middleware)
This commit is contained in:
cartsnitch-ceo[bot]
2026-04-15 04:29:39 +00:00
committed by GitHub
10 changed files with 504 additions and 96 deletions
+6 -2
View File
@@ -69,7 +69,9 @@ async def get_current_user(
token: str | None = None token: str | None = None
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev) # 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME) cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(
SESSION_COOKIE_NAME
)
if cookie_token: if cookie_token:
# Better-Auth cookie format is "token.sessionId" — extract just the token part # Better-Auth cookie format is "token.sessionId" — extract just the token part
token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token
@@ -86,7 +88,9 @@ async def get_current_user(
detail="Authentication required", detail="Authentication required",
) )
return await _validate_session_token(token, db) user_id = await _validate_session_token(token, db)
request.state.user_id = user_id
return user_id
async def verify_service_key(x_service_key: str = Header()) -> None: async def verify_service_key(x_service_key: str = Header()) -> None:
+25
View File
@@ -47,5 +47,30 @@ class CacheClient:
return return
await self._client.delete(key) await self._client.delete(key)
async def invalidate_price_cache(self, product_id: str) -> None:
"""Invalidate all price-related cache entries for a product."""
if not self._client:
return
pattern = f"price:*:{product_id}"
await self._delete_pattern(pattern)
async def invalidate_product_cache(self, product_id: str) -> None:
"""Invalidate the product detail cache entry."""
if not self._client:
return
await self._client.delete(f"product:{product_id}")
async def _delete_pattern(self, pattern: str) -> None:
"""Delete all keys matching a pattern using SCAN."""
if not self._client:
return
cursor = 0
while True:
cursor, keys = await self._client.scan(cursor=cursor, match=pattern, count=100)
if keys:
await self._client.delete(*keys)
if cursor == 0:
break
cache_client = CacheClient() cache_client = CacheClient()
+6 -1
View File
@@ -32,6 +32,9 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60 rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60 rate_limit_window_seconds: int = 60
rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True
rate_limit_enabled: bool = True rate_limit_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"} _PLACEHOLDER_VALUES = {"change-me-in-production"}
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
def normalize_database_url(self): def normalize_database_url(self):
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver.""" """Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
if self.database_url.startswith("postgresql://"): if self.database_url.startswith("postgresql://"):
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1) self.database_url = self.database_url.replace(
"postgresql://", "postgresql+asyncpg://", 1
)
return self return self
+2
View File
@@ -10,6 +10,7 @@ from cartsnitch_api.database import dispose_engine
from cartsnitch_api.middleware.cors import add_cors_middleware from cartsnitch_api.middleware.cors import add_cors_middleware
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
from cartsnitch_api.middleware.audit import add_audit_middleware
from cartsnitch_api.routes.alerts import router as alerts_router from cartsnitch_api.routes.alerts import router as alerts_router
from cartsnitch_api.routes.coupons import router as coupons_router from cartsnitch_api.routes.coupons import router as coupons_router
from cartsnitch_api.routes.health import router as health_router from cartsnitch_api.routes.health import router as health_router
@@ -43,6 +44,7 @@ def create_app() -> FastAPI:
add_cors_middleware(app) add_cors_middleware(app)
add_error_monitor_middleware(app) add_error_monitor_middleware(app)
add_rate_limit_middleware(app) add_rate_limit_middleware(app)
add_audit_middleware(app)
# Exception handlers # Exception handlers
add_error_handlers(app) add_error_handlers(app)
+64
View File
@@ -0,0 +1,64 @@
"""Audit logging middleware for sensitive API operations.
Logs structured JSON for POST/PUT/PATCH/DELETE requests and GET /auth/me.
Never logs request bodies, response bodies, Authorization headers, or cookie values.
"""
import json
import logging
import time
from collections.abc import Awaitable, Callable
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger("cartsnitch_api.audit")
HEALTH_PATHS = {"/health", "/healthz", "/ready"}
class AuditMiddleware(BaseHTTPMiddleware):
"""Middleware to log structured audit events for sensitive operations."""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable],
):
if request.method == "OPTIONS" or request.url.path in HEALTH_PATHS:
return await call_next(request)
method = request.method
path = request.url.path
is_sensitive_write = method in {"POST", "PUT", "PATCH", "DELETE"}
is_auth_me_read = method == "GET" and path == "/auth/me"
if not (is_sensitive_write or is_auth_me_read):
return await call_next(request)
start = time.perf_counter()
response = await call_next(request)
duration_ms = (time.perf_counter() - start) * 1000
user_id = getattr(request.state, "user_id", None)
client_ip = request.client.host if request.client else "unknown"
log_entry = {
"event": "audit",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"user_id": user_id,
"method": method,
"path": path,
"client_ip": client_ip,
"status_code": response.status_code,
"duration_ms": round(duration_ms, 2),
}
logger.info(json.dumps(log_entry))
return response
def add_audit_middleware(app: FastAPI) -> None:
app.add_middleware(AuditMiddleware)
+98 -17
View File
@@ -5,18 +5,31 @@ Per-IP limiting on public endpoints, per-token limiting on authenticated endpoin
""" """
import hashlib import hashlib
import logging
import time import time
import uuid
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
from typing import Protocol
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from redis.asyncio import Redis, RedisError
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class _SlidingWindowCounter:
class RateLimitBackend(Protocol):
"""Protocol for rate limit backends."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
class InMemorySlidingWindow:
"""Thread-safe in-memory sliding window rate limiter.""" """Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None: def __init__(self, max_requests: int, window_seconds: int) -> None:
@@ -25,13 +38,12 @@ class _SlidingWindowCounter:
self._hits: dict[str, list[float]] = defaultdict(list) self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock() self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]: async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after).""" """Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic() now = time.monotonic()
cutoff = now - self.window_seconds cutoff = now - self.window_seconds
with self._lock: with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff] self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key]) current_count = len(self._hits[key])
@@ -44,15 +56,84 @@ class _SlidingWindowCounter:
return True, remaining, 0 return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token) class RedisSlidingWindow:
_public_limiter = _SlidingWindowCounter( """Redis-backed sliding window rate limiter using sorted sets."""
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds, def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
) self.redis = redis
_auth_limiter = _SlidingWindowCounter( self.max_requests = max_requests
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users self.window_seconds = window_seconds
window_seconds=settings.rate_limit_window_seconds,
) async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
try:
now = time.monotonic()
cutoff = now - self.window_seconds
now_ms = int(now * 1000)
cutoff_ms = int(cutoff * 1000)
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, cutoff_ms)
pipe.zcard(key)
results = await pipe.execute()
current_count = results[1]
if current_count >= self.max_requests:
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if oldest:
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
pipe = self.redis.pipeline()
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except RedisError as e:
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
return await in_memory.is_allowed(key)
_redis_client: Redis | None = None
_use_redis = False
if settings.rate_limit_redis_enabled:
try:
_redis_client = Redis.from_url(settings.redis_url)
_use_redis = True
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
except Exception as e:
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
_use_redis = False
if _use_redis and _redis_client:
_public_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
else:
_public_limiter = InMemorySlidingWindow(
settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = InMemorySlidingWindow(
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
def _get_client_ip(request: Request) -> str: def _get_client_ip(request: Request) -> str:
@@ -63,30 +144,30 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown" return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
"""Determine rate limit key and which limiter to use.""" """Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"): if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
# For authenticated endpoints, use Bearer token as key if present if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest() token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter return f"token:{token_hash}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health": if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request) return await call_next(request)
key, limiter = _get_rate_limit_key(request) key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key) allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed: if not allowed:
return JSONResponse( return JSONResponse(
+14 -5
View File
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse) @router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): async def public_price_trend(
product_id: UUID,
days: int = Query(90, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
try: try:
return await svc.get_trend(product_id) return await svc.get_trend(product_id, days=days)
except LookupError: except LookupError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) @router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison( async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)], product_ids: Annotated[list[UUID], Query(max_length=20)],
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not product_ids: if not product_ids:
@@ -39,10 +44,14 @@ async def public_store_comparison(
detail="At least one product_id is required", detail="At least one product_id is required",
) )
svc = PublicService(db) svc = PublicService(db)
return await svc.get_store_comparison(product_ids) return await svc.get_store_comparison(product_ids, category=category)
@router.get("/inflation", response_model=PublicInflationResponse) @router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation(db: AsyncSession = Depends(get_db)): async def public_inflation(
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
return await svc.get_inflation() return await svc.get_inflation(category=category, period=period)
+42 -23
View File
@@ -1,5 +1,6 @@
"""Public service — unauthenticated price transparency endpoints.""" """Public service — unauthenticated price transparency endpoints."""
from datetime import date, timedelta
from uuid import UUID from uuid import UUID
from sqlalchemy import and_, func, select from sqlalchemy import and_, func, select
@@ -13,7 +14,7 @@ class PublicService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def get_trend(self, product_id: UUID) -> dict: async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute( result = await self.db.execute(
@@ -23,9 +24,13 @@ class PublicService:
if not product: if not product:
raise LookupError("Product not found") raise LookupError("Product not found")
date_threshold = date.today() - timedelta(days=days)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id) .where(
PriceHistory.normalized_product_id == product_id,
PriceHistory.observed_date >= date_threshold,
)
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date) .order_by(PriceHistory.observed_date)
) )
@@ -45,20 +50,25 @@ class PublicService:
], ],
} }
async def get_store_comparison(self, product_ids: list[UUID]) -> dict: async def get_store_comparison(
self, product_ids: list[UUID], category: str | None = None
) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids: if not product_ids:
return {"products": []} return {"products": []}
# Fetch all products in one query product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
prod_result = await self.db.execute( if category:
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) product_query = product_query.where(NormalizedProduct.category == category)
) prod_result = await self.db.execute(product_query)
products_by_id = {p.id: p for p in prod_result.scalars().all()} products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query if not products_by_id:
subq = latest_price_per_store(product_ids) return {"products": []}
filtered_product_ids = list(products_by_id.keys())
subq = latest_price_per_store(filtered_product_ids)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.join( .join(
@@ -69,18 +79,17 @@ class PublicService:
PriceHistory.normalized_product_id == subq.c.normalized_product_id, PriceHistory.normalized_product_id == subq.c.normalized_product_id,
), ),
) )
.where(PriceHistory.normalized_product_id.in_(product_ids)) .where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
) )
all_prices = prices_result.scalars().all() all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {} prices_by_product: dict[UUID, list] = {}
for ph in all_prices: for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = [] products = []
for pid in product_ids: for pid in filtered_product_ids:
product = products_by_id.get(pid) product = products_by_id.get(pid)
if not product: if not product:
continue continue
@@ -102,19 +111,29 @@ class PublicService:
return {"products": products} return {"products": products}
async def get_inflation(self) -> dict: async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
"""Aggregate price change stats. Compares average prices across periods.""" """Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
# Get average prices grouped by category for recent vs older data date_threshold = None
result = await self.db.execute( if period != "all-time":
select( days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
NormalizedProduct.category, days = days_map.get(period, 365)
func.avg(PriceHistory.regular_price), date_threshold = date.today() - timedelta(days=days)
)
.join(NormalizedProduct) query = select(
.group_by(NormalizedProduct.category) NormalizedProduct.category,
) func.avg(PriceHistory.regular_price),
).join(NormalizedProduct)
if category:
query = query.where(NormalizedProduct.category == category)
if date_threshold:
query = query.where(PriceHistory.observed_date >= date_threshold)
query = query.group_by(NormalizedProduct.category)
result = await self.db.execute(query)
categories = {} categories = {}
for row in result.all(): for row in result.all():
cat, avg_price = row cat, avg_price = row
@@ -122,7 +141,7 @@ class PublicService:
categories[cat] = float(avg_price) if avg_price else 0.0 categories[cat] = float(avg_price) if avg_price else 0.0
return { return {
"period": "all-time", "period": period,
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1), "cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0, "cpi_baseline": 100.0,
"categories": categories, "categories": categories,
+153 -48
View File
@@ -1,49 +1,184 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
from unittest.mock import MagicMock import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow,
RedisSlidingWindow,
_get_client_ip,
_get_rate_limit_key,
)
class TestSlidingWindowCounter: class TestInMemorySlidingWindow:
def test_allows_within_limit(self): def test_allows_within_limit(self):
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
for i in range(5): for i in range(5):
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is True assert allowed is True
assert remaining == 4 - i assert remaining == 4 - i
def test_blocks_over_limit(self): def test_blocks_over_limit(self):
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
for _ in range(3): for _ in range(3):
counter.is_allowed("test-key") limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is False assert allowed is False
assert remaining == 0 assert remaining == 0
assert retry > 0 assert retry > 0
def test_separate_keys(self): def test_separate_keys(self):
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
# Fill key-a limiter.is_allowed("key-a")
counter.is_allowed("key-a") limiter.is_allowed("key-a")
counter.is_allowed("key-a") allowed_a, _, _ = limiter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
assert allowed_a is False assert allowed_a is False
# key-b should still be allowed allowed_b, remaining, _ = limiter.is_allowed("key-b")
allowed_b, remaining, _ = counter.is_allowed("key-b")
assert allowed_b is True assert allowed_b is True
assert remaining == 1 assert remaining == 1
def test_resets_after_window_expires(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
for _ in range(2):
limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is False
time.sleep(1.1)
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 1
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestGetRateLimitKey:
def _make_request(
self,
path: str = "/purchases",
method: str = "GET",
auth_header: str = "",
headers: dict | None = None,
) -> MagicMock:
req = MagicMock()
req.url.path = path
req.method = method
req.headers = dict(headers) if headers else {}
if auth_header:
req.headers["authorization"] = auth_header
return req
def test_public_path_uses_public_limiter(self):
req = self._make_request("/public/inflation")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
def test_auth_post_path_uses_strict_limiter(self):
req = self._make_request("/auth/login", method="POST")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_auth_requests
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
def test_auth_get_path_uses_auth_limiter(self):
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("token:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
class TestRedisSlidingWindowFallback:
@pytest.mark.asyncio
async def test_fallback_on_redis_connection_error(self):
mock_redis = AsyncMock()
mock_redis.pipeline.return_value = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Connection refused")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 4
@pytest.mark.asyncio
async def test_fallback_on_redis_error_during_pipeline(self):
mock_redis = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Redis error")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rate_limit_returns_429(client): async def test_rate_limit_returns_429(client):
"""Public endpoint should return 429 after limit exceeded."""
# The default limit is 60/min — we won't hit it in normal tests,
# but we verify the middleware adds rate limit headers.
resp = await client.get("/public/inflation") resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers assert "x-ratelimit-remaining" in resp.headers
@@ -51,36 +186,6 @@ async def test_rate_limit_returns_429(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_skips_rate_limit(client): async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey:
def _make_request(self, auth_header: str = "") -> MagicMock:
req = MagicMock()
req.url.path = "/purchases"
req.headers = {"authorization": auth_header} if auth_header else {}
return req
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
+94
View File
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
data = resp.json() data = resp.json()
assert "categories" in data assert "categories" in data
assert "cartsnitch_index" in data assert "cartsnitch_index" in data
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()