Files
api/src/cartsnitch_api/middleware/rate_limit.py
T
Barcode Betty 7b595744e1 fix(api): mypy no-redef and no-any-return errors on dev (CAR-1335)
The api typecheck job is continue-on-error but still posts a failure
status that blocks merges. Three pre-existing mypy errors on dev were
inherited by every PR based on it:

1. middleware/rate_limit.py: duplicate 'name already defined' for
   _public_limiter, _auth_limiter, _auth_strict_limiter (declared at
   lines 111-113 and again at 124-126). The second set is redundant
   because actual assignment happens inside the if/else below.
2. cache.py:43 - 'Returning Any' from .get(); the redis client's get()
   return type isn't narrowed to bytes|str, so the final 'return value'
   branch is Any. Wrap with str() to satisfy the declared str|None.
3. middleware/rate_limit.py:150 - 'Returning Any' from _get_client_ip.
   request.headers.get() and request.client.host are typed Any; wrap
   the branches with str() to match the declared str return.

Refs CAR-1335.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-06-09 05:25:41 +00:00

201 lines
7.2 KiB
Python

"""Rate limiting middleware for public and authenticated endpoints.
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
"""
import hashlib
import logging
import time
import uuid
from collections import defaultdict
from threading import Lock
from typing import Protocol
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis, RedisError
from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class RateLimitBackend(Protocol):
"""Protocol for rate limit backends."""
max_requests: int
window_seconds: int
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
class InMemorySlidingWindow:
"""Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
self.max_requests = max_requests
self.window_seconds = window_seconds
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock()
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic()
cutoff = now - self.window_seconds
with self._lock:
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key])
if current_count >= self.max_requests:
retry_after = int(self._hits[key][0] - cutoff) + 1
return False, 0, retry_after
self._hits[key].append(now)
remaining = self.max_requests - current_count - 1
return True, remaining, 0
class RedisSlidingWindow:
"""Redis-backed sliding window rate limiter using sorted sets."""
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
self.redis = redis
self.max_requests = max_requests
self.window_seconds = window_seconds
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
try:
now = time.monotonic()
cutoff = now - self.window_seconds
now_ms = int(now * 1000)
cutoff_ms = int(cutoff * 1000)
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, cutoff_ms)
pipe.zcard(key)
results = await pipe.execute()
current_count = results[1]
if current_count >= self.max_requests:
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if oldest:
oldest_score = float(oldest[0][1])
retry_after = int((oldest_score - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
pipe = self.redis.pipeline()
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except RedisError as e:
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
return await in_memory.is_allowed(key)
_redis_client: Redis | None = None
_use_redis = False
_public_limiter: RateLimitBackend
_auth_limiter: RateLimitBackend
_auth_strict_limiter: RateLimitBackend
if settings.rate_limit_redis_enabled:
try:
_redis_client = Redis.from_url(settings.redis_url)
_use_redis = True
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
except Exception as e:
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
_use_redis = False
if _use_redis and _redis_client:
_public_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
else:
_public_limiter = InMemorySlidingWindow(
settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = InMemorySlidingWindow(
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
def _get_client_ip(request: Request) -> str:
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return str(forwarded.split(",")[0].strip())
return str(request.client.host) if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
"""Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter
if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter
return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request)
key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"code": "RATE_LIMITED",
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limiter.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
def add_rate_limit_middleware(app: FastAPI) -> None:
app.add_middleware(RateLimitMiddleware)