diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index 7642deb..da68fe6 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -13,14 +13,13 @@ class Settings(BaseSettings): ) redis_url: str = "redis://localhost:6379/0" - jwt_secret_key: str = "change-me-in-production" + jwt_secret_key: str jwt_algorithm: str = "HS256" jwt_access_token_expire_minutes: int = 15 jwt_refresh_token_expire_days: int = 7 - service_key: str = "change-me-in-production" - # Valid Fernet key for local dev — MUST be overridden in production - fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + service_key: str + fernet_key: str auth_service_url: str = "http://auth:3001" @@ -35,9 +34,26 @@ class Settings(BaseSettings): rate_limit_window_seconds: int = 60 rate_limit_enabled: bool = True + _PLACEHOLDER_VALUES = {"change-me-in-production"} + @model_validator(mode="after") - def validate_fernet_key(self): - """Validate fernet_key is a valid 32-byte url-safe base64 key at startup.""" + def validate_secrets(self): + if not self.jwt_secret_key or self.jwt_secret_key in self._PLACEHOLDER_VALUES: + raise ValueError( + "CARTSNITCH_JWT_SECRET_KEY must be set to a secure value. " + 'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"' + ) + if not self.service_key or self.service_key in self._PLACEHOLDER_VALUES: + raise ValueError( + "CARTSNITCH_SERVICE_KEY must be set to a secure value. " + 'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"' + ) + if not self.fernet_key or self.fernet_key in self._PLACEHOLDER_VALUES: + raise ValueError( + "CARTSNITCH_FERNET_KEY must be set to a valid Fernet key. " + "Generate one with: python -c " + "'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'" + ) try: decoded = base64.urlsafe_b64decode(self.fernet_key.encode()) if len(decoded) != 32: diff --git a/src/cartsnitch_api/middleware/cors.py b/src/cartsnitch_api/middleware/cors.py index 0e6a4ae..3bba4af 100644 --- a/src/cartsnitch_api/middleware/cors.py +++ b/src/cartsnitch_api/middleware/cors.py @@ -11,6 +11,6 @@ def add_cors_middleware(app: FastAPI) -> None: CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"], ) diff --git a/src/cartsnitch_api/middleware/rate_limit.py b/src/cartsnitch_api/middleware/rate_limit.py index 424ed19..319b363 100644 --- a/src/cartsnitch_api/middleware/rate_limit.py +++ b/src/cartsnitch_api/middleware/rate_limit.py @@ -4,6 +4,7 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. """ +import hashlib import time from collections import defaultdict from threading import Lock @@ -71,8 +72,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]: auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] - # Use last 16 chars of token as key to avoid storing full tokens - return f"token:{token[-16:]}", _auth_limiter + token_hash = hashlib.sha256(token.encode()).hexdigest() + return f"token:{token_hash}", _auth_limiter # Fallback to IP for unauthenticated non-public endpoints return f"ip:{_get_client_ip(request)}", _public_limiter diff --git a/tests/conftest.py b/tests/conftest.py index bb84c20..b684a41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,25 @@ from cartsnitch_api.database import get_db from cartsnitch_api.main import create_app from cartsnitch_api.models import Base +TEST_JWT_SECRET = secrets.token_urlsafe(32) +TEST_SERVICE_KEY = secrets.token_urlsafe(32) +TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" + + +@pytest.fixture(autouse=True) +def setup_test_settings(): + original_jwt = cartsnitch_settings.jwt_secret_key + original_service = cartsnitch_settings.service_key + original_fernet = cartsnitch_settings.fernet_key + cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET + cartsnitch_settings.service_key = TEST_SERVICE_KEY + cartsnitch_settings.fernet_key = TEST_FERNET_KEY + yield + cartsnitch_settings.jwt_secret_key = original_jwt + cartsnitch_settings.service_key = original_service + cartsnitch_settings.fernet_key = original_fernet + + TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" @@ -60,7 +79,8 @@ async def db_engine(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) # Create Better-Auth tables (not managed by SQLAlchemy models) - await conn.execute(text(""" + await conn.execute( + text(""" CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, token TEXT NOT NULL UNIQUE, @@ -71,8 +91,10 @@ async def db_engine(): created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) - """)) - await conn.execute(text(""" + """) + ) + await conn.execute( + text(""" CREATE TABLE IF NOT EXISTS accounts ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, @@ -88,8 +110,10 @@ async def db_engine(): created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) - """)) - await conn.execute(text(""" + """) + ) + await conn.execute( + text(""" CREATE TABLE IF NOT EXISTS verifications ( id TEXT PRIMARY KEY, identifier TEXT NOT NULL, @@ -98,7 +122,8 @@ async def db_engine(): created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ) - """)) + """) + ) yield engine @@ -133,7 +158,9 @@ async def client(db_engine): app.dependency_overrides.clear() -async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]: +async def _create_test_user_and_session( + client: AsyncClient, db_engine, **user_overrides +) -> tuple[dict, str]: """Create a test user and a valid session directly in the DB. Returns (user_dict, session_token). Better-Auth stores the raw token diff --git a/tests/test_middleware/test_rate_limit.py b/tests/test_middleware/test_rate_limit.py index d5b7691..59386a1 100644 --- a/tests/test_middleware/test_rate_limit.py +++ b/tests/test_middleware/test_rate_limit.py @@ -1,8 +1,10 @@ """Tests for rate limiting middleware.""" +from unittest.mock import MagicMock + import pytest -from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter +from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key class TestSlidingWindowCounter: @@ -53,3 +55,32 @@ async def test_health_skips_rate_limit(client): resp = await client.get("/health") assert resp.status_code == 200 assert "x-ratelimit-limit" not in resp.headers + + +class TestGetRateLimitKey: + def _make_request(self, auth_header: str = "") -> MagicMock: + req = MagicMock() + req.url.path = "/purchases" + req.headers = {"authorization": auth_header} if auth_header else {} + return req + + def test_distinct_tokens_produce_distinct_keys(self): + req1 = self._make_request("Bearer token_alpha_12345") + req2 = self._make_request("Bearer token_beta_67890") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 != key2 + + def test_same_token_produces_same_key(self): + req1 = self._make_request("Bearer same_token_value_abc") + req2 = self._make_request("Bearer same_token_value_abc") + key1, _ = _get_rate_limit_key(req1) + key2, _ = _get_rate_limit_key(req2) + assert key1 == key2 + + def test_key_does_not_contain_raw_token_suffix(self): + raw_token = "my_secret_jwt_token_xyz" + req = self._make_request(f"Bearer {raw_token}") + key, _ = _get_rate_limit_key(req) + assert raw_token[-16:] not in key + assert raw_token not in key