forked from cartsnitch/cartsnitch
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1267caf43c | |||
| 015401861a | |||
| 9891e1aefb | |||
| 69ad161e36 | |||
| 485f890df3 | |||
| bf3ed0ede3 | |||
| 3f41eb7346 | |||
| 6cbd1ef298 | |||
| 94214f762e | |||
| 562c6ef6f6 | |||
| ccc8189d88 | |||
| 86594e4a8e | |||
| c2f1a83c1d | |||
| 6f8e5a9577 | |||
| bbfa816e57 | |||
| 5904eb03a2 | |||
| 87b6433ff7 | |||
| d7c9938f7e | |||
| 02434060ee |
@@ -13,13 +13,14 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
redis_url: str = "redis://localhost:6379/0"
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
jwt_secret_key: str
|
jwt_secret_key: str = "change-me-in-production"
|
||||||
jwt_algorithm: str = "HS256"
|
jwt_algorithm: str = "HS256"
|
||||||
jwt_access_token_expire_minutes: int = 15
|
jwt_access_token_expire_minutes: int = 15
|
||||||
jwt_refresh_token_expire_days: int = 7
|
jwt_refresh_token_expire_days: int = 7
|
||||||
|
|
||||||
service_key: str
|
service_key: str = "change-me-in-production"
|
||||||
fernet_key: str
|
# Valid Fernet key for local dev — MUST be overridden in production
|
||||||
|
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||||
|
|
||||||
auth_service_url: str = "http://auth:3001"
|
auth_service_url: str = "http://auth:3001"
|
||||||
|
|
||||||
@@ -34,26 +35,9 @@ class Settings(BaseSettings):
|
|||||||
rate_limit_window_seconds: int = 60
|
rate_limit_window_seconds: int = 60
|
||||||
rate_limit_enabled: bool = True
|
rate_limit_enabled: bool = True
|
||||||
|
|
||||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_secrets(self):
|
def validate_fernet_key(self):
|
||||||
if not self.jwt_secret_key or self.jwt_secret_key in self._PLACEHOLDER_VALUES:
|
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_JWT_SECRET_KEY must be set to a secure value. "
|
|
||||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
if not self.service_key or self.service_key in self._PLACEHOLDER_VALUES:
|
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_SERVICE_KEY must be set to a secure value. "
|
|
||||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
if not self.fernet_key or self.fernet_key in self._PLACEHOLDER_VALUES:
|
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_FERNET_KEY must be set to a valid Fernet key. "
|
|
||||||
"Generate one with: python -c "
|
|
||||||
"'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||||
if len(decoded) != 32:
|
if len(decoded) != 32:
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
|||||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
@@ -72,8 +71,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
|||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.startswith("Bearer "):
|
||||||
token = auth_header[7:]
|
token = auth_header[7:]
|
||||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
# Use last 16 chars of token as key to avoid storing full tokens
|
||||||
return f"token:{token_hash}", _auth_limiter
|
return f"token:{token[-16:]}", _auth_limiter
|
||||||
|
|
||||||
# Fallback to IP for unauthenticated non-public endpoints
|
# Fallback to IP for unauthenticated non-public endpoints
|
||||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||||
|
|||||||
+7
-34
@@ -19,25 +19,6 @@ from cartsnitch_api.database import get_db
|
|||||||
from cartsnitch_api.main import create_app
|
from cartsnitch_api.main import create_app
|
||||||
from cartsnitch_api.models import Base
|
from cartsnitch_api.models import Base
|
||||||
|
|
||||||
TEST_JWT_SECRET = secrets.token_urlsafe(32)
|
|
||||||
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
|
|
||||||
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_test_settings():
|
|
||||||
original_jwt = cartsnitch_settings.jwt_secret_key
|
|
||||||
original_service = cartsnitch_settings.service_key
|
|
||||||
original_fernet = cartsnitch_settings.fernet_key
|
|
||||||
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
|
|
||||||
cartsnitch_settings.service_key = TEST_SERVICE_KEY
|
|
||||||
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
|
|
||||||
yield
|
|
||||||
cartsnitch_settings.jwt_secret_key = original_jwt
|
|
||||||
cartsnitch_settings.service_key = original_service
|
|
||||||
cartsnitch_settings.fernet_key = original_fernet
|
|
||||||
|
|
||||||
|
|
||||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
@@ -79,8 +60,7 @@ async def db_engine():
|
|||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||||
await conn.execute(
|
await conn.execute(text("""
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sessions (
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
token TEXT NOT NULL UNIQUE,
|
token TEXT NOT NULL UNIQUE,
|
||||||
@@ -91,10 +71,8 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
await conn.execute(text("""
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
@@ -110,10 +88,8 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
await conn.execute(text("""
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS verifications (
|
CREATE TABLE IF NOT EXISTS verifications (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
identifier TEXT NOT NULL,
|
identifier TEXT NOT NULL,
|
||||||
@@ -122,8 +98,7 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
|
||||||
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
@@ -158,9 +133,7 @@ async def client(db_engine):
|
|||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
async def _create_test_user_and_session(
|
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||||
client: AsyncClient, db_engine, **user_overrides
|
|
||||||
) -> tuple[dict, str]:
|
|
||||||
"""Create a test user and a valid session directly in the DB.
|
"""Create a test user and a valid session directly in the DB.
|
||||||
|
|
||||||
Returns (user_dict, session_token). Better-Auth stores the raw token
|
Returns (user_dict, session_token). Better-Auth stores the raw token
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
"""Tests for rate limiting middleware."""
|
"""Tests for rate limiting middleware."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
|
||||||
|
|
||||||
|
|
||||||
class TestSlidingWindowCounter:
|
class TestSlidingWindowCounter:
|
||||||
@@ -55,32 +53,3 @@ async def test_health_skips_rate_limit(client):
|
|||||||
resp = await client.get("/health")
|
resp = await client.get("/health")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "x-ratelimit-limit" not in resp.headers
|
assert "x-ratelimit-limit" not in resp.headers
|
||||||
|
|
||||||
|
|
||||||
class TestGetRateLimitKey:
|
|
||||||
def _make_request(self, auth_header: str = "") -> MagicMock:
|
|
||||||
req = MagicMock()
|
|
||||||
req.url.path = "/purchases"
|
|
||||||
req.headers = {"authorization": auth_header} if auth_header else {}
|
|
||||||
return req
|
|
||||||
|
|
||||||
def test_distinct_tokens_produce_distinct_keys(self):
|
|
||||||
req1 = self._make_request("Bearer token_alpha_12345")
|
|
||||||
req2 = self._make_request("Bearer token_beta_67890")
|
|
||||||
key1, _ = _get_rate_limit_key(req1)
|
|
||||||
key2, _ = _get_rate_limit_key(req2)
|
|
||||||
assert key1 != key2
|
|
||||||
|
|
||||||
def test_same_token_produces_same_key(self):
|
|
||||||
req1 = self._make_request("Bearer same_token_value_abc")
|
|
||||||
req2 = self._make_request("Bearer same_token_value_abc")
|
|
||||||
key1, _ = _get_rate_limit_key(req1)
|
|
||||||
key2, _ = _get_rate_limit_key(req2)
|
|
||||||
assert key1 == key2
|
|
||||||
|
|
||||||
def test_key_does_not_contain_raw_token_suffix(self):
|
|
||||||
raw_token = "my_secret_jwt_token_xyz"
|
|
||||||
req = self._make_request(f"Bearer {raw_token}")
|
|
||||||
key, _ = _get_rate_limit_key(req)
|
|
||||||
assert raw_token[-16:] not in key
|
|
||||||
assert raw_token not in key
|
|
||||||
|
|||||||
Reference in New Issue
Block a user