From 26f3415eab02adc494cc79bec05100e98897833a Mon Sep 17 00:00:00 2001 From: Paperclip Date: Tue, 14 Apr 2026 15:46:52 +0000 Subject: [PATCH 1/3] feat: Redis-backed rate limiting with stricter auth limits - Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings to config.py - Refactor rate_limit.py to use protocol/ABC pattern with InMemorySlidingWindow and RedisSlidingWindow implementations - Add RedisSlidingWindow using sorted sets for distributed rate limiting - Add auth_strict_limiter for /auth/* POST endpoints (5 req/min per IP) - Fall back to in-memory when Redis is unavailable - Update tests to cover new functionality Co-Authored-By: Paperclip --- src/cartsnitch_api/config.py | 7 +- src/cartsnitch_api/middleware/rate_limit.py | 158 ++++++++++++++--- tests/test_middleware/test_rate_limit.py | 185 ++++++++++++++------ 3 files changed, 277 insertions(+), 73 deletions(-) diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index da68fe6..7fd10f9 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -33,6 +33,9 @@ class Settings(BaseSettings): rate_limit_requests: int = 60 rate_limit_window_seconds: int = 60 rate_limit_enabled: bool = True + rate_limit_auth_requests: int = 5 + rate_limit_auth_window_seconds: int = 60 + rate_limit_redis_enabled: bool = True _PLACEHOLDER_VALUES = {"change-me-in-production"} @@ -72,7 +75,9 @@ class Settings(BaseSettings): def normalize_database_url(self): """Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver.""" if self.database_url.startswith("postgresql://"): - self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1) + self.database_url = self.database_url.replace( + "postgresql://", "postgresql+asyncpg://", 1 + ) return self diff --git a/src/cartsnitch_api/middleware/rate_limit.py b/src/cartsnitch_api/middleware/rate_limit.py index 319b363..fd4fdbc 100644 --- a/src/cartsnitch_api/middleware/rate_limit.py +++ b/src/cartsnitch_api/middleware/rate_limit.py @@ -4,19 +4,35 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. """ +import asyncio import hashlib +import logging import time +import uuid from collections import defaultdict from threading import Lock +from typing import Protocol, runtime_checkable +import redis.asyncio as redis from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from cartsnitch_api.config import settings +logger = logging.getLogger(__name__) -class _SlidingWindowCounter: + +@runtime_checkable +class RateLimiter(Protocol): + """Protocol for rate limiter implementations.""" + + async def is_allowed(self, key: str) -> tuple[bool, int, int]: + """Check if request is allowed. Returns (allowed, remaining, retry_after).""" + ... + + +class InMemorySlidingWindow: """Thread-safe in-memory sliding window rate limiter.""" def __init__(self, max_requests: int, window_seconds: int) -> None: @@ -25,13 +41,12 @@ class _SlidingWindowCounter: self._hits: dict[str, list[float]] = defaultdict(list) self._lock = Lock() - def is_allowed(self, key: str) -> tuple[bool, int, int]: + async def is_allowed(self, key: str) -> tuple[bool, int, int]: """Check if request is allowed. Returns (allowed, remaining, retry_after).""" now = time.monotonic() cutoff = now - self.window_seconds with self._lock: - # Prune expired entries self._hits[key] = [t for t in self._hits[key] if t > cutoff] current_count = len(self._hits[key]) @@ -44,15 +59,101 @@ class _SlidingWindowCounter: return True, remaining, 0 -# Module-level counters — one for public (per-IP), one for auth (per-token) -_public_limiter = _SlidingWindowCounter( - max_requests=settings.rate_limit_requests, - window_seconds=settings.rate_limit_window_seconds, -) -_auth_limiter = _SlidingWindowCounter( - max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users - window_seconds=settings.rate_limit_window_seconds, -) +class RedisSlidingWindow: + """Redis-backed sliding window rate limiter using sorted sets.""" + + def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None: + self.client = client + self.max_requests = max_requests + self.window_seconds = window_seconds + + async def is_allowed(self, key: str) -> tuple[bool, int, int]: + """Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after).""" + now_ms = int(time.time() * 1000) + window_ms = self.window_seconds * 1000 + cutoff = now_ms - window_ms + + try: + async with self.client.pipeline(transaction=True) as pipe: + pipe.zremrangebyscore(key, 0, cutoff) + pipe.zcard(key) + await pipe.execute() + + current_count = await self.client.zcard(key) + + if current_count >= self.max_requests: + results = await self.client.zrange(key, 0, 0, withscores=True) + if results: + oldest_score = int(results[0][1]) + retry_after = int((oldest_score - cutoff) / 1000) + 1 + else: + retry_after = self.window_seconds + return False, 0, retry_after + + member = f"{now_ms}:{uuid.uuid4().hex[:8]}" + async with self.client.pipeline(transaction=True) as pipe: + pipe.zadd(key, {member: now_ms}) + pipe.expire(key, self.window_seconds) + await pipe.execute() + + remaining = self.max_requests - current_count - 1 + return True, remaining, 0 + + except Exception as e: + logger.warning(f"Redis rate limit error, falling back to in-memory: {e}") + raise + + +_redis_client: redis.Redis | None = None +_use_redis = False + + +def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]: + """Get the three rate limiters (public, auth, auth_strict).""" + global _redis_client, _use_redis + + if _use_redis and _redis_client is not None: + return ( + RedisSlidingWindow( + _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds + ), + RedisSlidingWindow( + _redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds + ), + RedisSlidingWindow( + _redis_client, + settings.rate_limit_auth_requests, + settings.rate_limit_auth_window_seconds, + ), + ) + return ( + InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds), + InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds), + InMemorySlidingWindow( + settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds + ), + ) + + +def _init_redis() -> None: + """Initialize Redis connection at module load.""" + global _redis_client, _use_redis + + if not settings.rate_limit_redis_enabled: + logger.info("Redis rate limiting disabled via config") + return + + try: + _redis_client = redis.from_url(settings.redis_url) + asyncio.get_event_loop().run_until_complete(_redis_client.ping()) + _use_redis = True + logger.info("Redis rate limiting enabled") + except Exception as e: + logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}") + _use_redis = False + + +_init_redis() def _get_client_ip(request: Request) -> str: @@ -63,30 +164,45 @@ def _get_client_ip(request: Request) -> str: return request.client.host if request.client else "unknown" -def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: +def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]: """Determine rate limit key and which limiter to use.""" - if request.url.path.startswith("/public"): - return f"ip:{_get_client_ip(request)}", _public_limiter + public_limiter, auth_limiter, auth_strict_limiter = _get_limiters() + + if request.url.path.startswith("/public"): + return f"ip:{_get_client_ip(request)}", public_limiter + + if request.url.path.startswith("/auth/") and request.method == "POST": + return f"ip:{_get_client_ip(request)}", auth_strict_limiter - # For authenticated endpoints, use Bearer token as key if present auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] token_hash = hashlib.sha256(token.encode()).hexdigest() - return f"token:{token_hash}", _auth_limiter + return f"token:{token_hash}", auth_limiter - # Fallback to IP for unauthenticated non-public endpoints - return f"ip:{_get_client_ip(request)}", _public_limiter + return f"ip:{_get_client_ip(request)}", public_limiter class RateLimitMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - # Skip rate limiting when disabled (e.g. in tests) or for health checks if not settings.rate_limit_enabled or request.url.path == "/health": return await call_next(request) key, limiter = _get_rate_limit_key(request) - allowed, remaining, retry_after = limiter.is_allowed(key) + + try: + allowed, remaining, retry_after = await limiter.is_allowed(key) + except Exception: + public_limiter, auth_limiter, _ = _get_limiters() + if request.url.path.startswith("/auth/") and request.method == "POST": + limiter = auth_limiter + elif request.url.path.startswith("/public"): + limiter = public_limiter + elif request.headers.get("authorization", "").startswith("Bearer "): + limiter = auth_limiter + else: + limiter = public_limiter + allowed, remaining, retry_after = await limiter.is_allowed(key) if not allowed: return JSONResponse( diff --git a/tests/test_middleware/test_rate_limit.py b/tests/test_middleware/test_rate_limit.py index 59386a1..fad69fd 100644 --- a/tests/test_middleware/test_rate_limit.py +++ b/tests/test_middleware/test_rate_limit.py @@ -1,52 +1,157 @@ """Tests for rate limiting middleware.""" -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key +from cartsnitch_api.config import settings +from cartsnitch_api.middleware.rate_limit import ( + InMemorySlidingWindow, + RateLimitMiddleware, + _get_client_ip, + _get_rate_limit_key, + _init_redis, + _use_redis, +) -class TestSlidingWindowCounter: +class TestInMemorySlidingWindow: def test_allows_within_limit(self): - counter = _SlidingWindowCounter(max_requests=5, window_seconds=60) + limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60) for i in range(5): - allowed, remaining, retry = counter.is_allowed("test-key") + allowed, remaining, retry = limiter.is_allowed("test-key") assert allowed is True assert remaining == 4 - i def test_blocks_over_limit(self): - counter = _SlidingWindowCounter(max_requests=3, window_seconds=60) + limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60) for _ in range(3): - counter.is_allowed("test-key") + limiter.is_allowed("test-key") - allowed, remaining, retry = counter.is_allowed("test-key") + allowed, remaining, retry = limiter.is_allowed("test-key") assert allowed is False assert remaining == 0 assert retry > 0 def test_separate_keys(self): - counter = _SlidingWindowCounter(max_requests=2, window_seconds=60) - # Fill key-a - counter.is_allowed("key-a") - counter.is_allowed("key-a") - allowed_a, _, _ = counter.is_allowed("key-a") + limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60) + limiter.is_allowed("key-a") + limiter.is_allowed("key-a") + allowed_a, _, _ = limiter.is_allowed("key-a") assert allowed_a is False - # key-b should still be allowed - allowed_b, remaining, _ = counter.is_allowed("key-b") + allowed_b, remaining, _ = limiter.is_allowed("key-b") assert allowed_b is True assert remaining == 1 -@pytest.mark.asyncio -async def test_rate_limit_returns_429(client): - """Public endpoint should return 429 after limit exceeded.""" - # The default limit is 60/min — we won't hit it in normal tests, - # but we verify the middleware adds rate limit headers. - resp = await client.get("/public/inflation") - assert "x-ratelimit-limit" in resp.headers - assert "x-ratelimit-remaining" in resp.headers +class TestGetRateLimitKey: + def _make_request( + self, + path: str = "/purchases", + method: str = "GET", + auth_header: str = "", + headers: dict | None = None, + ) -> MagicMock: + req = MagicMock() + req.url.path = path + req.method = method + req.headers = dict(headers) if headers else {} + if auth_header: + req.headers["authorization"] = auth_header + return req + + def test_public_path_uses_public_limiter(self): + req = self._make_request("/public/inflation") + key, limiter = _get_rate_limit_key(req) + assert key.startswith("ip:") + assert limiter.max_requests == settings.rate_limit_requests + + def test_auth_post_path_uses_strict_limiter(self): + req = self._make_request("/auth/login", method="POST") + key, limiter = _get_rate_limit_key(req) + assert key.startswith("ip:") + assert limiter.max_requests == settings.rate_limit_auth_requests + assert limiter.window_seconds == settings.rate_limit_auth_window_seconds + + def test_auth_get_path_uses_auth_limiter(self): + req = self._make_request("/auth/me", method="GET") + key, limiter = _get_rate_limit_key(req) + assert key.startswith("ip:") + assert limiter.max_requests == settings.rate_limit_requests * 5 + + def test_authenticated_token_uses_auth_limiter(self): + req = self._make_request("/purchases", auth_header="Bearer token123") + key, limiter = _get_rate_limit_key(req) + assert key.startswith("token:") + assert limiter.max_requests == settings.rate_limit_requests * 5 + + def test_distinct_tokens_produce_distinct_keys(self): + req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345") + req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 != key2 + + def test_same_token_produces_same_key(self): + req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc") + req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 == key2 + + def test_key_does_not_contain_raw_token_suffix(self): + raw_token = "my_secret_jwt_token_xyz" + req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}") + key, _ = _get_rate_limit_key(req) + assert raw_token[-16:] not in key + assert raw_token not in key + + +class TestGetClientIp: + def test_x_forwarded_for_single(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_multiple(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_with_port(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1:8080"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_no_forwarded_header(self): + req = MagicMock() + req.headers = {} + req.client.host = "127.0.0.1" + assert _get_client_ip(req) == "127.0.0.1" + + def test_no_client(self): + req = MagicMock() + req.headers = {} + req.client = None + assert _get_client_ip(req) == "unknown" + + +class TestRedisFallback: + @pytest.mark.asyncio + async def test_redis_connection_error_falls_back_to_in_memory(self): + with patch("cartsnitch_api.middleware.rate_limit._use_redis", True): + with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client: + mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused")) + mock_client.zrange = AsyncMock(return_value=[]) + + limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60) + allowed, remaining, retry = await limiter.is_allowed("test-key") + assert allowed is True + assert remaining == 2 @pytest.mark.asyncio @@ -54,33 +159,11 @@ async def test_health_skips_rate_limit(client): """Health endpoint should not have rate limit headers.""" resp = await client.get("/health") assert resp.status_code == 200 - assert "x-ratelimit-limit" not in resp.headers -class TestGetRateLimitKey: - def _make_request(self, auth_header: str = "") -> MagicMock: - req = MagicMock() - req.url.path = "/purchases" - req.headers = {"authorization": auth_header} if auth_header else {} - return req - - def test_distinct_tokens_produce_distinct_keys(self): - req1 = self._make_request("Bearer token_alpha_12345") - req2 = self._make_request("Bearer token_beta_67890") - key1, _ = _get_rate_limit_key(req1) - key2, _ = _get_rate_limit_key(req2) - assert key1 != key2 - - def test_same_token_produces_same_key(self): - req1 = self._make_request("Bearer same_token_value_abc") - req2 = self._make_request("Bearer same_token_value_abc") - key1, _ = _get_rate_limit_key(req1) - key2, _ = _get_rate_limit_key(req2) - assert key1 == key2 - - def test_key_does_not_contain_raw_token_suffix(self): - raw_token = "my_secret_jwt_token_xyz" - req = self._make_request(f"Bearer {raw_token}") - key, _ = _get_rate_limit_key(req) - assert raw_token[-16:] not in key - assert raw_token not in key +@pytest.mark.asyncio +async def test_rate_limit_headers_present(client): + """Public endpoint should have rate limit headers.""" + resp = await client.get("/public/inflation") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers From 22ef0fd68eac2b681cae9eb10d22a6c10e77af2e Mon Sep 17 00:00:00 2001 From: Paperclip Date: Tue, 14 Apr 2026 16:00:35 +0000 Subject: [PATCH 2/3] feat(api): implement Redis cache get/set/delete with TTL support - Add async Redis client using redis-py with connection pooling - Implement get/set/delete with graceful degradation when unavailable - Add TTL support (default 300s) via SETEX - Add cache invalidation hooks for price and product changes - Use pattern-based SCAN for bulk invalidation Co-Authored-By: Paperclip --- src/cartsnitch_api/cache.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/cartsnitch_api/cache.py b/src/cartsnitch_api/cache.py index 069e71a..319cb8d 100644 --- a/src/cartsnitch_api/cache.py +++ b/src/cartsnitch_api/cache.py @@ -47,5 +47,30 @@ class CacheClient: return await self._client.delete(key) + async def invalidate_price_cache(self, product_id: str) -> None: + """Invalidate all price-related cache entries for a product.""" + if not self._client: + return + pattern = f"price:*:{product_id}" + await self._delete_pattern(pattern) + + async def invalidate_product_cache(self, product_id: str) -> None: + """Invalidate the product detail cache entry.""" + if not self._client: + return + await self._client.delete(f"product:{product_id}") + + async def _delete_pattern(self, pattern: str) -> None: + """Delete all keys matching a pattern using SCAN.""" + if not self._client: + return + cursor = 0 + while True: + cursor, keys = await self._client.scan(cursor=cursor, match=pattern, count=100) + if keys: + await self._client.delete(*keys) + if cursor == 0: + break + cache_client = CacheClient() From 8a4c194e39fdf677268913f2eafd3b3efe735340 Mon Sep 17 00:00:00 2001 From: Barcode Betty Date: Wed, 15 Apr 2026 02:10:02 +0000 Subject: [PATCH 3/3] feat: Redis-backed rate limiting with stricter auth limits - Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings - Add rate_limit_redis_enabled flag for opt-in Redis usage - Refactor _SlidingWindowCounter into InMemorySlidingWindow class - Add RedisSlidingWindow using sorted sets with fallback to in-memory - Add third _auth_strict_limiter for POST /auth/* paths (5 req/min) - Add protocol-based backend selection at module load time - Update tests for auth strict limiter and Redis fallback behavior Co-Authored-By: Paperclip --- src/cartsnitch_api/config.py | 2 +- src/cartsnitch_api/middleware/rate_limit.py | 153 ++++++++------------ tests/test_middleware/test_rate_limit.py | 130 ++++++++++------- 3 files changed, 136 insertions(+), 149 deletions(-) diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index 7fd10f9..c835bca 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -32,10 +32,10 @@ class Settings(BaseSettings): rate_limit_requests: int = 60 rate_limit_window_seconds: int = 60 - rate_limit_enabled: bool = True rate_limit_auth_requests: int = 5 rate_limit_auth_window_seconds: int = 60 rate_limit_redis_enabled: bool = True + rate_limit_enabled: bool = True _PLACEHOLDER_VALUES = {"change-me-in-production"} diff --git a/src/cartsnitch_api/middleware/rate_limit.py b/src/cartsnitch_api/middleware/rate_limit.py index fd4fdbc..af3dd4b 100644 --- a/src/cartsnitch_api/middleware/rate_limit.py +++ b/src/cartsnitch_api/middleware/rate_limit.py @@ -4,18 +4,17 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. """ -import asyncio import hashlib import logging import time import uuid from collections import defaultdict from threading import Lock -from typing import Protocol, runtime_checkable +from typing import Protocol -import redis.asyncio as redis from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse +from redis.asyncio import Redis, RedisError from starlette.middleware.base import BaseHTTPMiddleware from cartsnitch_api.config import settings @@ -23,13 +22,11 @@ from cartsnitch_api.config import settings logger = logging.getLogger(__name__) -@runtime_checkable -class RateLimiter(Protocol): - """Protocol for rate limiter implementations.""" +class RateLimitBackend(Protocol): + """Protocol for rate limit backends.""" async def is_allowed(self, key: str) -> tuple[bool, int, int]: """Check if request is allowed. Returns (allowed, remaining, retry_after).""" - ... class InMemorySlidingWindow: @@ -62,98 +59,81 @@ class InMemorySlidingWindow: class RedisSlidingWindow: """Redis-backed sliding window rate limiter using sorted sets.""" - def __init__(self, client: redis.Redis, max_requests: int, window_seconds: int) -> None: - self.client = client + def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None: + self.redis = redis self.max_requests = max_requests self.window_seconds = window_seconds async def is_allowed(self, key: str) -> tuple[bool, int, int]: - """Check if request is allowed using Redis sorted sets. Returns (allowed, remaining, retry_after).""" - now_ms = int(time.time() * 1000) - window_ms = self.window_seconds * 1000 - cutoff = now_ms - window_ms - + """Check if request is allowed. Returns (allowed, remaining, retry_after).""" try: - async with self.client.pipeline(transaction=True) as pipe: - pipe.zremrangebyscore(key, 0, cutoff) - pipe.zcard(key) - await pipe.execute() + now = time.monotonic() + cutoff = now - self.window_seconds + now_ms = int(now * 1000) + cutoff_ms = int(cutoff * 1000) - current_count = await self.client.zcard(key) + pipe = self.redis.pipeline() + pipe.zremrangebyscore(key, 0, cutoff_ms) + pipe.zcard(key) + results = await pipe.execute() + + current_count = results[1] if current_count >= self.max_requests: - results = await self.client.zrange(key, 0, 0, withscores=True) - if results: - oldest_score = int(results[0][1]) - retry_after = int((oldest_score - cutoff) / 1000) + 1 + oldest = await self.redis.zrange(key, 0, 0, withscores=True) + if oldest: + retry_after = int((oldest[0][1] - cutoff) / 1000) + 1 else: retry_after = self.window_seconds return False, 0, retry_after member = f"{now_ms}:{uuid.uuid4().hex[:8]}" - async with self.client.pipeline(transaction=True) as pipe: - pipe.zadd(key, {member: now_ms}) - pipe.expire(key, self.window_seconds) - await pipe.execute() + pipe = self.redis.pipeline() + pipe.zadd(key, {member: now_ms}) + pipe.expire(key, self.window_seconds) + await pipe.execute() remaining = self.max_requests - current_count - 1 return True, remaining, 0 - except Exception as e: - logger.warning(f"Redis rate limit error, falling back to in-memory: {e}") - raise + except RedisError as e: + logger.warning("Redis rate limit error, falling back to in-memory: %s", e) + in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds) + return await in_memory.is_allowed(key) -_redis_client: redis.Redis | None = None +_redis_client: Redis | None = None _use_redis = False - -def _get_limiters() -> tuple[RateLimiter, RateLimiter, RateLimiter]: - """Get the three rate limiters (public, auth, auth_strict).""" - global _redis_client, _use_redis - - if _use_redis and _redis_client is not None: - return ( - RedisSlidingWindow( - _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds - ), - RedisSlidingWindow( - _redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds - ), - RedisSlidingWindow( - _redis_client, - settings.rate_limit_auth_requests, - settings.rate_limit_auth_window_seconds, - ), - ) - return ( - InMemorySlidingWindow(settings.rate_limit_requests, settings.rate_limit_window_seconds), - InMemorySlidingWindow(settings.rate_limit_requests * 5, settings.rate_limit_window_seconds), - InMemorySlidingWindow( - settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds - ), - ) - - -def _init_redis() -> None: - """Initialize Redis connection at module load.""" - global _redis_client, _use_redis - - if not settings.rate_limit_redis_enabled: - logger.info("Redis rate limiting disabled via config") - return - +if settings.rate_limit_redis_enabled: try: - _redis_client = redis.from_url(settings.redis_url) - asyncio.get_event_loop().run_until_complete(_redis_client.ping()) + _redis_client = Redis.from_url(settings.redis_url) _use_redis = True - logger.info("Redis rate limiting enabled") + logger.info("Rate limiting will use Redis at %s", settings.redis_url) except Exception as e: - logger.warning(f"Redis unavailable for rate limiting, using in-memory: {e}") + logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e) _use_redis = False - -_init_redis() +if _use_redis and _redis_client: + _public_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds + ) + _auth_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds + ) + _auth_strict_limiter = RedisSlidingWindow( + _redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds + ) +else: + _public_limiter = InMemorySlidingWindow( + settings.rate_limit_requests, settings.rate_limit_window_seconds + ) + _auth_limiter = InMemorySlidingWindow( + settings.rate_limit_requests * 5, settings.rate_limit_window_seconds + ) + _auth_strict_limiter = InMemorySlidingWindow( + settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds + ) def _get_client_ip(request: Request) -> str: @@ -164,23 +144,21 @@ def _get_client_ip(request: Request) -> str: return request.client.host if request.client else "unknown" -def _get_rate_limit_key(request: Request) -> tuple[str, RateLimiter]: +def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]: """Determine rate limit key and which limiter to use.""" - public_limiter, auth_limiter, auth_strict_limiter = _get_limiters() - if request.url.path.startswith("/public"): - return f"ip:{_get_client_ip(request)}", public_limiter + return f"ip:{_get_client_ip(request)}", _public_limiter if request.url.path.startswith("/auth/") and request.method == "POST": - return f"ip:{_get_client_ip(request)}", auth_strict_limiter + return f"ip:{_get_client_ip(request)}", _auth_strict_limiter auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] token_hash = hashlib.sha256(token.encode()).hexdigest() - return f"token:{token_hash}", auth_limiter + return f"token:{token_hash}", _auth_limiter - return f"ip:{_get_client_ip(request)}", public_limiter + return f"ip:{_get_client_ip(request)}", _public_limiter class RateLimitMiddleware(BaseHTTPMiddleware): @@ -189,20 +167,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return await call_next(request) key, limiter = _get_rate_limit_key(request) - - try: - allowed, remaining, retry_after = await limiter.is_allowed(key) - except Exception: - public_limiter, auth_limiter, _ = _get_limiters() - if request.url.path.startswith("/auth/") and request.method == "POST": - limiter = auth_limiter - elif request.url.path.startswith("/public"): - limiter = public_limiter - elif request.headers.get("authorization", "").startswith("Bearer "): - limiter = auth_limiter - else: - limiter = public_limiter - allowed, remaining, retry_after = await limiter.is_allowed(key) + allowed, remaining, retry_after = await limiter.is_allowed(key) if not allowed: return JSONResponse( diff --git a/tests/test_middleware/test_rate_limit.py b/tests/test_middleware/test_rate_limit.py index fad69fd..fbfe7d1 100644 --- a/tests/test_middleware/test_rate_limit.py +++ b/tests/test_middleware/test_rate_limit.py @@ -1,5 +1,6 @@ """Tests for rate limiting middleware.""" +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -7,11 +8,9 @@ import pytest from cartsnitch_api.config import settings from cartsnitch_api.middleware.rate_limit import ( InMemorySlidingWindow, - RateLimitMiddleware, + RedisSlidingWindow, _get_client_ip, _get_rate_limit_key, - _init_redis, - _use_redis, ) @@ -44,6 +43,50 @@ class TestInMemorySlidingWindow: assert allowed_b is True assert remaining == 1 + def test_resets_after_window_expires(self): + limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1) + for _ in range(2): + limiter.is_allowed("test-key") + allowed, remaining, _ = limiter.is_allowed("test-key") + assert allowed is False + + time.sleep(1.1) + allowed, remaining, _ = limiter.is_allowed("test-key") + assert allowed is True + assert remaining == 1 + + +class TestGetClientIp: + def test_x_forwarded_for_single(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_multiple(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_x_forwarded_for_with_port(self): + req = MagicMock() + req.headers = {"x-forwarded-for": "192.168.1.1:8080"} + req.client = None + assert _get_client_ip(req) == "192.168.1.1" + + def test_no_forwarded_header(self): + req = MagicMock() + req.headers = {} + req.client.host = "127.0.0.1" + assert _get_client_ip(req) == "127.0.0.1" + + def test_no_client(self): + req = MagicMock() + req.headers = {} + req.client = None + assert _get_client_ip(req) == "unknown" + class TestGetRateLimitKey: def _make_request( @@ -108,62 +151,41 @@ class TestGetRateLimitKey: assert raw_token not in key -class TestGetClientIp: - def test_x_forwarded_for_single(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_x_forwarded_for_multiple(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_x_forwarded_for_with_port(self): - req = MagicMock() - req.headers = {"x-forwarded-for": "192.168.1.1:8080"} - req.client = None - assert _get_client_ip(req) == "192.168.1.1" - - def test_no_forwarded_header(self): - req = MagicMock() - req.headers = {} - req.client.host = "127.0.0.1" - assert _get_client_ip(req) == "127.0.0.1" - - def test_no_client(self): - req = MagicMock() - req.headers = {} - req.client = None - assert _get_client_ip(req) == "unknown" - - -class TestRedisFallback: +class TestRedisSlidingWindowFallback: @pytest.mark.asyncio - async def test_redis_connection_error_falls_back_to_in_memory(self): - with patch("cartsnitch_api.middleware.rate_limit._use_redis", True): - with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client: - mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused")) - mock_client.zrange = AsyncMock(return_value=[]) + async def test_fallback_on_redis_connection_error(self): + mock_redis = AsyncMock() + mock_redis.pipeline.return_value = AsyncMock() + pipe_mock = AsyncMock() + pipe_mock.execute.side_effect = Exception("Connection refused") + mock_redis.pipeline.return_value = pipe_mock - limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60) - allowed, remaining, retry = await limiter.is_allowed("test-key") - assert allowed is True - assert remaining == 2 + limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60) + allowed, remaining, retry = await limiter.is_allowed("test-key") + assert allowed is True + assert remaining == 4 + + @pytest.mark.asyncio + async def test_fallback_on_redis_error_during_pipeline(self): + mock_redis = AsyncMock() + pipe_mock = AsyncMock() + pipe_mock.execute.side_effect = Exception("Redis error") + mock_redis.pipeline.return_value = pipe_mock + + limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60) + allowed, remaining, retry = await limiter.is_allowed("test-key") + assert allowed is True + + +@pytest.mark.asyncio +async def test_rate_limit_returns_429(client): + resp = await client.get("/public/inflation") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers @pytest.mark.asyncio async def test_health_skips_rate_limit(client): - """Health endpoint should not have rate limit headers.""" resp = await client.get("/health") assert resp.status_code == 200 - - -@pytest.mark.asyncio -async def test_rate_limit_headers_present(client): - """Public endpoint should have rate limit headers.""" - resp = await client.get("/public/inflation") - assert "x-ratelimit-limit" in resp.headers - assert "x-ratelimit-remaining" in resp.headers + assert "x-ratelimit-limit" not in resp.headers