Merge pull request #212 from cartsnitch/dev

Promote to UAT: input validation + audit logging (PR #171, #183)
This commit is contained in:
cartsnitch-cto[bot]
2026-04-15 03:30:04 +00:00
committed by GitHub
6 changed files with 222 additions and 30 deletions
+6 -2
View File
@@ -69,7 +69,9 @@ async def get_current_user(
token: str | None = None
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME)
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(
SESSION_COOKIE_NAME
)
if cookie_token:
# Better-Auth cookie format is "token.sessionId" — extract just the token part
token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token
@@ -86,7 +88,9 @@ async def get_current_user(
detail="Authentication required",
)
return await _validate_session_token(token, db)
user_id = await _validate_session_token(token, db)
request.state.user_id = user_id
return user_id
async def verify_service_key(x_service_key: str = Header()) -> None:
+2
View File
@@ -10,6 +10,7 @@ from cartsnitch_api.database import dispose_engine
from cartsnitch_api.middleware.cors import add_cors_middleware
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
from cartsnitch_api.middleware.audit import add_audit_middleware
from cartsnitch_api.routes.alerts import router as alerts_router
from cartsnitch_api.routes.coupons import router as coupons_router
from cartsnitch_api.routes.health import router as health_router
@@ -43,6 +44,7 @@ def create_app() -> FastAPI:
add_cors_middleware(app)
add_error_monitor_middleware(app)
add_rate_limit_middleware(app)
add_audit_middleware(app)
# Exception handlers
add_error_handlers(app)
+64
View File
@@ -0,0 +1,64 @@
"""Audit logging middleware for sensitive API operations.
Logs structured JSON for POST/PUT/PATCH/DELETE requests and GET /auth/me.
Never logs request bodies, response bodies, Authorization headers, or cookie values.
"""
import json
import logging
import time
from collections.abc import Awaitable, Callable
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger("cartsnitch_api.audit")
HEALTH_PATHS = {"/health", "/healthz", "/ready"}
class AuditMiddleware(BaseHTTPMiddleware):
"""Middleware to log structured audit events for sensitive operations."""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable],
):
if request.method == "OPTIONS" or request.url.path in HEALTH_PATHS:
return await call_next(request)
method = request.method
path = request.url.path
is_sensitive_write = method in {"POST", "PUT", "PATCH", "DELETE"}
is_auth_me_read = method == "GET" and path == "/auth/me"
if not (is_sensitive_write or is_auth_me_read):
return await call_next(request)
start = time.perf_counter()
response = await call_next(request)
duration_ms = (time.perf_counter() - start) * 1000
user_id = getattr(request.state, "user_id", None)
client_ip = request.client.host if request.client else "unknown"
log_entry = {
"event": "audit",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"user_id": user_id,
"method": method,
"path": path,
"client_ip": client_ip,
"status_code": response.status_code,
"duration_ms": round(duration_ms, 2),
}
logger.info(json.dumps(log_entry))
return response
def add_audit_middleware(app: FastAPI) -> None:
app.add_middleware(AuditMiddleware)
+14 -5
View File
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
async def public_price_trend(
product_id: UUID,
days: int = Query(90, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db)
try:
return await svc.get_trend(product_id)
return await svc.get_trend(product_id, days=days)
except LookupError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
async def public_store_comparison(
product_ids: Annotated[list[UUID], Query(max_length=20)],
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
db: AsyncSession = Depends(get_db),
):
if not product_ids:
@@ -39,10 +44,14 @@ async def public_store_comparison(
detail="At least one product_id is required",
)
svc = PublicService(db)
return await svc.get_store_comparison(product_ids)
return await svc.get_store_comparison(product_ids, category=category)
@router.get("/inflation", response_model=PublicInflationResponse)
async def public_inflation(db: AsyncSession = Depends(get_db)):
async def public_inflation(
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
db: AsyncSession = Depends(get_db),
):
svc = PublicService(db)
return await svc.get_inflation()
return await svc.get_inflation(category=category, period=period)
+42 -23
View File
@@ -1,5 +1,6 @@
"""Public service — unauthenticated price transparency endpoints."""
from datetime import date, timedelta
from uuid import UUID
from sqlalchemy import and_, func, select
@@ -13,7 +14,7 @@ class PublicService:
def __init__(self, db: AsyncSession) -> None:
self.db = db
async def get_trend(self, product_id: UUID) -> dict:
async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
result = await self.db.execute(
@@ -23,9 +24,13 @@ class PublicService:
if not product:
raise LookupError("Product not found")
date_threshold = date.today() - timedelta(days=days)
prices_result = await self.db.execute(
select(PriceHistory)
.where(PriceHistory.normalized_product_id == product_id)
.where(
PriceHistory.normalized_product_id == product_id,
PriceHistory.observed_date >= date_threshold,
)
.options(selectinload(PriceHistory.store))
.order_by(PriceHistory.observed_date)
)
@@ -45,20 +50,25 @@ class PublicService:
],
}
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
async def get_store_comparison(
self, product_ids: list[UUID], category: str | None = None
) -> dict:
from cartsnitch_api.models import NormalizedProduct, PriceHistory
if not product_ids:
return {"products": []}
# Fetch all products in one query
prod_result = await self.db.execute(
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
)
product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
if category:
product_query = product_query.where(NormalizedProduct.category == category)
prod_result = await self.db.execute(product_query)
products_by_id = {p.id: p for p in prod_result.scalars().all()}
# Latest prices for all requested products in one query
subq = latest_price_per_store(product_ids)
if not products_by_id:
return {"products": []}
filtered_product_ids = list(products_by_id.keys())
subq = latest_price_per_store(filtered_product_ids)
prices_result = await self.db.execute(
select(PriceHistory)
.join(
@@ -69,18 +79,17 @@ class PublicService:
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
),
)
.where(PriceHistory.normalized_product_id.in_(product_ids))
.where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
.options(selectinload(PriceHistory.store))
)
all_prices = prices_result.scalars().all()
# Group by product
prices_by_product: dict[UUID, list] = {}
for ph in all_prices:
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
products = []
for pid in product_ids:
for pid in filtered_product_ids:
product = products_by_id.get(pid)
if not product:
continue
@@ -102,19 +111,29 @@ class PublicService:
return {"products": products}
async def get_inflation(self) -> dict:
async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
"""Aggregate price change stats. Compares average prices across periods."""
from cartsnitch_api.models import NormalizedProduct, PriceHistory
# Get average prices grouped by category for recent vs older data
result = await self.db.execute(
select(
NormalizedProduct.category,
func.avg(PriceHistory.regular_price),
)
.join(NormalizedProduct)
.group_by(NormalizedProduct.category)
)
date_threshold = None
if period != "all-time":
days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
days = days_map.get(period, 365)
date_threshold = date.today() - timedelta(days=days)
query = select(
NormalizedProduct.category,
func.avg(PriceHistory.regular_price),
).join(NormalizedProduct)
if category:
query = query.where(NormalizedProduct.category == category)
if date_threshold:
query = query.where(PriceHistory.observed_date >= date_threshold)
query = query.group_by(NormalizedProduct.category)
result = await self.db.execute(query)
categories = {}
for row in result.all():
cat, avg_price = row
@@ -122,7 +141,7 @@ class PublicService:
categories[cat] = float(avg_price) if avg_price else 0.0
return {
"period": "all-time",
"period": period,
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
"cpi_baseline": 100.0,
"categories": categories,
+94
View File
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
data = resp.json()
assert "categories" in data
assert "cartsnitch_index" in data
@pytest.mark.asyncio
async def test_trend_invalid_uuid(client):
resp = await client.get("/public/trends/not-a-uuid")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_zero(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=0")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_negative(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=-1")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_over_max(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=999")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_trend_days_valid(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}?days=30")
assert resp.status_code == 200
assert "product_name" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_empty_list(client):
resp = await client.get("/public/store-comparison")
assert resp.status_code == 400
assert "detail" in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_xss(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
)
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_store_comparison_category_sql_injection(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_invalid_period(client, public_data):
resp = await client.get("/public/inflation?period=10years")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()
@pytest.mark.asyncio
async def test_inflation_valid_periods(client, public_data):
for period in ["all-time", "1y", "6m", "3m", "1m"]:
resp = await client.get(f"/public/inflation?period={period}")
assert resp.status_code == 200, f"period={period} failed"
@pytest.mark.asyncio
async def test_inflation_category_too_long(client, public_data):
long_category = "x" * 200
resp = await client.get(f"/public/inflation?category={long_category}")
assert resp.status_code == 422
assert "detail" in resp.json()
assert "stack" not in resp.json()