feat(api): add input validation on public endpoints

- Add days query param to GET /public/trends/{product_id} (ge=1, le=365)
- Add category query param to GET /public/store-comparison
- Add category and period query params to GET /public/inflation
- Add boundary and malicious input test cases

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
CartSnitch Engineer Bot
2026-04-14 11:45:53 +00:00
parent 39dfacff86
commit ef4d0cc13f
3 changed files with 150 additions and 28 deletions
+14 -5
View File
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse) @router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): async def public_price_trend(
product_id: UUID,
days: int = Query(90, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
try: try:
return await svc.get_trend(product_id) return await svc.get_trend(product_id, days=days)
except LookupError: except LookupError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) @router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison( async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)], product_ids: Annotated[list[UUID], Query(max_length=20)],
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not product_ids: if not product_ids:
@@ -39,10 +44,14 @@ async def public_store_comparison(
detail="At least one product_id is required", detail="At least one product_id is required",
) )
svc = PublicService(db) svc = PublicService(db)
return await svc.get_store_comparison(product_ids) return await svc.get_store_comparison(product_ids, category=category)
@router.get("/inflation", response_model=PublicInflationResponse) @router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation(db: AsyncSession = Depends(get_db)): async def public_inflation(
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
return await svc.get_inflation() return await svc.get_inflation(category=category, period=period)
+42 -23
View File
@@ -1,5 +1,6 @@
"""Public service — unauthenticated price transparency endpoints.""" """Public service — unauthenticated price transparency endpoints."""
from datetime import date, timedelta
from uuid import UUID from uuid import UUID
from sqlalchemy import and_, func, select from sqlalchemy import and_, func, select
@@ -13,7 +14,7 @@ class PublicService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def get_trend(self, product_id: UUID) -> dict: async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute( result = await self.db.execute(
@@ -23,9 +24,13 @@ class PublicService:
if not product: if not product:
raise LookupError("Product not found") raise LookupError("Product not found")
date_threshold = date.today() - timedelta(days=days)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id) .where(
PriceHistory.normalized_product_id == product_id,
PriceHistory.observed_date >= date_threshold,
)
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date) .order_by(PriceHistory.observed_date)
) )
@@ -45,20 +50,25 @@ class PublicService:
], ],
} }
async def get_store_comparison(self, product_ids: list[UUID]) -> dict: async def get_store_comparison(
self, product_ids: list[UUID], category: str | None = None
) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids: if not product_ids:
return {"products": []} return {"products": []}
# Fetch all products in one query product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
prod_result = await self.db.execute( if category:
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) product_query = product_query.where(NormalizedProduct.category == category)
) prod_result = await self.db.execute(product_query)
products_by_id = {p.id: p for p in prod_result.scalars().all()} products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query if not products_by_id:
subq = latest_price_per_store(product_ids) return {"products": []}
filtered_product_ids = list(products_by_id.keys())
subq = latest_price_per_store(filtered_product_ids)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.join( .join(
@@ -69,18 +79,17 @@ class PublicService:
PriceHistory.normalized_product_id == subq.c.normalized_product_id, PriceHistory.normalized_product_id == subq.c.normalized_product_id,
), ),
) )
.where(PriceHistory.normalized_product_id.in_(product_ids)) .where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
) )
all_prices = prices_result.scalars().all() all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {} prices_by_product: dict[UUID, list] = {}
for ph in all_prices: for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = [] products = []
for pid in product_ids: for pid in filtered_product_ids:
product = products_by_id.get(pid) product = products_by_id.get(pid)
if not product: if not product:
continue continue
@@ -102,19 +111,29 @@ class PublicService:
return {"products": products} return {"products": products}
async def get_inflation(self) -> dict: async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
"""Aggregate price change stats. Compares average prices across periods.""" """Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
# Get average prices grouped by category for recent vs older data date_threshold = None
result = await self.db.execute( if period != "all-time":
select( days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
NormalizedProduct.category, days = days_map.get(period, 365)
func.avg(PriceHistory.regular_price), date_threshold = date.today() - timedelta(days=days)
)
.join(NormalizedProduct) query = select(
.group_by(NormalizedProduct.category) NormalizedProduct.category,
) func.avg(PriceHistory.regular_price),
).join(NormalizedProduct)
if category:
query = query.where(NormalizedProduct.category == category)
if date_threshold:
query = query.where(PriceHistory.observed_date >= date_threshold)
query = query.group_by(NormalizedProduct.category)
result = await self.db.execute(query)
categories = {} categories = {}
for row in result.all(): for row in result.all():
cat, avg_price = row cat, avg_price = row
@@ -122,7 +141,7 @@ class PublicService:
categories[cat] = float(avg_price) if avg_price else 0.0 categories[cat] = float(avg_price) if avg_price else 0.0
return { return {
"period": "all-time", "period": period,
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1), "cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0, "cpi_baseline": 100.0,
"categories": categories, "categories": categories,
+94
View File
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
data = resp.json() data = resp.json()
assert "categories" in data assert "categories" in data
assert "cartsnitch_index" in data assert "cartsnitch_index" in data
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()