forked from cartsnitch/api
feat: Redis-backed rate limiting with stricter auth limits (#194)
feat: Redis-backed rate limiting with stricter auth limits
This commit is contained in:
@@ -32,6 +32,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
rate_limit_requests: int = 60
|
rate_limit_requests: int = 60
|
||||||
rate_limit_window_seconds: int = 60
|
rate_limit_window_seconds: int = 60
|
||||||
|
rate_limit_auth_requests: int = 5
|
||||||
|
rate_limit_auth_window_seconds: int = 60
|
||||||
|
rate_limit_redis_enabled: bool = True
|
||||||
rate_limit_enabled: bool = True
|
rate_limit_enabled: bool = True
|
||||||
|
|
||||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||||
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
|
|||||||
def normalize_database_url(self):
|
def normalize_database_url(self):
|
||||||
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
|
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
|
||||||
if self.database_url.startswith("postgresql://"):
|
if self.database_url.startswith("postgresql://"):
|
||||||
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
self.database_url = self.database_url.replace(
|
||||||
|
"postgresql://", "postgresql+asyncpg://", 1
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,31 @@ Per-IP limiting on public endpoints, per-token limiting on authenticated endpoin
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, status
|
from fastapi import FastAPI, Request, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from redis.asyncio import Redis, RedisError
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class _SlidingWindowCounter:
|
|
||||||
|
class RateLimitBackend(Protocol):
|
||||||
|
"""Protocol for rate limit backends."""
|
||||||
|
|
||||||
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||||
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||||
|
|
||||||
|
|
||||||
|
class InMemorySlidingWindow:
|
||||||
"""Thread-safe in-memory sliding window rate limiter."""
|
"""Thread-safe in-memory sliding window rate limiter."""
|
||||||
|
|
||||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||||
@@ -25,13 +38,12 @@ class _SlidingWindowCounter:
|
|||||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
cutoff = now - self.window_seconds
|
cutoff = now - self.window_seconds
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Prune expired entries
|
|
||||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||||
|
|
||||||
current_count = len(self._hits[key])
|
current_count = len(self._hits[key])
|
||||||
@@ -44,15 +56,84 @@ class _SlidingWindowCounter:
|
|||||||
return True, remaining, 0
|
return True, remaining, 0
|
||||||
|
|
||||||
|
|
||||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
class RedisSlidingWindow:
|
||||||
_public_limiter = _SlidingWindowCounter(
|
"""Redis-backed sliding window rate limiter using sorted sets."""
|
||||||
max_requests=settings.rate_limit_requests,
|
|
||||||
window_seconds=settings.rate_limit_window_seconds,
|
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
|
||||||
)
|
self.redis = redis
|
||||||
_auth_limiter = _SlidingWindowCounter(
|
self.max_requests = max_requests
|
||||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
self.window_seconds = window_seconds
|
||||||
window_seconds=settings.rate_limit_window_seconds,
|
|
||||||
)
|
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||||
|
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||||
|
try:
|
||||||
|
now = time.monotonic()
|
||||||
|
cutoff = now - self.window_seconds
|
||||||
|
now_ms = int(now * 1000)
|
||||||
|
cutoff_ms = int(cutoff * 1000)
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.zremrangebyscore(key, 0, cutoff_ms)
|
||||||
|
pipe.zcard(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
current_count = results[1]
|
||||||
|
|
||||||
|
if current_count >= self.max_requests:
|
||||||
|
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
|
||||||
|
if oldest:
|
||||||
|
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
|
||||||
|
else:
|
||||||
|
retry_after = self.window_seconds
|
||||||
|
return False, 0, retry_after
|
||||||
|
|
||||||
|
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.zadd(key, {member: now_ms})
|
||||||
|
pipe.expire(key, self.window_seconds)
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
remaining = self.max_requests - current_count - 1
|
||||||
|
return True, remaining, 0
|
||||||
|
|
||||||
|
except RedisError as e:
|
||||||
|
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
|
||||||
|
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
|
||||||
|
return await in_memory.is_allowed(key)
|
||||||
|
|
||||||
|
|
||||||
|
_redis_client: Redis | None = None
|
||||||
|
_use_redis = False
|
||||||
|
|
||||||
|
if settings.rate_limit_redis_enabled:
|
||||||
|
try:
|
||||||
|
_redis_client = Redis.from_url(settings.redis_url)
|
||||||
|
_use_redis = True
|
||||||
|
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
|
||||||
|
_use_redis = False
|
||||||
|
|
||||||
|
if _use_redis and _redis_client:
|
||||||
|
_public_limiter = RedisSlidingWindow(
|
||||||
|
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||||
|
)
|
||||||
|
_auth_limiter = RedisSlidingWindow(
|
||||||
|
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||||
|
)
|
||||||
|
_auth_strict_limiter = RedisSlidingWindow(
|
||||||
|
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_public_limiter = InMemorySlidingWindow(
|
||||||
|
settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||||
|
)
|
||||||
|
_auth_limiter = InMemorySlidingWindow(
|
||||||
|
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||||
|
)
|
||||||
|
_auth_strict_limiter = InMemorySlidingWindow(
|
||||||
|
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip(request: Request) -> str:
|
def _get_client_ip(request: Request) -> str:
|
||||||
@@ -63,30 +144,30 @@ def _get_client_ip(request: Request) -> str:
|
|||||||
return request.client.host if request.client else "unknown"
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
|
||||||
"""Determine rate limit key and which limiter to use."""
|
"""Determine rate limit key and which limiter to use."""
|
||||||
if request.url.path.startswith("/public"):
|
if request.url.path.startswith("/public"):
|
||||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||||
|
|
||||||
# For authenticated endpoints, use Bearer token as key if present
|
if request.url.path.startswith("/auth/") and request.method == "POST":
|
||||||
|
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
|
||||||
|
|
||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.startswith("Bearer "):
|
||||||
token = auth_header[7:]
|
token = auth_header[7:]
|
||||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||||
return f"token:{token_hash}", _auth_limiter
|
return f"token:{token_hash}", _auth_limiter
|
||||||
|
|
||||||
# Fallback to IP for unauthenticated non-public endpoints
|
|
||||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||||
|
|
||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
|
||||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
key, limiter = _get_rate_limit_key(request)
|
key, limiter = _get_rate_limit_key(request)
|
||||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|||||||
@@ -1,49 +1,184 @@
|
|||||||
"""Tests for rate limiting middleware."""
|
"""Tests for rate limiting middleware."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
from cartsnitch_api.config import settings
|
||||||
|
from cartsnitch_api.middleware.rate_limit import (
|
||||||
|
InMemorySlidingWindow,
|
||||||
|
RedisSlidingWindow,
|
||||||
|
_get_client_ip,
|
||||||
|
_get_rate_limit_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSlidingWindowCounter:
|
class TestInMemorySlidingWindow:
|
||||||
def test_allows_within_limit(self):
|
def test_allows_within_limit(self):
|
||||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||||
assert allowed is True
|
assert allowed is True
|
||||||
assert remaining == 4 - i
|
assert remaining == 4 - i
|
||||||
|
|
||||||
def test_blocks_over_limit(self):
|
def test_blocks_over_limit(self):
|
||||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
counter.is_allowed("test-key")
|
limiter.is_allowed("test-key")
|
||||||
|
|
||||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||||
assert allowed is False
|
assert allowed is False
|
||||||
assert remaining == 0
|
assert remaining == 0
|
||||||
assert retry > 0
|
assert retry > 0
|
||||||
|
|
||||||
def test_separate_keys(self):
|
def test_separate_keys(self):
|
||||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
|
||||||
# Fill key-a
|
limiter.is_allowed("key-a")
|
||||||
counter.is_allowed("key-a")
|
limiter.is_allowed("key-a")
|
||||||
counter.is_allowed("key-a")
|
allowed_a, _, _ = limiter.is_allowed("key-a")
|
||||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
|
||||||
assert allowed_a is False
|
assert allowed_a is False
|
||||||
|
|
||||||
# key-b should still be allowed
|
allowed_b, remaining, _ = limiter.is_allowed("key-b")
|
||||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
|
||||||
assert allowed_b is True
|
assert allowed_b is True
|
||||||
assert remaining == 1
|
assert remaining == 1
|
||||||
|
|
||||||
|
def test_resets_after_window_expires(self):
|
||||||
|
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
|
||||||
|
for _ in range(2):
|
||||||
|
limiter.is_allowed("test-key")
|
||||||
|
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
time.sleep(1.1)
|
||||||
|
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is True
|
||||||
|
assert remaining == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetClientIp:
|
||||||
|
def test_x_forwarded_for_single(self):
|
||||||
|
req = MagicMock()
|
||||||
|
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
||||||
|
req.client = None
|
||||||
|
assert _get_client_ip(req) == "192.168.1.1"
|
||||||
|
|
||||||
|
def test_x_forwarded_for_multiple(self):
|
||||||
|
req = MagicMock()
|
||||||
|
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
||||||
|
req.client = None
|
||||||
|
assert _get_client_ip(req) == "192.168.1.1"
|
||||||
|
|
||||||
|
def test_x_forwarded_for_with_port(self):
|
||||||
|
req = MagicMock()
|
||||||
|
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
||||||
|
req.client = None
|
||||||
|
assert _get_client_ip(req) == "192.168.1.1"
|
||||||
|
|
||||||
|
def test_no_forwarded_header(self):
|
||||||
|
req = MagicMock()
|
||||||
|
req.headers = {}
|
||||||
|
req.client.host = "127.0.0.1"
|
||||||
|
assert _get_client_ip(req) == "127.0.0.1"
|
||||||
|
|
||||||
|
def test_no_client(self):
|
||||||
|
req = MagicMock()
|
||||||
|
req.headers = {}
|
||||||
|
req.client = None
|
||||||
|
assert _get_client_ip(req) == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRateLimitKey:
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
path: str = "/purchases",
|
||||||
|
method: str = "GET",
|
||||||
|
auth_header: str = "",
|
||||||
|
headers: dict | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
req = MagicMock()
|
||||||
|
req.url.path = path
|
||||||
|
req.method = method
|
||||||
|
req.headers = dict(headers) if headers else {}
|
||||||
|
if auth_header:
|
||||||
|
req.headers["authorization"] = auth_header
|
||||||
|
return req
|
||||||
|
|
||||||
|
def test_public_path_uses_public_limiter(self):
|
||||||
|
req = self._make_request("/public/inflation")
|
||||||
|
key, limiter = _get_rate_limit_key(req)
|
||||||
|
assert key.startswith("ip:")
|
||||||
|
assert limiter.max_requests == settings.rate_limit_requests
|
||||||
|
|
||||||
|
def test_auth_post_path_uses_strict_limiter(self):
|
||||||
|
req = self._make_request("/auth/login", method="POST")
|
||||||
|
key, limiter = _get_rate_limit_key(req)
|
||||||
|
assert key.startswith("ip:")
|
||||||
|
assert limiter.max_requests == settings.rate_limit_auth_requests
|
||||||
|
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
|
||||||
|
|
||||||
|
def test_auth_get_path_uses_auth_limiter(self):
|
||||||
|
req = self._make_request("/auth/me", method="GET")
|
||||||
|
key, limiter = _get_rate_limit_key(req)
|
||||||
|
assert key.startswith("ip:")
|
||||||
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||||
|
|
||||||
|
def test_authenticated_token_uses_auth_limiter(self):
|
||||||
|
req = self._make_request("/purchases", auth_header="Bearer token123")
|
||||||
|
key, limiter = _get_rate_limit_key(req)
|
||||||
|
assert key.startswith("token:")
|
||||||
|
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||||
|
|
||||||
|
def test_distinct_tokens_produce_distinct_keys(self):
|
||||||
|
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
|
||||||
|
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
|
||||||
|
key1, _ = _get_rate_limit_key(req1)
|
||||||
|
key2, _ = _get_rate_limit_key(req2)
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
def test_same_token_produces_same_key(self):
|
||||||
|
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||||
|
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||||
|
key1, _ = _get_rate_limit_key(req1)
|
||||||
|
key2, _ = _get_rate_limit_key(req2)
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_key_does_not_contain_raw_token_suffix(self):
|
||||||
|
raw_token = "my_secret_jwt_token_xyz"
|
||||||
|
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
|
||||||
|
key, _ = _get_rate_limit_key(req)
|
||||||
|
assert raw_token[-16:] not in key
|
||||||
|
assert raw_token not in key
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisSlidingWindowFallback:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_on_redis_connection_error(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pipeline.return_value = AsyncMock()
|
||||||
|
pipe_mock = AsyncMock()
|
||||||
|
pipe_mock.execute.side_effect = Exception("Connection refused")
|
||||||
|
mock_redis.pipeline.return_value = pipe_mock
|
||||||
|
|
||||||
|
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
|
||||||
|
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||||
|
assert allowed is True
|
||||||
|
assert remaining == 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_on_redis_error_during_pipeline(self):
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
pipe_mock = AsyncMock()
|
||||||
|
pipe_mock.execute.side_effect = Exception("Redis error")
|
||||||
|
mock_redis.pipeline.return_value = pipe_mock
|
||||||
|
|
||||||
|
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
|
||||||
|
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rate_limit_returns_429(client):
|
async def test_rate_limit_returns_429(client):
|
||||||
"""Public endpoint should return 429 after limit exceeded."""
|
|
||||||
# The default limit is 60/min — we won't hit it in normal tests,
|
|
||||||
# but we verify the middleware adds rate limit headers.
|
|
||||||
resp = await client.get("/public/inflation")
|
resp = await client.get("/public/inflation")
|
||||||
assert "x-ratelimit-limit" in resp.headers
|
assert "x-ratelimit-limit" in resp.headers
|
||||||
assert "x-ratelimit-remaining" in resp.headers
|
assert "x-ratelimit-remaining" in resp.headers
|
||||||
@@ -51,36 +186,6 @@ async def test_rate_limit_returns_429(client):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_skips_rate_limit(client):
|
async def test_health_skips_rate_limit(client):
|
||||||
"""Health endpoint should not have rate limit headers."""
|
|
||||||
resp = await client.get("/health")
|
resp = await client.get("/health")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "x-ratelimit-limit" not in resp.headers
|
assert "x-ratelimit-limit" not in resp.headers
|
||||||
|
|
||||||
|
|
||||||
class TestGetRateLimitKey:
|
|
||||||
def _make_request(self, auth_header: str = "") -> MagicMock:
|
|
||||||
req = MagicMock()
|
|
||||||
req.url.path = "/purchases"
|
|
||||||
req.headers = {"authorization": auth_header} if auth_header else {}
|
|
||||||
return req
|
|
||||||
|
|
||||||
def test_distinct_tokens_produce_distinct_keys(self):
|
|
||||||
req1 = self._make_request("Bearer token_alpha_12345")
|
|
||||||
req2 = self._make_request("Bearer token_beta_67890")
|
|
||||||
key1, _ = _get_rate_limit_key(req1)
|
|
||||||
key2, _ = _get_rate_limit_key(req2)
|
|
||||||
assert key1 != key2
|
|
||||||
|
|
||||||
def test_same_token_produces_same_key(self):
|
|
||||||
req1 = self._make_request("Bearer same_token_value_abc")
|
|
||||||
req2 = self._make_request("Bearer same_token_value_abc")
|
|
||||||
key1, _ = _get_rate_limit_key(req1)
|
|
||||||
key2, _ = _get_rate_limit_key(req2)
|
|
||||||
assert key1 == key2
|
|
||||||
|
|
||||||
def test_key_does_not_contain_raw_token_suffix(self):
|
|
||||||
raw_token = "my_secret_jwt_token_xyz"
|
|
||||||
req = self._make_request(f"Bearer {raw_token}")
|
|
||||||
key, _ = _get_rate_limit_key(req)
|
|
||||||
assert raw_token[-16:] not in key
|
|
||||||
assert raw_token not in key
|
|
||||||
|
|||||||
Reference in New Issue
Block a user