feat: Redis-backed rate limiting with stricter auth limits

- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60)
  settings to config.py
- Refactor rate_limit.py to use protocol/ABC pattern with InMemorySlidingWindow
  and RedisSlidingWindow implementations
- Add RedisSlidingWindow using sorted sets for distributed rate limiting
- Add auth_strict_limiter for /auth/* POST endpoints (5 req/min per IP)
- Fall back to in-memory when Redis is unavailable
- Update tests to cover new functionality

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Paperclip
2026-04-14 15:46:52 +00:00
parent 06c099594a
commit 26f3415eab
3 changed files with 277 additions and 73 deletions
+6 -1
View File
@@ -33,6 +33,9 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60 rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60 rate_limit_window_seconds: int = 60
rate_limit_enabled: bool = True rate_limit_enabled: bool = True
rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"} _PLACEHOLDER_VALUES = {"change-me-in-production"}
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
def normalize_database_url(self): def normalize_database_url(self):
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver.""" """Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
if self.database_url.startswith("postgresql://"): if self.database_url.startswith("postgresql://"):
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1) self.database_url = self.database_url.replace(
"postgresql://", "postgresql+asyncpg://", 1
)
return self return self
+137 -21
View File
@@ -4,19 +4,35 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
""" """
import asyncio
import hashlib import hashlib
import logging
import time import time
import uuid
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
from typing import Protocol, runtime_checkable
import redis.asyncio as redis
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class _SlidingWindowCounter:
@runtime_checkable
class RateLimiter(Protocol):
"""Protocol for rate limiter implementations."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
...
class InMemorySlidingWindow:
"""Thread-safe in-memory sliding window rate limiter.""" """Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None: def __init__(self, max_requests: int, window_seconds: int) -> None:
@@ -25,13 +41,12 @@ class _SlidingWindowCounter:
self._hits: dict[str, list[float]] = defaultdict(list) self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock() self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]: async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after).""" """Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic() now = time.monotonic()
cutoff = now - self.window_seconds cutoff = now - self.window_seconds
with self._lock: with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff] self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key]) current_count = len(self._hits[key])
@@ -44,15 +59,101 @@ class _SlidingWindowCounter:
return True, remaining, 0 return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token) class RedisSlidingWindow:
_public_limiter = _SlidingWindowCounter( """Redis-backed sliding window rate limiter using sorted sets."""
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds, def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None:
) self.client = client
_auth_limiter = _SlidingWindowCounter( self.max_requests = max_requests
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users self.window_seconds = window_seconds
window_seconds=settings.rate_limit_window_seconds,
) async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after)."""
now_ms = int(time.time() * 1000)
window_ms = self.window_seconds * 1000
cutoff = now_ms - window_ms
try:
async with self.client.pipeline(transaction=True) as pipe:
pipe.zremrangebyscore(key, 0, cutoff)
pipe.zcard(key)
await pipe.execute()
current_count = await self.client.zcard(key)
if current_count >= self.max_requests:
results = await self.client.zrange(key, 0, 0, withscores=True)
if results:
oldest_score = int(results[0][1])
retry_after = int((oldest_score - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
async with self.client.pipeline(transaction=True) as pipe:
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except Exception as e:
logger.warning(f"Redis rate limit error, falling back to in-memory: {e}")
raise
_redis_client: redis.Redis | None = None
_use_redis = False
def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]:
"""Get the three rate limiters (public, auth, auth_strict)."""
global _redis_client, _use_redis
if _use_redis and _redis_client is not None:
return (
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client,
settings.rate_limit_auth_requests,
settings.rate_limit_auth_window_seconds,
),
)
return (
InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds),
InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds),
InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
),
)
def _init_redis() -> None:
"""Initialize Redis connection at module load."""
global _redis_client, _use_redis
if not settings.rate_limit_redis_enabled:
logger.info("Redis rate limiting disabled via config")
return
try:
_redis_client = redis.from_url(settings.redis_url)
asyncio.get_event_loop().run_until_complete(_redis_client.ping())
_use_redis = True
logger.info("Redis rate limiting enabled")
except Exception as e:
logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}")
_use_redis = False
_init_redis()
def _get_client_ip(request: Request) -> str: def _get_client_ip(request: Request) -> str:
@@ -63,30 +164,45 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown" return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]:
"""Determine rate limit key and which limiter to use.""" """Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"): public_limiter, auth_limiter, auth_strict_limiter = _get_limiters()
return f"ip:{_get_client_ip(request)}", _public_limiter
if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", public_limiter
if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", auth_strict_limiter
# For authenticated endpoints, use Bearer token as key if present
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest() token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter return f"token:{token_hash}", auth_limiter
# Fallback to IP for unauthenticated non-public endpoints return f"ip:{_get_client_ip(request)}", public_limiter
return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health": if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request) return await call_next(request)
key, limiter = _get_rate_limit_key(request) key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key)
try:
allowed, remaining, retry_after = await limiter.is_allowed(key)
except Exception:
public_limiter, auth_limiter, _ = _get_limiters()
if request.url.path.startswith("/auth/") and request.method == "POST":
limiter = auth_limiter
elif request.url.path.startswith("/public"):
limiter = public_limiter
elif request.headers.get("authorization", "").startswith("Bearer "):
limiter = auth_limiter
else:
limiter = public_limiter
allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed: if not allowed:
return JSONResponse( return JSONResponse(
+134 -51
View File
@@ -1,52 +1,157 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow,
RateLimitMiddleware,
_get_client_ip,
_get_rate_limit_key,
_init_redis,
_use_redis,
)
class TestSlidingWindowCounter: class TestInMemorySlidingWindow:
def test_allows_within_limit(self): def test_allows_within_limit(self):
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
for i in range(5): for i in range(5):
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is True assert allowed is True
assert remaining == 4 - i assert remaining == 4 - i
def test_blocks_over_limit(self): def test_blocks_over_limit(self):
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
for _ in range(3): for _ in range(3):
counter.is_allowed("test-key") limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is False assert allowed is False
assert remaining == 0 assert remaining == 0
assert retry > 0 assert retry > 0
def test_separate_keys(self): def test_separate_keys(self):
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
# Fill key-a limiter.is_allowed("key-a")
counter.is_allowed("key-a") limiter.is_allowed("key-a")
counter.is_allowed("key-a") allowed_a, _, _ = limiter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
assert allowed_a is False assert allowed_a is False
# key-b should still be allowed allowed_b, remaining, _ = limiter.is_allowed("key-b")
allowed_b, remaining, _ = counter.is_allowed("key-b")
assert allowed_b is True assert allowed_b is True
assert remaining == 1 assert remaining == 1
@pytest.mark.asyncio class TestGetRateLimitKey:
async def test_rate_limit_returns_429(client): def _make_request(
"""Public endpoint should return 429 after limit exceeded.""" self,
# The default limit is 60/min — we won't hit it in normal tests, path: str = "/purchases",
# but we verify the middleware adds rate limit headers. method: str = "GET",
resp = await client.get("/public/inflation") auth_header: str = "",
assert "x-ratelimit-limit" in resp.headers headers: dict | None = None,
assert "x-ratelimit-remaining" in resp.headers ) -> MagicMock:
req = MagicMock()
req.url.path = path
req.method = method
req.headers = dict(headers) if headers else {}
if auth_header:
req.headers["authorization"] = auth_header
return req
def test_public_path_uses_public_limiter(self):
req = self._make_request("/public/inflation")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
def test_auth_post_path_uses_strict_limiter(self):
req = self._make_request("/auth/login", method="POST")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_auth_requests
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
def test_auth_get_path_uses_auth_limiter(self):
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("token:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestRedisFallback:
@pytest.mark.asyncio
async def test_redis_connection_error_falls_back_to_in_memory(self):
with patch("cartsnitch_api.middleware.rate_limit._use_redis", True):
with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client:
mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused"))
mock_client.zrange = AsyncMock(return_value=[])
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -54,33 +159,11 @@ async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers.""" """Health endpoint should not have rate limit headers."""
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey: @pytest.mark.asyncio
def _make_request(self, auth_header: str = "") -> MagicMock: async def test_rate_limit_headers_present(client):
req = MagicMock() """Public endpoint should have rate limit headers."""
req.url.path = "/purchases" resp = await client.get("/public/inflation")
req.headers = {"authorization": auth_header} if auth_header else {} assert "x-ratelimit-limit" in resp.headers
return req assert "x-ratelimit-remaining" in resp.headers
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key