diff --git a/src/cartsnitch_api/middleware/rate_limit.py b/src/cartsnitch_api/middleware/rate_limit.py index 424ed19..319b363 100644 --- a/src/cartsnitch_api/middleware/rate_limit.py +++ b/src/cartsnitch_api/middleware/rate_limit.py @@ -4,6 +4,7 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. """ +import hashlib import time from collections import defaultdict from threading import Lock @@ -71,8 +72,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] - # Use last 16 chars of token as key to avoid storing full tokens - return f"token:{token[-16:]}", _auth_limiter + token_hash = hashlib.sha256(token.encode()).hexdigest() + return f"token:{token_hash}", _auth_limiter # Fallback to IP for unauthenticated non-public endpoints return f"ip:{_get_client_ip(request)}", _public_limiter diff --git a/tests/test_middleware/test_rate_limit.py b/tests/test_middleware/test_rate_limit.py index d5b7691..59386a1 100644 --- a/tests/test_middleware/test_rate_limit.py +++ b/tests/test_middleware/test_rate_limit.py @@ -1,8 +1,10 @@ """Tests for rate limiting middleware.""" +from unittest.mock import MagicMock + import pytest -from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter +from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key class TestSlidingWindowCounter: @@ -53,3 +55,32 @@ async def test_health_skips_rate_limit(client): resp = await client.get("/health") assert resp.status_code == 200 assert "x-ratelimit-limit" not in resp.headers + + +class TestGetRateLimitKey: + def _make_request(self, auth_header: str = "") -> MagicMock: + req = MagicMock() + req.url.path = "/purchases" + req.headers = {"authorization": auth_header} if auth_header else {} + return req + + def test_distinct_tokens_produce_distinct_keys(self): + req1 = self._make_request("Bearer token_alpha_12345") + req2 = self._make_request("Bearer token_beta_67890") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 != key2 + + def test_same_token_produces_same_key(self): + req1 = self._make_request("Bearer same_token_value_abc") + req2 = self._make_request("Bearer same_token_value_abc") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 == key2 + + def test_key_does_not_contain_raw_token_suffix(self): + raw_token = "my_secret_jwt_token_xyz" + req = self._make_request(f"Bearer {raw_token}") + key, _ = _get_rate_limit_key(req) + assert raw_token[-16:] not in key + assert raw_token not in key