feat: Redis-backed rate limiting with stricter auth limits

- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings
- Add rate_limit_redis_enabled flag for opt-in Redis usage
- Refactor _SlidingWindowCounter into InMemorySlidingWindow class
- Add RedisSlidingWindow using sorted sets with fallback to in-memory
- Add third _auth_strict_limiter for POST /auth/* paths (5 req/min)
- Add protocol-based backend selection at module load time
- Update tests for auth strict limiter and Redis fallback behavior

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Barcode Betty
2026-04-15 02:10:02 +00:00
parent 26f3415eab
commit 8a4c194e39
3 changed files with 136 additions and 149 deletions
+1 -1
View File
@@ -32,10 +32,10 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60 rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60 rate_limit_window_seconds: int = 60
rate_limit_enabled: bool = True
rate_limit_auth_requests: int = 5 rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60 rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True rate_limit_redis_enabled: bool = True
rate_limit_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"} _PLACEHOLDER_VALUES = {"change-me-in-production"}
+59 -94
View File
@@ -4,18 +4,17 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
""" """
import asyncio
import hashlib import hashlib
import logging import logging
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
from typing import Protocol, runtime_checkable from typing import Protocol
import redis.asyncio as redis
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from redis.asyncio import Redis, RedisError
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
@@ -23,13 +22,11 @@ from cartsnitch_api.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@runtime_checkable class RateLimitBackend(Protocol):
class RateLimiter(Protocol): """Protocol for rate limit backends."""
"""Protocol for rate limiter implementations."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]: async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after).""" """Check if request is allowed. Returns (allowed, remaining, retry_after)."""
...
class InMemorySlidingWindow: class InMemorySlidingWindow:
@@ -62,98 +59,81 @@ class InMemorySlidingWindow:
class RedisSlidingWindow: class RedisSlidingWindow:
"""Redis-backed sliding window rate limiter using sorted sets.""" """Redis-backed sliding window rate limiter using sorted sets."""
def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None: def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
self.client = client self.redis = redis
self.max_requests = max_requests self.max_requests = max_requests
self.window_seconds = window_seconds self.window_seconds = window_seconds
async def is_allowed(self, key: str) -> tuple[bool, int, int]: async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after).""" """Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now_ms = int(time.time() * 1000)
window_ms = self.window_seconds * 1000
cutoff = now_ms - window_ms
try: try:
async with self.client.pipeline(transaction=True) as pipe: now = time.monotonic()
pipe.zremrangebyscore(key, 0, cutoff) cutoff = now - self.window_seconds
pipe.zcard(key) now_ms = int(now * 1000)
await pipe.execute() cutoff_ms = int(cutoff * 1000)
current_count = await self.client.zcard(key) pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, cutoff_ms)
pipe.zcard(key)
results = await pipe.execute()
current_count = results[1]
if current_count >= self.max_requests: if current_count >= self.max_requests:
results = await self.client.zrange(key, 0, 0, withscores=True) oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if results: if oldest:
oldest_score = int(results[0][1]) retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
retry_after = int((oldest_score - cutoff) / 1000) + 1
else: else:
retry_after = self.window_seconds retry_after = self.window_seconds
return False, 0, retry_after return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}" member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
async with self.client.pipeline(transaction=True) as pipe: pipe = self.redis.pipeline()
pipe.zadd(key, {member: now_ms}) pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds) pipe.expire(key, self.window_seconds)
await pipe.execute() await pipe.execute()
remaining = self.max_requests - current_count - 1 remaining = self.max_requests - current_count - 1
return True, remaining, 0 return True, remaining, 0
except Exception as e: except RedisError as e:
logger.warning(f"Redis rate limit error, falling back to in-memory: {e}") logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
raise in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
return await in_memory.is_allowed(key)
_redis_client: redis.Redis | None = None _redis_client: Redis | None = None
_use_redis = False _use_redis = False
if settings.rate_limit_redis_enabled:
def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]:
"""Get the three rate limiters (public, auth, auth_strict)."""
global _redis_client, _use_redis
if _use_redis and _redis_client is not None:
return (
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
),
RedisSlidingWindow(
_redis_client,
settings.rate_limit_auth_requests,
settings.rate_limit_auth_window_seconds,
),
)
return (
InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds),
InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds),
InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
),
)
def _init_redis() -> None:
"""Initialize Redis connection at module load."""
global _redis_client, _use_redis
if not settings.rate_limit_redis_enabled:
logger.info("Redis rate limiting disabled via config")
return
try: try:
_redis_client = redis.from_url(settings.redis_url) _redis_client = Redis.from_url(settings.redis_url)
asyncio.get_event_loop().run_until_complete(_redis_client.ping())
_use_redis = True _use_redis = True
logger.info("Redis rate limiting enabled") logger.info("Rate limiting will use Redis at %s", settings.redis_url)
except Exception as e: except Exception as e:
logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}") logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
_use_redis = False _use_redis = False
if _use_redis and _redis_client:
_init_redis() _public_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
else:
_public_limiter = InMemorySlidingWindow(
settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = InMemorySlidingWindow(
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
def _get_client_ip(request: Request) -> str: def _get_client_ip(request: Request) -> str:
@@ -164,23 +144,21 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown" return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]: def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
"""Determine rate limit key and which limiter to use.""" """Determine rate limit key and which limiter to use."""
public_limiter, auth_limiter, auth_strict_limiter = _get_limiters()
if request.url.path.startswith("/public"): if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
if request.url.path.startswith("/auth/") and request.method == "POST": if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", auth_strict_limiter return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest() token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", auth_limiter return f"token:{token_hash}", _auth_limiter
return f"ip:{_get_client_ip(request)}", public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
@@ -189,20 +167,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
key, limiter = _get_rate_limit_key(request) key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = await limiter.is_allowed(key)
try:
allowed, remaining, retry_after = await limiter.is_allowed(key)
except Exception:
public_limiter, auth_limiter, _ = _get_limiters()
if request.url.path.startswith("/auth/") and request.method == "POST":
limiter = auth_limiter
elif request.url.path.startswith("/public"):
limiter = public_limiter
elif request.headers.get("authorization", "").startswith("Bearer "):
limiter = auth_limiter
else:
limiter = public_limiter
allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed: if not allowed:
return JSONResponse( return JSONResponse(
+76 -54
View File
@@ -1,5 +1,6 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -7,11 +8,9 @@ import pytest
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import ( from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow, InMemorySlidingWindow,
RateLimitMiddleware, RedisSlidingWindow,
_get_client_ip, _get_client_ip,
_get_rate_limit_key, _get_rate_limit_key,
_init_redis,
_use_redis,
) )
@@ -44,6 +43,50 @@ class TestInMemorySlidingWindow:
assert allowed_b is True assert allowed_b is True
assert remaining == 1 assert remaining == 1
def test_resets_after_window_expires(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
for _ in range(2):
limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is False
time.sleep(1.1)
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 1
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestGetRateLimitKey: class TestGetRateLimitKey:
def _make_request( def _make_request(
@@ -108,62 +151,41 @@ class TestGetRateLimitKey:
assert raw_token not in key assert raw_token not in key
class TestGetClientIp: class TestRedisSlidingWindowFallback:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestRedisFallback:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_connection_error_falls_back_to_in_memory(self): async def test_fallback_on_redis_connection_error(self):
with patch("cartsnitch_api.middleware.rate_limit._use_redis", True): mock_redis = AsyncMock()
with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client: mock_redis.pipeline.return_value = AsyncMock()
mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused")) pipe_mock = AsyncMock()
mock_client.zrange = AsyncMock(return_value=[]) pipe_mock.execute.side_effect = Exception("Connection refused")
mock_redis.pipeline.return_value = pipe_mock
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60) limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key") allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True assert allowed is True
assert remaining == 2 assert remaining == 4
@pytest.mark.asyncio
async def test_fallback_on_redis_error_during_pipeline(self):
mock_redis = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Redis error")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
@pytest.mark.asyncio
async def test_rate_limit_returns_429(client):
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_skips_rate_limit(client): async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers
@pytest.mark.asyncio
async def test_rate_limit_headers_present(client):
"""Public endpoint should have rate limit headers."""
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers