feat: Redis-backed rate limiting with stricter auth limits

- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60)
  settings to config.py
- Refactor rate_limit.py to use protocol/ABC pattern with InMemorySlidingWindow
  and RedisSlidingWindow implementations
- Add RedisSlidingWindow using sorted sets for distributed rate limiting
- Add auth_strict_limiter for /auth/* POST endpoints (5 req/min per IP)
- Fall back to in-memory when Redis is unavailable
- Update tests to cover new functionality

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Paperclip
2026-04-14 15:46:52 +00:00
parent 06c099594a
commit 26f3415eab
3 changed files with 277 additions and 73 deletions
+6 -1
View File
@@ -33,6 +33,9 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
rate_limit_enabled: bool = True
rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"}
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
def normalize_database_url(self):
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
if self.database_url.startswith("postgresql://"):
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
self.database_url = self.database_url.replace(
"postgresql://", "postgresql+asyncpg://", 1
)
return self
+137 -21
View File
@@ -4,19 +4,35 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
"""
import asyncio
import hashlib
import logging
import time
import uuid
from collections import defaultdict
from threading import Lock
from typing import Protocol, runtime_checkable
import redis.asyncio as redis
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class _SlidingWindowCounter:
@runtime_checkable
class RateLimiter(Protocol):
"""Protocol for rate limiter implementations."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
...
class InMemorySlidingWindow:
"""Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
@@ -25,13 +41,12 @@ class _SlidingWindowCounter:
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]:
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic()
cutoff = now - self.window_seconds
with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key])
@@ -44,15 +59,101 @@ class _SlidingWindowCounter:
return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token)
_public_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds,
)
_auth_limiter = _SlidingWindowCounter(
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
window_seconds=settings.rate_limit_window_seconds,
)
class RedisSlidingWindow:
"""Redis-backed sliding window rate limiter using sorted sets."""
def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None:
self.client = client
self.max_requests = max_requests
self.window_seconds = window_seconds
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after)."""
now_ms = int(time.time() * 1000)
window_ms = self.window_seconds * 1000
cutoff = now_ms - window_ms
try:
async with self.client.pipeline(transaction=True) as pipe:
pipe.zremrangebyscore(key, 0, cutoff)
pipe.zcard(key)
await pipe.execute()
current_count = await self.client.zcard(key)
if current_count >= self.max_requests:
results = await self.client.zrange(key, 0, 0, withscores=True)
if results:
oldest_score = int(results[0][1])
retry_after = int((oldest_score - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
async with self.client.pipeline(transaction=True) as pipe:
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except Exception as e:
logger.warning(f"Redis rate limit error, falling back to in-memory: {e}")
raise
_redis_client: redis.Redis | None = None
_use_redis = False
def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]:
"""Get the three rate limiters (public, auth, auth_strict)."""
global _redis_client, _use_redis
if _use_redis and _redis_client is not None:
return (
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client,
settings.rate_limit_auth_requests,
settings.rate_limit_auth_window_seconds,
),
)
return (
InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds),
InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds),
InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
),
)
def _init_redis() -> None:
"""Initialize Redis connection at module load."""
global _redis_client, _use_redis
if not settings.rate_limit_redis_enabled:
logger.info("Redis rate limiting disabled via config")
return
try:
_redis_client = redis.from_url(settings.redis_url)
asyncio.get_event_loop().run_until_complete(_redis_client.ping())
_use_redis = True
logger.info("Redis rate limiting enabled")
except Exception as e:
logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}")
_use_redis = False
_init_redis()
def _get_client_ip(request: Request) -> str:
@@ -63,30 +164,45 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]:
"""Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter
public_limiter, auth_limiter, auth_strict_limiter = _get_limiters()
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", public_limiter
if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", auth_strict_limiter
# For authenticated endpoints, use Bearer token as key if present
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter
return f"token:{token_hash}", auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter
return f"ip:{_get_client_ip(request)}", public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request)
key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key)
try:
allowed, remaining, retry_after = await limiter.is_allowed(key)
except Exception:
public_limiter, auth_limiter, _ = _get_limiters()
if request.url.path.startswith("/auth/") and request.method == "POST":
limiter = auth_limiter
elif request.url.path.startswith("/public"):
limiter = public_limiter
elif request.headers.get("authorization", "").startswith("Bearer "):
limiter = auth_limiter
else:
limiter = public_limiter
allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed:
return JSONResponse(
+134 -51
View File
@@ -1,52 +1,157 @@
"""Tests for rate limiting middleware."""
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow,
RateLimitMiddleware,
_get_client_ip,
_get_rate_limit_key,
_init_redis,
_use_redis,
)
class TestSlidingWindowCounter:
class TestInMemorySlidingWindow:
def test_allows_within_limit(self):
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
for i in range(5):
allowed, remaining, retry = counter.is_allowed("test-key")
allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 4 - i
def test_blocks_over_limit(self):
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
for _ in range(3):
counter.is_allowed("test-key")
limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key")
allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is False
assert remaining == 0
assert retry > 0
def test_separate_keys(self):
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
# Fill key-a
counter.is_allowed("key-a")
counter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
limiter.is_allowed("key-a")
limiter.is_allowed("key-a")
allowed_a, _, _ = limiter.is_allowed("key-a")
assert allowed_a is False
# key-b should still be allowed
allowed_b, remaining, _ = counter.is_allowed("key-b")
allowed_b, remaining, _ = limiter.is_allowed("key-b")
assert allowed_b is True
assert remaining == 1
@pytest.mark.asyncio
async def test_rate_limit_returns_429(client):
"""Public endpoint should return 429 after limit exceeded."""
# The default limit is 60/min — we won't hit it in normal tests,
# but we verify the middleware adds rate limit headers.
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers
class TestGetRateLimitKey:
def _make_request(
self,
path: str = "/purchases",
method: str = "GET",
auth_header: str = "",
headers: dict | None = None,
) -> MagicMock:
req = MagicMock()
req.url.path = path
req.method = method
req.headers = dict(headers) if headers else {}
if auth_header:
req.headers["authorization"] = auth_header
return req
def test_public_path_uses_public_limiter(self):
req = self._make_request("/public/inflation")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
def test_auth_post_path_uses_strict_limiter(self):
req = self._make_request("/auth/login", method="POST")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_auth_requests
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
def test_auth_get_path_uses_auth_limiter(self):
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("token:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestRedisFallback:
@pytest.mark.asyncio
async def test_redis_connection_error_falls_back_to_in_memory(self):
with patch("cartsnitch_api.middleware.rate_limit._use_redis", True):
with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client:
mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused"))
mock_client.zrange = AsyncMock(return_value=[])
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 2
@pytest.mark.asyncio
@@ -54,33 +159,11 @@ async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health")
assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey:
def _make_request(self, auth_header: str = "") -> MagicMock:
req = MagicMock()
req.url.path = "/purchases"
req.headers = {"authorization": auth_header} if auth_header else {}
return req
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
@pytest.mark.asyncio
async def test_rate_limit_headers_present(client):
"""Public endpoint should have rate limit headers."""
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers