Files
cartsnitch-fork-test/api/src/cartsnitch_api/middleware/rate_limit.py
T
2026-04-14 11:36:17 +00:00

113 lines
4.1 KiB
Python

"""Rate limiting middleware for public and authenticated endpoints.
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
"""
import hashlib
import time
from collections import defaultdict
from threading import Lock
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings
class _SlidingWindowCounter:
"""Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
self.max_requests = max_requests
self.window_seconds = window_seconds
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic()
cutoff = now - self.window_seconds
with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key])
if current_count >= self.max_requests:
retry_after = int(self._hits[key][0] - cutoff) + 1
return False, 0, retry_after
self._hits[key].append(now)
remaining = self.max_requests - current_count - 1
return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token)
_public_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds,
)
_auth_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
window_seconds=settings.rate_limit_window_seconds,
)
def _get_client_ip(request: Request) -> str:
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
"""Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter
# For authenticated endpoints, use Bearer token as key if present
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request)
key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key)
if not allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"code": "RATE_LIMITED",
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limiter.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
def add_rate_limit_middleware(app: FastAPI) -> None:
app.add_middleware(RateLimitMiddleware)