8a4c194e39
- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings - Add rate_limit_redis_enabled flag for opt-in Redis usage - Refactor _SlidingWindowCounter into InMemorySlidingWindow class - Add RedisSlidingWindow using sorted sets with fallback to in-memory - Add third _auth_strict_limiter for POST /auth/* paths (5 req/min) - Add protocol-based backend selection at module load time - Update tests for auth strict limiter and Redis fallback behavior Co-Authored-By: Paperclip <noreply@paperclip.ing>
192 lines
6.8 KiB
Python
192 lines
6.8 KiB
Python
"""Tests for rate limiting middleware."""
|
|
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from cartsnitch_api.config import settings
|
|
from cartsnitch_api.middleware.rate_limit import (
|
|
InMemorySlidingWindow,
|
|
RedisSlidingWindow,
|
|
_get_client_ip,
|
|
_get_rate_limit_key,
|
|
)
|
|
|
|
|
|
class TestInMemorySlidingWindow:
|
|
def test_allows_within_limit(self):
|
|
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
|
|
for i in range(5):
|
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 4 - i
|
|
|
|
def test_blocks_over_limit(self):
|
|
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
|
for _ in range(3):
|
|
limiter.is_allowed("test-key")
|
|
|
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
|
assert allowed is False
|
|
assert remaining == 0
|
|
assert retry > 0
|
|
|
|
def test_separate_keys(self):
|
|
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
|
|
limiter.is_allowed("key-a")
|
|
limiter.is_allowed("key-a")
|
|
allowed_a, _, _ = limiter.is_allowed("key-a")
|
|
assert allowed_a is False
|
|
|
|
allowed_b, remaining, _ = limiter.is_allowed("key-b")
|
|
assert allowed_b is True
|
|
assert remaining == 1
|
|
|
|
def test_resets_after_window_expires(self):
|
|
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
|
|
for _ in range(2):
|
|
limiter.is_allowed("test-key")
|
|
allowed, remaining, _ = limiter.is_allowed("test-key")
|
|
assert allowed is False
|
|
|
|
time.sleep(1.1)
|
|
allowed, remaining, _ = limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 1
|
|
|
|
|
|
class TestGetClientIp:
|
|
def test_x_forwarded_for_single(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_x_forwarded_for_multiple(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_x_forwarded_for_with_port(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_no_forwarded_header(self):
|
|
req = MagicMock()
|
|
req.headers = {}
|
|
req.client.host = "127.0.0.1"
|
|
assert _get_client_ip(req) == "127.0.0.1"
|
|
|
|
def test_no_client(self):
|
|
req = MagicMock()
|
|
req.headers = {}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "unknown"
|
|
|
|
|
|
class TestGetRateLimitKey:
|
|
def _make_request(
|
|
self,
|
|
path: str = "/purchases",
|
|
method: str = "GET",
|
|
auth_header: str = "",
|
|
headers: dict | None = None,
|
|
) -> MagicMock:
|
|
req = MagicMock()
|
|
req.url.path = path
|
|
req.method = method
|
|
req.headers = dict(headers) if headers else {}
|
|
if auth_header:
|
|
req.headers["authorization"] = auth_header
|
|
return req
|
|
|
|
def test_public_path_uses_public_limiter(self):
|
|
req = self._make_request("/public/inflation")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_requests
|
|
|
|
def test_auth_post_path_uses_strict_limiter(self):
|
|
req = self._make_request("/auth/login", method="POST")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_auth_requests
|
|
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
|
|
|
|
def test_auth_get_path_uses_auth_limiter(self):
|
|
req = self._make_request("/auth/me", method="GET")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
|
|
|
def test_authenticated_token_uses_auth_limiter(self):
|
|
req = self._make_request("/purchases", auth_header="Bearer token123")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("token:")
|
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
|
|
|
def test_distinct_tokens_produce_distinct_keys(self):
|
|
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
|
|
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 != key2
|
|
|
|
def test_same_token_produces_same_key(self):
|
|
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
|
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 == key2
|
|
|
|
def test_key_does_not_contain_raw_token_suffix(self):
|
|
raw_token = "my_secret_jwt_token_xyz"
|
|
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
|
|
key, _ = _get_rate_limit_key(req)
|
|
assert raw_token[-16:] not in key
|
|
assert raw_token not in key
|
|
|
|
|
|
class TestRedisSlidingWindowFallback:
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_on_redis_connection_error(self):
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pipeline.return_value = AsyncMock()
|
|
pipe_mock = AsyncMock()
|
|
pipe_mock.execute.side_effect = Exception("Connection refused")
|
|
mock_redis.pipeline.return_value = pipe_mock
|
|
|
|
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
|
|
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 4
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_on_redis_error_during_pipeline(self):
|
|
mock_redis = AsyncMock()
|
|
pipe_mock = AsyncMock()
|
|
pipe_mock.execute.side_effect = Exception("Redis error")
|
|
mock_redis.pipeline.return_value = pipe_mock
|
|
|
|
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
|
|
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_returns_429(client):
|
|
resp = await client.get("/public/inflation")
|
|
assert "x-ratelimit-limit" in resp.headers
|
|
assert "x-ratelimit-remaining" in resp.headers
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_skips_rate_limit(client):
|
|
resp = await client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert "x-ratelimit-limit" not in resp.headers
|