forked from cartsnitch/api
b5df9aba1e
Co-Authored-By: Paperclip <noreply@paperclip.ing>
87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
"""Tests for rate limiting middleware."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
|
|
|
|
|
class TestSlidingWindowCounter:
|
|
def test_allows_within_limit(self):
|
|
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
|
for i in range(5):
|
|
allowed, remaining, retry = counter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert remaining == 4 - i
|
|
|
|
def test_blocks_over_limit(self):
|
|
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
|
for _ in range(3):
|
|
counter.is_allowed("test-key")
|
|
|
|
allowed, remaining, retry = counter.is_allowed("test-key")
|
|
assert allowed is False
|
|
assert remaining == 0
|
|
assert retry > 0
|
|
|
|
def test_separate_keys(self):
|
|
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
|
# Fill key-a
|
|
counter.is_allowed("key-a")
|
|
counter.is_allowed("key-a")
|
|
allowed_a, _, _ = counter.is_allowed("key-a")
|
|
assert allowed_a is False
|
|
|
|
# key-b should still be allowed
|
|
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
|
assert allowed_b is True
|
|
assert remaining == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_returns_429(client):
|
|
"""Public endpoint should return 429 after limit exceeded."""
|
|
# The default limit is 60/min — we won't hit it in normal tests,
|
|
# but we verify the middleware adds rate limit headers.
|
|
resp = await client.get("/public/inflation")
|
|
assert "x-ratelimit-limit" in resp.headers
|
|
assert "x-ratelimit-remaining" in resp.headers
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_skips_rate_limit(client):
|
|
"""Health endpoint should not have rate limit headers."""
|
|
resp = await client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert "x-ratelimit-limit" not in resp.headers
|
|
|
|
|
|
class TestGetRateLimitKey:
|
|
def _make_request(self, auth_header: str = "") -> MagicMock:
|
|
req = MagicMock()
|
|
req.url.path = "/purchases"
|
|
req.headers = {"authorization": auth_header} if auth_header else {}
|
|
return req
|
|
|
|
def test_distinct_tokens_produce_distinct_keys(self):
|
|
req1 = self._make_request("Bearer token_alpha_12345")
|
|
req2 = self._make_request("Bearer token_beta_67890")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 != key2
|
|
|
|
def test_same_token_produces_same_key(self):
|
|
req1 = self._make_request("Bearer same_token_value_abc")
|
|
req2 = self._make_request("Bearer same_token_value_abc")
|
|
key1, _ = _get_rate_limit_key(req1)
|
|
key2, _ = _get_rate_limit_key(req2)
|
|
assert key1 == key2
|
|
|
|
def test_key_does_not_contain_raw_token_suffix(self):
|
|
raw_token = "my_secret_jwt_token_xyz"
|
|
req = self._make_request(f"Bearer {raw_token}")
|
|
key, _ = _get_rate_limit_key(req)
|
|
assert raw_token[-16:] not in key
|
|
assert raw_token not in key
|