From c03e599ae3b6c68653c6e67815cc0ee0d6086efd Mon Sep 17 00:00:00 2001 From: Barcode Betty Date: Wed, 15 Apr 2026 02:10:02 +0000 Subject: [PATCH] feat: Redis-backed rate limiting with stricter auth limits - Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings - Add rate_limit_redis_enabled flag for opt-in Redis usage - Refactor _SlidingWindowCounter into InMemorySlidingWindow class - Add RedisSlidingWindow using sorted sets with fallback to in-memory - Add third _auth_strict_limiter for POST /auth/* paths (5 req/min) - Add protocol-based backend selection at module load time - Update tests for auth strict limiter and Redis fallback behavior Co-Authored-By: Paperclip --- api/src/cartsnitch_api/config.py | 2 +- .../cartsnitch_api/middleware/rate_limit.py | 153 +++++++----------- api/tests/test_middleware/test_rate_limit.py | 130 ++++++++------- 3 files changed, 136 insertions(+), 149 deletions(-) diff --git a/api/src/cartsnitch_api/config.py b/api/src/cartsnitch_api/config.py index 7fd10f9..c835bca 100644 --- a/api/src/cartsnitch_api/config.py +++ b/api/src/cartsnitch_api/config.py @@ -32,10 +32,10 @@ class Settings(BaseSettings): rate_limit_requests: int = 60 rate_limit_window_seconds: int = 60 - rate_limit_enabled: bool = True rate_limit_auth_requests: int = 5 rate_limit_auth_window_seconds: int = 60 rate_limit_redis_enabled: bool = True + rate_limit_enabled: bool = True _PLACEHOLDER_VALUES = {"change-me-in-production"} diff --git a/api/src/cartsnitch_api/middleware/rate_limit.py b/api/src/cartsnitch_api/middleware/rate_limit.py index fd4fdbc..af3dd4b 100644 --- a/api/src/cartsnitch_api/middleware/rate_limit.py +++ b/api/src/cartsnitch_api/middleware/rate_limit.py @@ -4,18 +4,17 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. """ -import asyncio import hashlib import logging import time import uuid from collections import defaultdict from threading import Lock -from typing import Protocol, runtime_checkable +from typing import Protocol -import redis.asyncio as redis from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse +from redis.asyncio import Redis, RedisError from starlette.middleware.base import BaseHTTPMiddleware from cartsnitch_api.config import settings @@ -23,13 +22,11 @@ from cartsnitch_api.config import settings logger = logging.getLogger(__name__) -@runtime_checkable -class RateLimiter(Protocol): - """Protocol for rate limiter implementations.""" +class RateLimitBackend(Protocol): + """Protocol for rate limit backends.""" async def is_allowed(self, key: str) -> tuple[bool, int, int]: """Check if request is allowed. Returns (allowed, remaining, retry_after).""" - ... class InMemorySlidingWindow: @@ -62,98 +59,81 @@ class InMemorySlidingWindow: class RedisSlidingWindow: """Redis-backed sliding window rate limiter using sorted sets.""" - def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None: - self.client = client + def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None: + self.redis = redis self.max_requests = max_requests self.window_seconds = window_seconds async def is_allowed(self, key: str) -> tuple[bool, int, int]: - """Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after).""" - now_ms = int(time.time() * 1000) - window_ms = self.window_seconds * 1000 - cutoff = now_ms - window_ms - + """Check if request is allowed. Returns (allowed, remaining, retry_after).""" try: - async with self.client.pipeline(transaction=True) as pipe: - pipe.zremrangebyscore(key, 0, cutoff) - pipe.zcard(key) - await pipe.execute() + now = time.monotonic() + cutoff = now - self.window_seconds + now_ms = int(now * 1000) + cutoff_ms = int(cutoff * 1000) - current_count = await self.client.zcard(key) + pipe = self.redis.pipeline() + pipe.zremrangebyscore(key, 0, cutoff_ms) + pipe.zcard(key) + results = await pipe.execute() + + current_count = results[1] if current_count >= self.max_requests: - results = await self.client.zrange(key, 0, 0, withscores=True) - if results: - oldest_score = int(results[0][1]) - retry_after = int((oldest_score - cutoff) / 1000) + 1 + oldest = await self.redis.zrange(key, 0, 0, withscores=True) + if oldest: + retry_after = int((oldest[0][1] - cutoff) / 1000) + 1 else: retry_after = self.window_seconds return False, 0, retry_after member = f"{now_ms}:{uuid.uuid4().hex[:8]}" - async with self.client.pipeline(transaction=True) as pipe: - pipe.zadd(key, {member: now_ms}) - pipe.expire(key, self.window_seconds) - await pipe.execute() + pipe = self.redis.pipeline() + pipe.zadd(key, {member: now_ms}) + pipe.expire(key, self.window_seconds) + await pipe.execute() remaining = self.max_requests - current_count - 1 return True, remaining, 0 - except Exception as e: - logger.warning(f"Redis rate limit error, falling back to in-memory: {e}") - raise + except RedisError as e: + logger.warning("Redis rate limit error, falling back to in-memory: %s", e) + in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds) + return await in_memory.is_allowed(key) -_redis_client: redis.Redis | None = None +_redis_client: Redis | None = None _use_redis = False - -def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]: - """Get the three rate limiters (public, auth, auth_strict).""" - global _redis_client, _use_redis - - if _use_redis and _redis_client is not None: - return ( - RedisSlidingWindow( - _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds - ), - RedisSlidingWindow( - _redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds - ), - RedisSlidingWindow( - _redis_client, - settings.rate_limit_auth_requests, - settings.rate_limit_auth_window_seconds, - ), - ) - return ( - InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds), - InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds), - InMemorySlidingWindow( - settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds - ), - ) - - -def _init_redis() -> None: - """Initialize Redis connection at module load.""" - global _redis_client, _use_redis - - if not settings.rate_limit_redis_enabled: - logger.info("Redis rate limiting disabled via config") - return - +if settings.rate_limit_redis_enabled: try: - _redis_client = redis.from_url(settings.redis_url) - asyncio.get_event_loop().run_until_complete(_redis_client.ping()) + _redis_client = Redis.from_url(settings.redis_url) _use_redis = True - logger.info("Redis rate limiting enabled") + logger.info("Rate limiting will use Redis at %s", settings.redis_url) except Exception as e: - logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}") + logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e) _use_redis = False - -_init_redis() +if _use_redis and _redis_client: + _public_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds + ) + _auth_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds + ) + _auth_strict_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds + ) +else: + _public_limiter = InMemorySlidingWindow( + settings.rate_limit_requests, settings.rate_limit_window_seconds + ) + _auth_limiter = InMemorySlidingWindow( + settings.rate_limit_requests * 5, settings.rate_limit_window_seconds + ) + _auth_strict_limiter = InMemorySlidingWindow( + settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds + ) def _get_client_ip(request: Request) -> str: @@ -164,23 +144,21 @@ def _get_client_ip(request: Request) -> str: return request.client.host if request.client else "unknown" -def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]: +def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]: """Determine rate limit key and which limiter to use.""" - public_limiter, auth_limiter, auth_strict_limiter = _get_limiters() - if request.url.path.startswith("/public"): - return f"ip:{_get_client_ip(request)}", public_limiter + return f"ip:{_get_client_ip(request)}", _public_limiter if request.url.path.startswith("/auth/") and request.method == "POST": - return f"ip:{_get_client_ip(request)}", auth_strict_limiter + return f"ip:{_get_client_ip(request)}", _auth_strict_limiter auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] token_hash = hashlib.sha256(token.encode()).hexdigest() - return f"token:{token_hash}", auth_limiter + return f"token:{token_hash}", _auth_limiter - return f"ip:{_get_client_ip(request)}", public_limiter + return f"ip:{_get_client_ip(request)}", _public_limiter class RateLimitMiddleware(BaseHTTPMiddleware): @@ -189,20 +167,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return await call_next(request) key, limiter = _get_rate_limit_key(request) - - try: - allowed, remaining, retry_after = await limiter.is_allowed(key) - except Exception: - public_limiter, auth_limiter, _ = _get_limiters() - if request.url.path.startswith("/auth/") and request.method == "POST": - limiter = auth_limiter - elif request.url.path.startswith("/public"): - limiter = public_limiter - elif request.headers.get("authorization", "").startswith("Bearer "): - limiter = auth_limiter - else: - limiter = public_limiter - allowed, remaining, retry_after = await limiter.is_allowed(key) + allowed, remaining, retry_after = await limiter.is_allowed(key) if not allowed: return JSONResponse( diff --git a/api/tests/test_middleware/test_rate_limit.py b/api/tests/test_middleware/test_rate_limit.py index fad69fd..fbfe7d1 100644 --- a/api/tests/test_middleware/test_rate_limit.py +++ b/api/tests/test_middleware/test_rate_limit.py @@ -1,5 +1,6 @@ """Tests for rate limiting middleware.""" +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -7,11 +8,9 @@ import pytest from cartsnitch_api.config import settings from cartsnitch_api.middleware.rate_limit import ( InMemorySlidingWindow, - RateLimitMiddleware, + RedisSlidingWindow, _get_client_ip, _get_rate_limit_key, - _init_redis, - _use_redis, ) @@ -44,6 +43,50 @@ class TestInMemorySlidingWindow: assert allowed_b is True assert remaining == 1 + def test_resets_after_window_expires(self): + limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1) + for _ in range(2): + limiter.is_allowed("test-key") + allowed, remaining, _ = limiter.is_allowed("test-key") + assert allowed is False + + time.sleep(1.1) + allowed, remaining, _ = limiter.is_allowed("test-key") + assert allowed is True + assert remaining == 1 + + +class TestGetClientIp: + def test_x_forwarded_for_single(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_multiple(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_with_port(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1:8080"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_no_forwarded_header(self): + req = MagicMock() + req.headers = {} + req.client.host = "127.0.0.1" + assert _get_client_ip(req) == "127.0.0.1" + + def test_no_client(self): + req = MagicMock() + req.headers = {} + req.client = None + assert _get_client_ip(req) == "unknown" + class TestGetRateLimitKey: def _make_request( @@ -108,62 +151,41 @@ class TestGetRateLimitKey: assert raw_token not in key -class TestGetClientIp: - def test_x_forwarded_for_single(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_x_forwarded_for_multiple(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_x_forwarded_for_with_port(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1:8080"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_no_forwarded_header(self): - req = MagicMock() - req.headers = {} - req.client.host = "127.0.0.1" - assert _get_client_ip(req) == "127.0.0.1" - - def test_no_client(self): - req = MagicMock() - req.headers = {} - req.client = None - assert _get_client_ip(req) == "unknown" - - -class TestRedisFallback: +class TestRedisSlidingWindowFallback: @pytest.mark.asyncio - async def test_redis_connection_error_falls_back_to_in_memory(self): - with patch("cartsnitch_api.middleware.rate_limit._use_redis", True): - with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client: - mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused")) - mock_client.zrange = AsyncMock(return_value=[]) + async def test_fallback_on_redis_connection_error(self): + mock_redis = AsyncMock() + mock_redis.pipeline.return_value = AsyncMock() + pipe_mock = AsyncMock() + pipe_mock.execute.side_effect = Exception("Connection refused") + mock_redis.pipeline.return_value = pipe_mock - limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60) - allowed, remaining, retry = await limiter.is_allowed("test-key") - assert allowed is True - assert remaining == 2 + limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60) + allowed, remaining, retry = await limiter.is_allowed("test-key") + assert allowed is True + assert remaining == 4 + + @pytest.mark.asyncio + async def test_fallback_on_redis_error_during_pipeline(self): + mock_redis = AsyncMock() + pipe_mock = AsyncMock() + pipe_mock.execute.side_effect = Exception("Redis error") + mock_redis.pipeline.return_value = pipe_mock + + limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60) + allowed, remaining, retry = await limiter.is_allowed("test-key") + assert allowed is True + + +@pytest.mark.asyncio +async def test_rate_limit_returns_429(client): + resp = await client.get("/public/inflation") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers @pytest.mark.asyncio async def test_health_skips_rate_limit(client): - """Health endpoint should not have rate limit headers.""" resp = await client.get("/health") assert resp.status_code == 200 - - -@pytest.mark.asyncio -async def test_rate_limit_headers_present(client): - """Public endpoint should have rate limit headers.""" - resp = await client.get("/public/inflation") - assert "x-ratelimit-limit" in resp.headers - assert "x-ratelimit-remaining" in resp.headers + assert "x-ratelimit-limit" not in resp.headers