diff --git a/src/cartsnitch_api/routes/public.py b/src/cartsnitch_api/routes/public.py index 5d0b87b..4b5c5dc 100644 --- a/src/cartsnitch_api/routes/public.py +++ b/src/cartsnitch_api/routes/public.py @@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"]) @router.get("/trends/{product_id}", response_model=PublicTrendResponse) -async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): +async def public_price_trend( + product_id: UUID, + days: int = Query(90, ge=1, le=365), + db: AsyncSession = Depends(get_db), +): svc = PublicService(db) try: - return await svc.get_trend(product_id) + return await svc.get_trend(product_id, days=days) except LookupError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" @@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db @router.get("/store-comparison", response_model=PublicStoreComparisonResponse) async def public_store_comparison( product_ids: Annotated[list[UUID], Query(max_length=20)], + category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"), db: AsyncSession = Depends(get_db), ): if not product_ids: @@ -39,10 +44,14 @@ async def public_store_comparison( detail="At least one product_id is required", ) svc = PublicService(db) - return await svc.get_store_comparison(product_ids) + return await svc.get_store_comparison(product_ids, category=category) @router.get("/inflation", response_model=PublicInflationResponse) -async def public_inflation(db: AsyncSession = Depends(get_db)): +async def public_inflation( + category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"), + period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"), + db: AsyncSession = Depends(get_db), +): svc = PublicService(db) - return await svc.get_inflation() + return await svc.get_inflation(category=category, period=period) diff --git a/src/cartsnitch_api/services/public.py b/src/cartsnitch_api/services/public.py index f1ccbeb..5ff5e8d 100644 --- a/src/cartsnitch_api/services/public.py +++ b/src/cartsnitch_api/services/public.py @@ -1,5 +1,6 @@ """Public service — unauthenticated price transparency endpoints.""" +from datetime import date, timedelta from uuid import UUID from sqlalchemy import and_, func, select @@ -13,7 +14,7 @@ class PublicService: def __init__(self, db: AsyncSession) -> None: self.db = db - async def get_trend(self, product_id: UUID) -> dict: + async def get_trend(self, product_id: UUID, days: int = 90) -> dict: from cartsnitch_api.models import NormalizedProduct, PriceHistory result = await self.db.execute( @@ -23,9 +24,13 @@ class PublicService: if not product: raise LookupError("Product not found") + date_threshold = date.today() - timedelta(days=days) prices_result = await self.db.execute( select(PriceHistory) - .where(PriceHistory.normalized_product_id == product_id) + .where( + PriceHistory.normalized_product_id == product_id, + PriceHistory.observed_date >= date_threshold, + ) .options(selectinload(PriceHistory.store)) .order_by(PriceHistory.observed_date) ) @@ -45,20 +50,25 @@ class PublicService: ], } - async def get_store_comparison(self, product_ids: list[UUID]) -> dict: + async def get_store_comparison( + self, product_ids: list[UUID], category: str | None = None + ) -> dict: from cartsnitch_api.models import NormalizedProduct, PriceHistory if not product_ids: return {"products": []} - # Fetch all products in one query - prod_result = await self.db.execute( - select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) - ) + product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) + if category: + product_query = product_query.where(NormalizedProduct.category == category) + prod_result = await self.db.execute(product_query) products_by_id = {p.id: p for p in prod_result.scalars().all()} - # Latest prices for all requested products in one query - subq = latest_price_per_store(product_ids) + if not products_by_id: + return {"products": []} + + filtered_product_ids = list(products_by_id.keys()) + subq = latest_price_per_store(filtered_product_ids) prices_result = await self.db.execute( select(PriceHistory) .join( @@ -69,18 +79,17 @@ class PublicService: PriceHistory.normalized_product_id == subq.c.normalized_product_id, ), ) - .where(PriceHistory.normalized_product_id.in_(product_ids)) + .where(PriceHistory.normalized_product_id.in_(filtered_product_ids)) .options(selectinload(PriceHistory.store)) ) all_prices = prices_result.scalars().all() - # Group by product prices_by_product: dict[UUID, list] = {} for ph in all_prices: prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) products = [] - for pid in product_ids: + for pid in filtered_product_ids: product = products_by_id.get(pid) if not product: continue @@ -102,19 +111,29 @@ class PublicService: return {"products": products} - async def get_inflation(self) -> dict: + async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict: """Aggregate price change stats. Compares average prices across periods.""" from cartsnitch_api.models import NormalizedProduct, PriceHistory - # Get average prices grouped by category for recent vs older data - result = await self.db.execute( - select( - NormalizedProduct.category, - func.avg(PriceHistory.regular_price), - ) - .join(NormalizedProduct) - .group_by(NormalizedProduct.category) - ) + date_threshold = None + if period != "all-time": + days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30} + days = days_map.get(period, 365) + date_threshold = date.today() - timedelta(days=days) + + query = select( + NormalizedProduct.category, + func.avg(PriceHistory.regular_price), + ).join(NormalizedProduct) + + if category: + query = query.where(NormalizedProduct.category == category) + if date_threshold: + query = query.where(PriceHistory.observed_date >= date_threshold) + + query = query.group_by(NormalizedProduct.category) + + result = await self.db.execute(query) categories = {} for row in result.all(): cat, avg_price = row @@ -122,7 +141,7 @@ class PublicService: categories[cat] = float(avg_price) if avg_price else 0.0 return { - "period": "all-time", + "period": period, "cartsnitch_index": sum(categories.values()) / max(len(categories), 1), "cpi_baseline": 100.0, "categories": categories, diff --git a/tests/test_routes/test_public.py b/tests/test_routes/test_public.py index 08a5d29..931bca5 100644 --- a/tests/test_routes/test_public.py +++ b/tests/test_routes/test_public.py @@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data): data = resp.json() assert "categories" in data assert "cartsnitch_index" in data + + +@pytest.mark.asyncio +async def test_trend_invalid_uuid(client): + resp = await client.get("/public/trends/not-a-uuid") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_trend_days_zero(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}?days=0") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_trend_days_negative(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}?days=-1") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_trend_days_over_max(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}?days=999") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_trend_days_valid(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/trends/{pid}?days=30") + assert resp.status_code == 200 + assert "product_name" in resp.json() + + +@pytest.mark.asyncio +async def test_store_comparison_empty_list(client): + resp = await client.get("/public/store-comparison") + assert resp.status_code == 400 + assert "detail" in resp.json() + + +@pytest.mark.asyncio +async def test_store_comparison_category_xss(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get( + f"/public/store-comparison?product_ids={pid}&category=" + ) + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_store_comparison_category_sql_injection(client, public_data): + pid = str(public_data["product"].id) + resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_inflation_invalid_period(client, public_data): + resp = await client.get("/public/inflation?period=10years") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json() + + +@pytest.mark.asyncio +async def test_inflation_valid_periods(client, public_data): + for period in ["all-time", "1y", "6m", "3m", "1m"]: + resp = await client.get(f"/public/inflation?period={period}") + assert resp.status_code == 200, f"period={period} failed" + + +@pytest.mark.asyncio +async def test_inflation_category_too_long(client, public_data): + long_category = "x" * 200 + resp = await client.get(f"/public/inflation?category={long_category}") + assert resp.status_code == 422 + assert "detail" in resp.json() + assert "stack" not in resp.json()