8a4c194e39
- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings - Add rate_limit_redis_enabled flag for opt-in Redis usage - Refactor _SlidingWindowCounter into InMemorySlidingWindow class - Add RedisSlidingWindow using sorted sets with fallback to in-memory - Add third _auth_strict_limiter for POST /auth/* paths (5 req/min) - Add protocol-based backend selection at module load time - Update tests for auth strict limiter and Redis fallback behavior Co-Authored-By: Paperclip <noreply@paperclip.ing>
194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
"""Rate limiting middleware for public and authenticated endpoints.
|
|
|
|
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
|
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict
|
|
from threading import Lock
|
|
from typing import Protocol
|
|
|
|
from fastapi import FastAPI, Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from redis.asyncio import Redis, RedisError
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from cartsnitch_api.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitBackend(Protocol):
|
|
"""Protocol for rate limit backends."""
|
|
|
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
|
|
|
|
|
class InMemorySlidingWindow:
|
|
"""Thread-safe in-memory sliding window rate limiter."""
|
|
|
|
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
self._hits: dict[str, list[float]] = defaultdict(list)
|
|
self._lock = Lock()
|
|
|
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
|
now = time.monotonic()
|
|
cutoff = now - self.window_seconds
|
|
|
|
with self._lock:
|
|
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
|
|
|
current_count = len(self._hits[key])
|
|
if current_count >= self.max_requests:
|
|
retry_after = int(self._hits[key][0] - cutoff) + 1
|
|
return False, 0, retry_after
|
|
|
|
self._hits[key].append(now)
|
|
remaining = self.max_requests - current_count - 1
|
|
return True, remaining, 0
|
|
|
|
|
|
class RedisSlidingWindow:
|
|
"""Redis-backed sliding window rate limiter using sorted sets."""
|
|
|
|
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
|
|
self.redis = redis
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
|
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
|
try:
|
|
now = time.monotonic()
|
|
cutoff = now - self.window_seconds
|
|
now_ms = int(now * 1000)
|
|
cutoff_ms = int(cutoff * 1000)
|
|
|
|
pipe = self.redis.pipeline()
|
|
pipe.zremrangebyscore(key, 0, cutoff_ms)
|
|
pipe.zcard(key)
|
|
results = await pipe.execute()
|
|
|
|
current_count = results[1]
|
|
|
|
if current_count >= self.max_requests:
|
|
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
|
|
if oldest:
|
|
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
|
|
else:
|
|
retry_after = self.window_seconds
|
|
return False, 0, retry_after
|
|
|
|
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
|
|
pipe = self.redis.pipeline()
|
|
pipe.zadd(key, {member: now_ms})
|
|
pipe.expire(key, self.window_seconds)
|
|
await pipe.execute()
|
|
|
|
remaining = self.max_requests - current_count - 1
|
|
return True, remaining, 0
|
|
|
|
except RedisError as e:
|
|
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
|
|
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
|
|
return await in_memory.is_allowed(key)
|
|
|
|
|
|
_redis_client: Redis | None = None
|
|
_use_redis = False
|
|
|
|
if settings.rate_limit_redis_enabled:
|
|
try:
|
|
_redis_client = Redis.from_url(settings.redis_url)
|
|
_use_redis = True
|
|
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
|
|
except Exception as e:
|
|
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
|
|
_use_redis = False
|
|
|
|
if _use_redis and _redis_client:
|
|
_public_limiter = RedisSlidingWindow(
|
|
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
|
)
|
|
_auth_limiter = RedisSlidingWindow(
|
|
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
|
)
|
|
_auth_strict_limiter = RedisSlidingWindow(
|
|
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
|
)
|
|
else:
|
|
_public_limiter = InMemorySlidingWindow(
|
|
settings.rate_limit_requests, settings.rate_limit_window_seconds
|
|
)
|
|
_auth_limiter = InMemorySlidingWindow(
|
|
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
|
)
|
|
_auth_strict_limiter = InMemorySlidingWindow(
|
|
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
|
)
|
|
|
|
|
|
def _get_client_ip(request: Request) -> str:
|
|
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
|
|
"""Determine rate limit key and which limiter to use."""
|
|
if request.url.path.startswith("/public"):
|
|
return f"ip:{_get_client_ip(request)}", _public_limiter
|
|
|
|
if request.url.path.startswith("/auth/") and request.method == "POST":
|
|
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
|
|
|
|
auth_header = request.headers.get("authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
token = auth_header[7:]
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
return f"token:{token_hash}", _auth_limiter
|
|
|
|
return f"ip:{_get_client_ip(request)}", _public_limiter
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not settings.rate_limit_enabled or request.url.path == "/health":
|
|
return await call_next(request)
|
|
|
|
key, limiter = _get_rate_limit_key(request)
|
|
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
|
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"detail": "Rate limit exceeded",
|
|
"code": "RATE_LIMITED",
|
|
},
|
|
headers={
|
|
"Retry-After": str(retry_after),
|
|
"X-RateLimit-Limit": str(limiter.max_requests),
|
|
"X-RateLimit-Remaining": "0",
|
|
},
|
|
)
|
|
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
return response
|
|
|
|
|
|
def add_rate_limit_middleware(app: FastAPI) -> None:
|
|
app.add_middleware(RateLimitMiddleware)
|