forked from cartsnitch/cartsnitch
feat: Redis-backed rate limiting with stricter auth limits
- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings - Add rate_limit_redis_enabled flag for opt-in Redis usage - Refactor _SlidingWindowCounter into InMemorySlidingWindow class - Add RedisSlidingWindow using sorted sets with fallback to in-memory - Add third _auth_strict_limiter for POST /auth/* paths (5 req/min) - Add protocol-based backend selection at module load time - Update tests for auth strict limiter and Redis fallback behavior Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -32,10 +32,10 @@ class Settings(BaseSettings):
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_enabled: bool = True
|
||||
rate_limit_auth_requests: int = 5
|
||||
rate_limit_auth_window_seconds: int = 60
|
||||
rate_limit_redis_enabled: bool = True
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||
|
||||
|
||||
@@ -4,18 +4,17 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import Protocol
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis, RedisError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
@@ -23,13 +22,11 @@ from cartsnitch_api.config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class RateLimiter(Protocol):
|
||||
"""Protocol for rate limiter implementations."""
|
||||
class RateLimitBackend(Protocol):
|
||||
"""Protocol for rate limit backends."""
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
...
|
||||
|
||||
|
||||
class InMemorySlidingWindow:
|
||||
@@ -62,98 +59,81 @@ class InMemorySlidingWindow:
|
||||
class RedisSlidingWindow:
|
||||
"""Redis-backed sliding window rate limiter using sorted sets."""
|
||||
|
||||
def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None:
|
||||
self.client = client
|
||||
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
|
||||
self.redis = redis
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after)."""
|
||||
now_ms = int(time.time() * 1000)
|
||||
window_ms = self.window_seconds * 1000
|
||||
cutoff = now_ms - window_ms
|
||||
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
try:
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
pipe.zremrangebyscore(key, 0, cutoff)
|
||||
pipe.zcard(key)
|
||||
await pipe.execute()
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
now_ms = int(now * 1000)
|
||||
cutoff_ms = int(cutoff * 1000)
|
||||
|
||||
current_count = await self.client.zcard(key)
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, cutoff_ms)
|
||||
pipe.zcard(key)
|
||||
results = await pipe.execute()
|
||||
|
||||
current_count = results[1]
|
||||
|
||||
if current_count >= self.max_requests:
|
||||
results = await self.client.zrange(key, 0, 0, withscores=True)
|
||||
if results:
|
||||
oldest_score = int(results[0][1])
|
||||
retry_after = int((oldest_score - cutoff) / 1000) + 1
|
||||
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
|
||||
if oldest:
|
||||
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
|
||||
else:
|
||||
retry_after = self.window_seconds
|
||||
return False, 0, retry_after
|
||||
|
||||
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
|
||||
async with self.client.pipeline(transaction=True) as pipe:
|
||||
pipe.zadd(key, {member: now_ms})
|
||||
pipe.expire(key, self.window_seconds)
|
||||
await pipe.execute()
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zadd(key, {member: now_ms})
|
||||
pipe.expire(key, self.window_seconds)
|
||||
await pipe.execute()
|
||||
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis rate limit error, falling back to in-memory: {e}")
|
||||
raise
|
||||
except RedisError as e:
|
||||
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
|
||||
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
|
||||
return await in_memory.is_allowed(key)
|
||||
|
||||
|
||||
_redis_client: redis.Redis | None = None
|
||||
_redis_client: Redis | None = None
|
||||
_use_redis = False
|
||||
|
||||
|
||||
def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]:
|
||||
"""Get the three rate limiters (public, auth, auth_strict)."""
|
||||
global _redis_client, _use_redis
|
||||
|
||||
if _use_redis and _redis_client is not None:
|
||||
return (
|
||||
RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
),
|
||||
RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
),
|
||||
RedisSlidingWindow(
|
||||
_redis_client,
|
||||
settings.rate_limit_auth_requests,
|
||||
settings.rate_limit_auth_window_seconds,
|
||||
),
|
||||
)
|
||||
return (
|
||||
InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds),
|
||||
InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds),
|
||||
InMemorySlidingWindow(
|
||||
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _init_redis() -> None:
|
||||
"""Initialize Redis connection at module load."""
|
||||
global _redis_client, _use_redis
|
||||
|
||||
if not settings.rate_limit_redis_enabled:
|
||||
logger.info("Redis rate limiting disabled via config")
|
||||
return
|
||||
|
||||
if settings.rate_limit_redis_enabled:
|
||||
try:
|
||||
_redis_client = redis.from_url(settings.redis_url)
|
||||
asyncio.get_event_loop().run_until_complete(_redis_client.ping())
|
||||
_redis_client = Redis.from_url(settings.redis_url)
|
||||
_use_redis = True
|
||||
logger.info("Redis rate limiting enabled")
|
||||
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}")
|
||||
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
|
||||
_use_redis = False
|
||||
|
||||
|
||||
_init_redis()
|
||||
if _use_redis and _redis_client:
|
||||
_public_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
else:
|
||||
_public_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
@@ -164,23 +144,21 @@ def _get_client_ip(request: Request) -> str:
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]:
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
public_limiter, auth_limiter, auth_strict_limiter = _get_limiters()
|
||||
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", public_limiter
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
if request.url.path.startswith("/auth/") and request.method == "POST":
|
||||
return f"ip:{_get_client_ip(request)}", auth_strict_limiter
|
||||
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
return f"token:{token_hash}", auth_limiter
|
||||
return f"token:{token_hash}", _auth_limiter
|
||||
|
||||
return f"ip:{_get_client_ip(request)}", public_limiter
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
@@ -189,20 +167,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
|
||||
try:
|
||||
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||
except Exception:
|
||||
public_limiter, auth_limiter, _ = _get_limiters()
|
||||
if request.url.path.startswith("/auth/") and request.method == "POST":
|
||||
limiter = auth_limiter
|
||||
elif request.url.path.startswith("/public"):
|
||||
limiter = public_limiter
|
||||
elif request.headers.get("authorization", "").startswith("Bearer "):
|
||||
limiter = auth_limiter
|
||||
else:
|
||||
limiter = public_limiter
|
||||
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -7,11 +8,9 @@ import pytest
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.middleware.rate_limit import (
|
||||
InMemorySlidingWindow,
|
||||
RateLimitMiddleware,
|
||||
RedisSlidingWindow,
|
||||
_get_client_ip,
|
||||
_get_rate_limit_key,
|
||||
_init_redis,
|
||||
_use_redis,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,6 +43,50 @@ class TestInMemorySlidingWindow:
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
def test_resets_after_window_expires(self):
|
||||
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
|
||||
for _ in range(2):
|
||||
limiter.is_allowed("test-key")
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
|
||||
time.sleep(1.1)
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def test_x_forwarded_for_single(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_with_port(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_no_forwarded_header(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client.host = "127.0.0.1"
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
def test_no_client(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(
|
||||
@@ -108,62 +151,41 @@ class TestGetRateLimitKey:
|
||||
assert raw_token not in key
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def test_x_forwarded_for_single(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_with_port(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_no_forwarded_header(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client.host = "127.0.0.1"
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
def test_no_client(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
class TestRedisFallback:
|
||||
class TestRedisSlidingWindowFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_connection_error_falls_back_to_in_memory(self):
|
||||
with patch("cartsnitch_api.middleware.rate_limit._use_redis", True):
|
||||
with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client:
|
||||
mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused"))
|
||||
mock_client.zrange = AsyncMock(return_value=[])
|
||||
async def test_fallback_on_redis_connection_error(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline.return_value = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Connection refused")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 2
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_error_during_pipeline(self):
|
||||
mock_redis = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Redis error")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_headers_present(client):
|
||||
"""Public endpoint should have rate limit headers."""
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
|
||||
Reference in New Issue
Block a user