Compare commits

..

1 Commits

Author SHA1 Message Date
CartSnitch Engineer Bot cfea2586cb feat(api): add input validation on public endpoints
- Add days query param to GET /public/trends/{product_id} (ge=1, le=365)
- Add category query param to GET /public/store-comparison
- Add category and period query params to GET /public/inflation
- Add boundary and malicious input test cases

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-14 11:45:53 +00:00
9 changed files with 167 additions and 101 deletions
@@ -1,38 +0,0 @@
"""Add GIN index on upc_variants and alter column to JSONB.
Revision ID: 009_add_gin_index_upc_variants
Revises: 008_create_domain_tables
Create Date: 2026-04-14
"""
import sqlalchemy as sa
from alembic import op
revision = "009_add_gin_index_upc_variants"
down_revision = "008_create_domain_tables"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"normalized_products",
"upc_variants",
type_=sa.dialects.postgresql.JSONB(),
postgresql_using="upc_variants::jsonb",
)
op.create_index(
"ix_normalized_products_upc_variants_gin",
"normalized_products",
["upc_variants"],
postgresql_using="gin",
)
def downgrade() -> None:
op.drop_index("ix_normalized_products_upc_variants_gin", table_name="normalized_products")
op.alter_column(
"normalized_products",
"upc_variants",
type_=sa.JSON(),
)
+2 -2
View File
@@ -11,6 +11,6 @@ def add_cors_middleware(app: FastAPI) -> None:
CORSMiddleware, CORSMiddleware,
allow_origins=settings.cors_origins, allow_origins=settings.cors_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], allow_methods=["*"],
allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"], allow_headers=["*"],
) )
+14 -5
View File
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse) @router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)): async def public_price_trend(
product_id: UUID,
days: int = Query(90, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
try: try:
return await svc.get_trend(product_id) return await svc.get_trend(product_id, days=days)
except LookupError: except LookupError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found" status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse) @router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison( async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)], product_ids: Annotated[list[UUID], Query(max_length=20)],
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not product_ids: if not product_ids:
@@ -39,10 +44,14 @@ async def public_store_comparison(
detail="At least one product_id is required", detail="At least one product_id is required",
) )
svc = PublicService(db) svc = PublicService(db)
return await svc.get_store_comparison(product_ids) return await svc.get_store_comparison(product_ids, category=category)
@router.get("/inflation", response_model=PublicInflationResponse) @router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation(db: AsyncSession = Depends(get_db)): async def public_inflation(
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db) svc = PublicService(db)
return await svc.get_inflation() return await svc.get_inflation(category=category, period=period)
+42 -23
View File
@@ -1,5 +1,6 @@
"""Public service — unauthenticated price transparency endpoints.""" """Public service — unauthenticated price transparency endpoints."""
from datetime import date, timedelta
from uuid import UUID from uuid import UUID
from sqlalchemy import and_, func, select from sqlalchemy import and_, func, select
@@ -13,7 +14,7 @@ class PublicService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def get_trend(self, product_id: UUID) -> dict: async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute( result = await self.db.execute(
@@ -23,9 +24,13 @@ class PublicService:
if not product: if not product:
raise LookupError("Product not found") raise LookupError("Product not found")
date_threshold = date.today() - timedelta(days=days)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id) .where(
PriceHistory.normalized_product_id == product_id,
PriceHistory.observed_date >= date_threshold,
)
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date) .order_by(PriceHistory.observed_date)
) )
@@ -45,20 +50,25 @@ class PublicService:
], ],
} }
async def get_store_comparison(self, product_ids: list[UUID]) -> dict: async def get_store_comparison(
self, product_ids: list[UUID], category: str | None = None
) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids: if not product_ids:
return {"products": []} return {"products": []}
# Fetch all products in one query product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
prod_result = await self.db.execute( if category:
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids)) product_query = product_query.where(NormalizedProduct.category == category)
) prod_result = await self.db.execute(product_query)
products_by_id = {p.id: p for p in prod_result.scalars().all()} products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query if not products_by_id:
subq = latest_price_per_store(product_ids) return {"products": []}
filtered_product_ids = list(products_by_id.keys())
subq = latest_price_per_store(filtered_product_ids)
prices_result = await self.db.execute( prices_result = await self.db.execute(
select(PriceHistory) select(PriceHistory)
.join( .join(
@@ -69,18 +79,17 @@ class PublicService:
PriceHistory.normalized_product_id == subq.c.normalized_product_id, PriceHistory.normalized_product_id == subq.c.normalized_product_id,
), ),
) )
.where(PriceHistory.normalized_product_id.in_(product_ids)) .where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
.options(selectinload(PriceHistory.store)) .options(selectinload(PriceHistory.store))
) )
all_prices = prices_result.scalars().all() all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {} prices_by_product: dict[UUID, list] = {}
for ph in all_prices: for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph) prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = [] products = []
for pid in product_ids: for pid in filtered_product_ids:
product = products_by_id.get(pid) product = products_by_id.get(pid)
if not product: if not product:
continue continue
@@ -102,19 +111,29 @@ class PublicService:
return {"products": products} return {"products": products}
async def get_inflation(self) -> dict: async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
"""Aggregate price change stats. Compares average prices across periods.""" """Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory from cartsnitch_api.models import NormalizedProduct, PriceHistory
# Get average prices grouped by category for recent vs older data date_threshold = None
result = await self.db.execute( if period != "all-time":
select( days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
NormalizedProduct.category, days = days_map.get(period, 365)
func.avg(PriceHistory.regular_price), date_threshold = date.today() - timedelta(days=days)
)
.join(NormalizedProduct) query = select(
.group_by(NormalizedProduct.category) NormalizedProduct.category,
) func.avg(PriceHistory.regular_price),
).join(NormalizedProduct)
if category:
query = query.where(NormalizedProduct.category == category)
if date_threshold:
query = query.where(PriceHistory.observed_date >= date_threshold)
query = query.group_by(NormalizedProduct.category)
result = await self.db.execute(query)
categories = {} categories = {}
for row in result.all(): for row in result.all():
cat, avg_price = row cat, avg_price = row
@@ -122,7 +141,7 @@ class PublicService:
categories[cat] = float(avg_price) if avg_price else 0.0 categories[cat] = float(avg_price) if avg_price else 0.0
return { return {
"period": "all-time", "period": period,
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1), "cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0, "cpi_baseline": 100.0,
"categories": categories, "categories": categories,
+94
View File
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
data = resp.json() data = resp.json()
assert "categories" in data assert "categories" in data
assert "cartsnitch_index" in data assert "cartsnitch_index" in data
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@@ -3,7 +3,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import JSON, String from sqlalchemy import JSON, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_common.constants import ProductCategory, SizeUnit from cartsnitch_common.constants import ProductCategory, SizeUnit
@@ -27,9 +26,7 @@ class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
brand: Mapped[str | None] = mapped_column(String(200)) brand: Mapped[str | None] = mapped_column(String(200))
size: Mapped[str | None] = mapped_column(String(50)) size: Mapped[str | None] = mapped_column(String(50))
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
upc_variants: Mapped[list[str] | None] = mapped_column( upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
JSON().with_variant(JSONB(), "postgresql"), default=list
)
# Relationships # Relationships
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
-6
View File
@@ -9,12 +9,6 @@ server {
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript image/svg+xml; gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript image/svg+xml;
gzip_min_length 256; gzip_min_length 256;
# Security headers
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
add_header Content-Security-Policy "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self'; connect-src 'self' https://*.cartsnitch.com https://*.farh.net; frame-ancestors 'self'" always;
# Health endpoint for K8s probes # Health endpoint for K8s probes
location /health { location /health {
access_log off; access_log off;
+3 -3
View File
@@ -9805,9 +9805,9 @@
} }
}, },
"node_modules/vite": { "node_modules/vite": {
"version": "6.4.2", "version": "6.4.1",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.2.tgz", "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
"integrity": "sha512-2N/55r4JDJ4gdrCvGgINMy+HH3iRpNIz8K6SFwVsA+JbQScLiC+clmAxBgwiSPgcG9U15QmvqCGWzMbqda5zGQ==", "integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"devOptional": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
@@ -5,14 +5,12 @@ Matches products across retailers by:
2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) 2. Fuzzy name matching via token-based Jaccard similarity (lower confidence)
""" """
import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from cartsnitch_common.models.product import NormalizedProduct from cartsnitch_common.models.product import NormalizedProduct
from sqlalchemy import cast, func, select, String from sqlalchemy import select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -98,24 +96,17 @@ def jaccard_similarity(a: str, b: str) -> float:
def match_by_upc(session: Session, upc: str) -> MatchResult | None: def match_by_upc(session: Session, upc: str) -> MatchResult | None:
"""Find a normalized product by exact UPC match. """Find a normalized product by exact UPC match.
Uses PostgreSQL JSONB containment (@>) for production efficiency. Loads products with upc_variants and checks membership in Python
Falls back to LIKE on SQLite for test compatibility. for cross-database compatibility (works on both PostgreSQL and SQLite).
""" """
dialect_name = session.bind.dialect.name if session.bind else "default" # TODO: Use PostgreSQL JSON containment query (@>) for production.
if dialect_name == "postgresql": # Current approach loads all products into memory — acceptable for tests
stmt = select(NormalizedProduct).where( # and small datasets, but will not scale.
cast(NormalizedProduct.upc_variants, JSONB).op("@>")( stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None))
func.cast(json.dumps([upc]), JSONB) products = session.execute(stmt).scalars().all()
) for product in products:
) if product.upc_variants and upc in product.upc_variants:
else: return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
stmt = select(NormalizedProduct).where(
NormalizedProduct.upc_variants.is_not(None),
cast(NormalizedProduct.upc_variants, String).contains(upc),
)
product = session.execute(stmt).scalars().first()
if product:
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
return None return None