forked from cartsnitch/api
26f3415eab
- Add rate_limit_auth_requests (5/min) and rate_limit_auth_window_seconds (60) settings to config.py - Refactor rate_limit.py to use protocol/ABC pattern with InMemorySlidingWindow and RedisSlidingWindow implementations - Add RedisSlidingWindow using sorted sets for distributed rate limiting - Add auth_strict_limiter for /auth/* POST endpoints (5 req/min per IP) - Fall back to in-memory when Redis is unavailable - Update tests to cover new functionality Co-Authored-By: Paperclip <noreply@paperclip.ing>
170 lines
6.1 KiB
Python
170 lines
6.1 KiB
Python
"""Tests for rate limiting middleware."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from cartsnitch_api.config import settings
|
|
from cartsnitch_api.middleware.rate_limit import (
|
|
InMemorySlidingWindow,
|
|
RateLimitMiddleware,
|
|
_get_client_ip,
|
|
_get_rate_limit_key,
|
|
_init_redis,
|
|
_use_redis,
|
|
)
|
|
|
|
|
|
class TestInMemorySlidingWindow:
|
|
def test_allows_within_limit(self):
|
|
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
|
|
for i in range(5):
|
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 4 - i
|
|
|
|
def test_blocks_over_limit(self):
|
|
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
|
for _ in range(3):
|
|
limiter.is_allowed("test-key")
|
|
|
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
|
assert allowed is False
|
|
assert remaining == 0
|
|
assert retry > 0
|
|
|
|
def test_separate_keys(self):
|
|
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
|
|
limiter.is_allowed("key-a")
|
|
limiter.is_allowed("key-a")
|
|
allowed_a, _, _ = limiter.is_allowed("key-a")
|
|
assert allowed_a is False
|
|
|
|
allowed_b, remaining, _ = limiter.is_allowed("key-b")
|
|
assert allowed_b is True
|
|
assert remaining == 1
|
|
|
|
|
|
class TestGetRateLimitKey:
|
|
def _make_request(
|
|
self,
|
|
path: str = "/purchases",
|
|
method: str = "GET",
|
|
auth_header: str = "",
|
|
headers: dict | None = None,
|
|
) -> MagicMock:
|
|
req = MagicMock()
|
|
req.url.path = path
|
|
req.method = method
|
|
req.headers = dict(headers) if headers else {}
|
|
if auth_header:
|
|
req.headers["authorization"] = auth_header
|
|
return req
|
|
|
|
def test_public_path_uses_public_limiter(self):
|
|
req = self._make_request("/public/inflation")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_requests
|
|
|
|
def test_auth_post_path_uses_strict_limiter(self):
|
|
req = self._make_request("/auth/login", method="POST")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_auth_requests
|
|
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
|
|
|
|
def test_auth_get_path_uses_auth_limiter(self):
|
|
req = self._make_request("/auth/me", method="GET")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("ip:")
|
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
|
|
|
def test_authenticated_token_uses_auth_limiter(self):
|
|
req = self._make_request("/purchases", auth_header="Bearer token123")
|
|
key, limiter = _get_rate_limit_key(req)
|
|
assert key.startswith("token:")
|
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
|
|
|
def test_distinct_tokens_produce_distinct_keys(self):
|
|
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
|
|
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 != key2
|
|
|
|
def test_same_token_produces_same_key(self):
|
|
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
|
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 == key2
|
|
|
|
def test_key_does_not_contain_raw_token_suffix(self):
|
|
raw_token = "my_secret_jwt_token_xyz"
|
|
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
|
|
key, _ = _get_rate_limit_key(req)
|
|
assert raw_token[-16:] not in key
|
|
assert raw_token not in key
|
|
|
|
|
|
class TestGetClientIp:
|
|
def test_x_forwarded_for_single(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_x_forwarded_for_multiple(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_x_forwarded_for_with_port(self):
|
|
req = MagicMock()
|
|
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "192.168.1.1"
|
|
|
|
def test_no_forwarded_header(self):
|
|
req = MagicMock()
|
|
req.headers = {}
|
|
req.client.host = "127.0.0.1"
|
|
assert _get_client_ip(req) == "127.0.0.1"
|
|
|
|
def test_no_client(self):
|
|
req = MagicMock()
|
|
req.headers = {}
|
|
req.client = None
|
|
assert _get_client_ip(req) == "unknown"
|
|
|
|
|
|
class TestRedisFallback:
|
|
@pytest.mark.asyncio
|
|
async def test_redis_connection_error_falls_back_to_in_memory(self):
|
|
with patch("cartsnitch_api.middleware.rate_limit._use_redis", True):
|
|
with patch("cartsnitch_api.middleware.rate_limit._redis_client") as mock_client:
|
|
mock_client.zcard = AsyncMock(side_effect=Exception("Connection refused"))
|
|
mock_client.zrange = AsyncMock(return_value=[])
|
|
|
|
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
|
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_skips_rate_limit(client):
|
|
"""Health endpoint should not have rate limit headers."""
|
|
resp = await client.get("/health")
|
|
assert resp.status_code == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_headers_present(client):
|
|
"""Public endpoint should have rate limit headers."""
|
|
resp = await client.get("/public/inflation")
|
|
assert "x-ratelimit-limit" in resp.headers
|
|
assert "x-ratelimit-remaining" in resp.headers
|