feat: Redis-backed rate limiting with stricter auth limits (#194)

feat: Redis-backed rate limiting with stricter auth limits
This commit is contained in:
cartsnitch-ceo[bot]
2026-04-15 03:31:42 +00:00
committed by GitHub
3 changed files with 257 additions and 66 deletions
+6 -1
View File
@@ -32,6 +32,9 @@ class Settings(BaseSettings):
rate_limit_requests: int = 60 rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60 rate_limit_window_seconds: int = 60
rate_limit_auth_requests: int = 5
rate_limit_auth_window_seconds: int = 60
rate_limit_redis_enabled: bool = True
rate_limit_enabled: bool = True rate_limit_enabled: bool = True
_PLACEHOLDER_VALUES = {"change-me-in-production"} _PLACEHOLDER_VALUES = {"change-me-in-production"}
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
def normalize_database_url(self): def normalize_database_url(self):
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver.""" """Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
if self.database_url.startswith("postgresql://"): if self.database_url.startswith("postgresql://"):
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1) self.database_url = self.database_url.replace(
"postgresql://", "postgresql+asyncpg://", 1
)
return self return self
+98 -17
View File
@@ -5,18 +5,31 @@ Per-IP limiting on public endpoints, per-token limiting on authenticated endpoin
""" """
import hashlib import hashlib
import logging
import time import time
import uuid
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
from typing import Protocol
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from redis.asyncio import Redis, RedisError
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
logger = logging.getLogger(__name__)
class _SlidingWindowCounter:
class RateLimitBackend(Protocol):
"""Protocol for rate limit backends."""
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
class InMemorySlidingWindow:
"""Thread-safe in-memory sliding window rate limiter.""" """Thread-safe in-memory sliding window rate limiter."""
def __init__(self, max_requests: int, window_seconds: int) -> None: def __init__(self, max_requests: int, window_seconds: int) -> None:
@@ -25,13 +38,12 @@ class _SlidingWindowCounter:
self._hits: dict[str, list[float]] = defaultdict(list) self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = Lock() self._lock = Lock()
def is_allowed(self, key: str) -> tuple[bool, int, int]: async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after).""" """Check if request is allowed. Returns (allowed, remaining, retry_after)."""
now = time.monotonic() now = time.monotonic()
cutoff = now - self.window_seconds cutoff = now - self.window_seconds
with self._lock: with self._lock:
# Prune expired entries
self._hits[key] = [t for t in self._hits[key] if t > cutoff] self._hits[key] = [t for t in self._hits[key] if t > cutoff]
current_count = len(self._hits[key]) current_count = len(self._hits[key])
@@ -44,15 +56,84 @@ class _SlidingWindowCounter:
return True, remaining, 0 return True, remaining, 0
# Module-level counters — one for public (per-IP), one for auth (per-token) class RedisSlidingWindow:
_public_limiter = _SlidingWindowCounter( """Redis-backed sliding window rate limiter using sorted sets."""
max_requests=settings.rate_limit_requests,
window_seconds=settings.rate_limit_window_seconds, def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
) self.redis = redis
_auth_limiter = _SlidingWindowCounter( self.max_requests = max_requests
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users self.window_seconds = window_seconds
window_seconds=settings.rate_limit_window_seconds,
) async def is_allowed(self, key: str) -> tuple[bool, int, int]:
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
try:
now = time.monotonic()
cutoff = now - self.window_seconds
now_ms = int(now * 1000)
cutoff_ms = int(cutoff * 1000)
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, cutoff_ms)
pipe.zcard(key)
results = await pipe.execute()
current_count = results[1]
if current_count >= self.max_requests:
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
if oldest:
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
else:
retry_after = self.window_seconds
return False, 0, retry_after
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
pipe = self.redis.pipeline()
pipe.zadd(key, {member: now_ms})
pipe.expire(key, self.window_seconds)
await pipe.execute()
remaining = self.max_requests - current_count - 1
return True, remaining, 0
except RedisError as e:
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
return await in_memory.is_allowed(key)
_redis_client: Redis | None = None
_use_redis = False
if settings.rate_limit_redis_enabled:
try:
_redis_client = Redis.from_url(settings.redis_url)
_use_redis = True
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
except Exception as e:
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
_use_redis = False
if _use_redis and _redis_client:
_public_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = RedisSlidingWindow(
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
else:
_public_limiter = InMemorySlidingWindow(
settings.rate_limit_requests, settings.rate_limit_window_seconds
)
_auth_limiter = InMemorySlidingWindow(
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
)
_auth_strict_limiter = InMemorySlidingWindow(
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
)
def _get_client_ip(request: Request) -> str: def _get_client_ip(request: Request) -> str:
@@ -63,30 +144,30 @@ def _get_client_ip(request: Request) -> str:
return request.client.host if request.client else "unknown" return request.client.host if request.client else "unknown"
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
"""Determine rate limit key and which limiter to use.""" """Determine rate limit key and which limiter to use."""
if request.url.path.startswith("/public"): if request.url.path.startswith("/public"):
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
# For authenticated endpoints, use Bearer token as key if present if request.url.path.startswith("/auth/") and request.method == "POST":
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest() token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token_hash}", _auth_limiter return f"token:{token_hash}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
# Skip rate limiting when disabled (e.g. in tests) or for health checks
if not settings.rate_limit_enabled or request.url.path == "/health": if not settings.rate_limit_enabled or request.url.path == "/health":
return await call_next(request) return await call_next(request)
key, limiter = _get_rate_limit_key(request) key, limiter = _get_rate_limit_key(request)
allowed, remaining, retry_after = limiter.is_allowed(key) allowed, remaining, retry_after = await limiter.is_allowed(key)
if not allowed: if not allowed:
return JSONResponse( return JSONResponse(
+153 -48
View File
@@ -1,49 +1,184 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
from unittest.mock import MagicMock import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key from cartsnitch_api.config import settings
from cartsnitch_api.middleware.rate_limit import (
InMemorySlidingWindow,
RedisSlidingWindow,
_get_client_ip,
_get_rate_limit_key,
)
class TestSlidingWindowCounter: class TestInMemorySlidingWindow:
def test_allows_within_limit(self): def test_allows_within_limit(self):
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
for i in range(5): for i in range(5):
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is True assert allowed is True
assert remaining == 4 - i assert remaining == 4 - i
def test_blocks_over_limit(self): def test_blocks_over_limit(self):
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
for _ in range(3): for _ in range(3):
counter.is_allowed("test-key") limiter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key") allowed, remaining, retry = limiter.is_allowed("test-key")
assert allowed is False assert allowed is False
assert remaining == 0 assert remaining == 0
assert retry > 0 assert retry > 0
def test_separate_keys(self): def test_separate_keys(self):
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
# Fill key-a limiter.is_allowed("key-a")
counter.is_allowed("key-a") limiter.is_allowed("key-a")
counter.is_allowed("key-a") allowed_a, _, _ = limiter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
assert allowed_a is False assert allowed_a is False
# key-b should still be allowed allowed_b, remaining, _ = limiter.is_allowed("key-b")
allowed_b, remaining, _ = counter.is_allowed("key-b")
assert allowed_b is True assert allowed_b is True
assert remaining == 1 assert remaining == 1
def test_resets_after_window_expires(self):
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
for _ in range(2):
limiter.is_allowed("test-key")
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is False
time.sleep(1.1)
allowed, remaining, _ = limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 1
class TestGetClientIp:
def test_x_forwarded_for_single(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_multiple(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_x_forwarded_for_with_port(self):
req = MagicMock()
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
req.client = None
assert _get_client_ip(req) == "192.168.1.1"
def test_no_forwarded_header(self):
req = MagicMock()
req.headers = {}
req.client.host = "127.0.0.1"
assert _get_client_ip(req) == "127.0.0.1"
def test_no_client(self):
req = MagicMock()
req.headers = {}
req.client = None
assert _get_client_ip(req) == "unknown"
class TestGetRateLimitKey:
def _make_request(
self,
path: str = "/purchases",
method: str = "GET",
auth_header: str = "",
headers: dict | None = None,
) -> MagicMock:
req = MagicMock()
req.url.path = path
req.method = method
req.headers = dict(headers) if headers else {}
if auth_header:
req.headers["authorization"] = auth_header
return req
def test_public_path_uses_public_limiter(self):
req = self._make_request("/public/inflation")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests
def test_auth_post_path_uses_strict_limiter(self):
req = self._make_request("/auth/login", method="POST")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_auth_requests
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
def test_auth_get_path_uses_auth_limiter(self):
req = self._make_request("/auth/me", method="GET")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("ip:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_authenticated_token_uses_auth_limiter(self):
req = self._make_request("/purchases", auth_header="Bearer token123")
key, limiter = _get_rate_limit_key(req)
assert key.startswith("token:")
assert limiter.max_requests == settings.rate_limit_requests * 5
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
class TestRedisSlidingWindowFallback:
@pytest.mark.asyncio
async def test_fallback_on_redis_connection_error(self):
mock_redis = AsyncMock()
mock_redis.pipeline.return_value = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Connection refused")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
assert remaining == 4
@pytest.mark.asyncio
async def test_fallback_on_redis_error_during_pipeline(self):
mock_redis = AsyncMock()
pipe_mock = AsyncMock()
pipe_mock.execute.side_effect = Exception("Redis error")
mock_redis.pipeline.return_value = pipe_mock
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
allowed, remaining, retry = await limiter.is_allowed("test-key")
assert allowed is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rate_limit_returns_429(client): async def test_rate_limit_returns_429(client):
"""Public endpoint should return 429 after limit exceeded."""
# The default limit is 60/min — we won't hit it in normal tests,
# but we verify the middleware adds rate limit headers.
resp = await client.get("/public/inflation") resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers assert "x-ratelimit-remaining" in resp.headers
@@ -51,36 +186,6 @@ async def test_rate_limit_returns_429(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_skips_rate_limit(client): async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey:
def _make_request(self, auth_header: str = "") -> MagicMock:
req = MagicMock()
req.url.path = "/purchases"
req.headers = {"authorization": auth_header} if auth_header else {}
return req
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key