Compare commits

..

1 Commits

Author SHA1 Message Date
CartSnitch Engineer Bot bc5e03e7a0 fix(security): use SHA-256 hash for rate limit key instead of token suffix
Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-14 11:36:17 +00:00
5 changed files with 63 additions and 153 deletions
@@ -4,6 +4,7 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
""" """
import hashlib
import time import time
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
@@ -71,8 +72,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
# Use last 16 chars of token as key to avoid storing full tokens token_hash = hashlib.sha256(token.encode()).hexdigest()
return f"token:{token[-16:]}", _auth_limiter return f"token:{token_hash}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints # Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
+5 -14
View File
@@ -18,14 +18,10 @@ router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse) @router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend( async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
product_id: UUID,
days: int = Query(90, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
try: try:
return await svc.get_trend(product_id, days=days) return await svc.get_trend(product_id)
except LookupError: except LookupError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
@@ -35,7 +31,6 @@ async def public_price_trend(
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) @router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison( async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)], product_ids: Annotated[list[UUID], Query(max_length=20)],
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not product_ids: if not product_ids:
@@ -44,14 +39,10 @@ async def public_store_comparison(
detail="At least one product_id is required", detail="At least one product_id is required",
) )
svc = PublicService(db) svc = PublicService(db)
return await svc.get_store_comparison(product_ids, category=category) return await svc.get_store_comparison(product_ids)
@router.get("/inflation", response_model=PublicInflationResponse) @router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation( async def public_inflation(db: AsyncSession = Depends(get_db)):
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
return await svc.get_inflation(category=category, period=period) return await svc.get_inflation()
+23 -42
View File
@@ -1,6 +1,5 @@
"""Public service — unauthenticated price transparency endpoints.""" """Public service — unauthenticated price transparency endpoints."""
from datetime import date, timedelta
from uuid import UUID from uuid import UUID
from sqlalchemy import and_, func, select from sqlalchemy import and_, func, select
@@ -14,7 +13,7 @@ class PublicService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def get_trend(self, product_id: UUID, days: int = 90) -> dict: async def get_trend(self, product_id: UUID) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute( result = await self.db.execute(
@@ -24,13 +23,9 @@ class PublicService:
if not product: if not product:
raise LookupError("Product not found") raise LookupError("Product not found")
date_threshold = date.today() - timedelta(days=days)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.where( .where(PriceHistory.normalized_product_id == product_id)
PriceHistory.normalized_product_id == product_id,
PriceHistory.observed_date >= date_threshold,
)
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date) .order_by(PriceHistory.observed_date)
) )
@@ -50,25 +45,20 @@ class PublicService:
], ],
} }
async def get_store_comparison( async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
self, product_ids: list[UUID], category: str | None = None
) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids: if not product_ids:
return {"products": []} return {"products": []}
product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) # Fetch all products in one query
if category: prod_result = await self.db.execute(
product_query = product_query.where(NormalizedProduct.category == category) select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
prod_result = await self.db.execute(product_query) )
products_by_id = {p.id: p for p in prod_result.scalars().all()} products_by_id = {p.id: p for p in prod_result.scalars().all()}
if not products_by_id: # Latest prices for all requested products in one query
return {"products": []} subq = latest_price_per_store(product_ids)
filtered_product_ids = list(products_by_id.keys())
subq = latest_price_per_store(filtered_product_ids)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.join( .join(
@@ -79,17 +69,18 @@ class PublicService:
PriceHistory.normalized_product_id == subq.c.normalized_product_id, PriceHistory.normalized_product_id == subq.c.normalized_product_id,
), ),
) )
.where(PriceHistory.normalized_product_id.in_(filtered_product_ids)) .where(PriceHistory.normalized_product_id.in_(product_ids))
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
) )
all_prices = prices_result.scalars().all() all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {} prices_by_product: dict[UUID, list] = {}
for ph in all_prices: for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = [] products = []
for pid in filtered_product_ids: for pid in product_ids:
product = products_by_id.get(pid) product = products_by_id.get(pid)
if not product: if not product:
continue continue
@@ -111,29 +102,19 @@ class PublicService:
return {"products": products} return {"products": products}
async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict: async def get_inflation(self) -> dict:
"""Aggregate price change stats. Compares average prices across periods.""" """Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
date_threshold = None # Get average prices grouped by category for recent vs older data
if period != "all-time": result = await self.db.execute(
days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30} select(
days = days_map.get(period, 365) NormalizedProduct.category,
date_threshold = date.today() - timedelta(days=days) func.avg(PriceHistory.regular_price),
)
query = select( .join(NormalizedProduct)
NormalizedProduct.category, .group_by(NormalizedProduct.category)
func.avg(PriceHistory.regular_price), )
).join(NormalizedProduct)
if category:
query = query.where(NormalizedProduct.category == category)
if date_threshold:
query = query.where(PriceHistory.observed_date >= date_threshold)
query = query.group_by(NormalizedProduct.category)
result = await self.db.execute(query)
categories = {} categories = {}
for row in result.all(): for row in result.all():
cat, avg_price = row cat, avg_price = row
@@ -141,7 +122,7 @@ class PublicService:
categories[cat] = float(avg_price) if avg_price else 0.0 categories[cat] = float(avg_price) if avg_price else 0.0
return { return {
"period": period, "period": "all-time",
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1), "cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0, "cpi_baseline": 100.0,
"categories": categories, "categories": categories,
+32 -1
View File
@@ -1,8 +1,10 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
from unittest.mock import MagicMock
import pytest import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
class TestSlidingWindowCounter: class TestSlidingWindowCounter:
@@ -53,3 +55,32 @@ async def test_health_skips_rate_limit(client):
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey:
def _make_request(self, auth_header: str = "") -> MagicMock:
req = MagicMock()
req.url.path = "/purchases"
req.headers = {"authorization": auth_header} if auth_header else {}
return req
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
-94
View File
@@ -71,97 +71,3 @@ async def test_public_inflation(client, public_data):
data = resp.json() data = resp.json()
assert "categories" in data assert "categories" in data
assert "cartsnitch_index" in data assert "cartsnitch_index" in data
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()