Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cfea2586cb | |||
| ee97f64db6 | |||
| 538a5f4f4d | |||
| 4485bf1d5e | |||
| f7bf767da5 | |||
| 2f1833e90d | |||
| b2725fd512 | |||
| 5532b43e38 |
@@ -53,6 +53,7 @@ def run_migrations_online() -> None:
|
||||
# checkfirst=True ensures this is a no-op on existing databases.
|
||||
try:
|
||||
Base.metadata.create_all(bind=connection, checkfirst=True)
|
||||
connection.commit()
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger("alembic.env").warning(
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Create domain tables (stores, purchases, coupons, etc.).
|
||||
|
||||
Revision ID: 008_create_domain_tables
|
||||
Revises: 007_bootstrap_users_table
|
||||
Create Date: 2026-04-04
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "008_create_domain_tables"
|
||||
down_revision = "007_bootstrap_users_table"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
# 1. stores
|
||||
if not inspector.has_table("stores"):
|
||||
op.create_table(
|
||||
"stores",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("slug", sa.String(20), nullable=False, unique=True),
|
||||
sa.Column("logo_url", sa.String(500), nullable=True),
|
||||
sa.Column("website_url", sa.String(500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 2. store_locations
|
||||
if not inspector.has_table("store_locations"):
|
||||
op.create_table(
|
||||
"store_locations",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
|
||||
sa.Column("address", sa.String(300), nullable=False),
|
||||
sa.Column("city", sa.String(100), nullable=False),
|
||||
sa.Column("state", sa.String(2), nullable=False),
|
||||
sa.Column("zip", sa.String(10), nullable=False),
|
||||
sa.Column("lat", sa.Float(), nullable=True),
|
||||
sa.Column("lng", sa.Float(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 3. normalized_products
|
||||
if not inspector.has_table("normalized_products"):
|
||||
op.create_table(
|
||||
"normalized_products",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("canonical_name", sa.String(300), nullable=False),
|
||||
sa.Column("category", sa.String(50), nullable=True),
|
||||
sa.Column("subcategory", sa.String(100), nullable=True),
|
||||
sa.Column("brand", sa.String(200), nullable=True),
|
||||
sa.Column("size", sa.String(50), nullable=True),
|
||||
sa.Column("size_unit", sa.String(10), nullable=True),
|
||||
sa.Column("upc_variants", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 4. purchases
|
||||
if not inspector.has_table("purchases"):
|
||||
op.create_table(
|
||||
"purchases",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False),
|
||||
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
|
||||
sa.Column("store_location_id", sa.Uuid(), sa.ForeignKey("store_locations.id"), nullable=True),
|
||||
sa.Column("receipt_id", sa.String(200), nullable=False),
|
||||
sa.Column("purchase_date", sa.Date(), nullable=False),
|
||||
sa.Column("total", sa.Numeric(10, 2), nullable=False),
|
||||
sa.Column("subtotal", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("tax", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("savings_total", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("source_url", sa.String(500), nullable=True),
|
||||
sa.Column("raw_data", sa.JSON(), nullable=True),
|
||||
sa.Column("ingested_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"),
|
||||
sa.Index("ix_purchases_user_store", "user_id", "store_id"),
|
||||
)
|
||||
|
||||
# 5. purchase_items
|
||||
if not inspector.has_table("purchase_items"):
|
||||
op.create_table(
|
||||
"purchase_items",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("purchase_id", sa.Uuid(), sa.ForeignKey("purchases.id"), nullable=False),
|
||||
sa.Column("product_name_raw", sa.String(300), nullable=False),
|
||||
sa.Column("upc", sa.String(20), nullable=True),
|
||||
sa.Column("quantity", sa.Numeric(10, 3), nullable=False),
|
||||
sa.Column("unit_price", sa.Numeric(10, 2), nullable=False),
|
||||
sa.Column("extended_price", sa.Numeric(10, 2), nullable=False),
|
||||
sa.Column("regular_price", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("sale_price", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("coupon_discount", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("loyalty_discount", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("category_raw", sa.String(100), nullable=True),
|
||||
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 6. coupons
|
||||
if not inspector.has_table("coupons"):
|
||||
op.create_table(
|
||||
"coupons",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
|
||||
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True),
|
||||
sa.Column("title", sa.String(300), nullable=False),
|
||||
sa.Column("description", sa.String(1000), nullable=True),
|
||||
sa.Column("discount_type", sa.String(20), nullable=False),
|
||||
sa.Column("discount_value", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("min_purchase", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("valid_from", sa.Date(), nullable=True),
|
||||
sa.Column("valid_to", sa.Date(), nullable=True),
|
||||
sa.Column("requires_clip", sa.Boolean(), server_default=text("false"), nullable=False),
|
||||
sa.Column("coupon_code", sa.String(100), nullable=True),
|
||||
sa.Column("source_url", sa.String(500), nullable=True),
|
||||
sa.Column("scraped_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 7. price_history
|
||||
if not inspector.has_table("price_history"):
|
||||
op.create_table(
|
||||
"price_history",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
|
||||
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
|
||||
sa.Column("observed_date", sa.Date(), nullable=False),
|
||||
sa.Column("regular_price", sa.Numeric(10, 2), nullable=False),
|
||||
sa.Column("sale_price", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("loyalty_price", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("coupon_price", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("source", sa.String(20), nullable=False),
|
||||
sa.Column("purchase_item_id", sa.Uuid(), sa.ForeignKey("purchase_items.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Index("ix_price_history_product_store_date", "normalized_product_id", "store_id", "observed_date"),
|
||||
)
|
||||
|
||||
# 8. shrinkflation_events
|
||||
if not inspector.has_table("shrinkflation_events"):
|
||||
op.create_table(
|
||||
"shrinkflation_events",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False),
|
||||
sa.Column("detected_date", sa.Date(), nullable=False),
|
||||
sa.Column("old_size", sa.String(50), nullable=False),
|
||||
sa.Column("new_size", sa.String(50), nullable=False),
|
||||
sa.Column("old_unit", sa.String(10), nullable=True),
|
||||
sa.Column("new_unit", sa.String(10), nullable=True),
|
||||
sa.Column("price_at_old_size", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("price_at_new_size", sa.Numeric(10, 2), nullable=True),
|
||||
sa.Column("confidence", sa.Numeric(3, 2), server_default=text("1.00"), nullable=False),
|
||||
sa.Column("notes", sa.String(1000), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 9. user_store_accounts
|
||||
if not inspector.has_table("user_store_accounts"):
|
||||
op.create_table(
|
||||
"user_store_accounts",
|
||||
sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False),
|
||||
sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False),
|
||||
sa.Column("session_data", sa.JSON(), nullable=True),
|
||||
sa.Column("session_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("status", sa.String(20), server_default=text("'active'"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint("user_id", "store_id", name="uq_user_store_account"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
if inspector.has_table("user_store_accounts"):
|
||||
op.drop_table("user_store_accounts")
|
||||
if inspector.has_table("shrinkflation_events"):
|
||||
op.drop_table("shrinkflation_events")
|
||||
if inspector.has_table("price_history"):
|
||||
op.drop_table("price_history")
|
||||
if inspector.has_table("coupons"):
|
||||
op.drop_table("coupons")
|
||||
if inspector.has_table("purchase_items"):
|
||||
op.drop_table("purchase_items")
|
||||
if inspector.has_table("purchases"):
|
||||
op.drop_table("purchases")
|
||||
if inspector.has_table("normalized_products"):
|
||||
op.drop_table("normalized_products")
|
||||
if inspector.has_table("store_locations"):
|
||||
op.drop_table("store_locations")
|
||||
if inspector.has_table("stores"):
|
||||
op.drop_table("stores")
|
||||
@@ -13,14 +13,13 @@ class Settings(BaseSettings):
|
||||
)
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
jwt_secret_key: str = "change-me-in-production"
|
||||
jwt_secret_key: str
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 15
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
service_key: str = "change-me-in-production"
|
||||
# Valid Fernet key for local dev — MUST be overridden in production
|
||||
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
service_key: str
|
||||
fernet_key: str
|
||||
|
||||
auth_service_url: str = "http://auth:3001"
|
||||
|
||||
@@ -35,9 +34,26 @@ class Settings(BaseSettings):
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_fernet_key(self):
|
||||
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||
def validate_secrets(self):
|
||||
if not self.jwt_secret_key or self.jwt_secret_key in self._PLACEHOLDER_VALUES:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_JWT_SECRET_KEY must be set to a secure value. "
|
||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
if not self.service_key or self.service_key in self._PLACEHOLDER_VALUES:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_SERVICE_KEY must be set to a secure value. "
|
||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
if not self.fernet_key or self.fernet_key in self._PLACEHOLDER_VALUES:
|
||||
raise ValueError(
|
||||
"CARTSNITCH_FERNET_KEY must be set to a valid Fernet key. "
|
||||
"Generate one with: python -c "
|
||||
"'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
|
||||
)
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||
if len(decoded) != 32:
|
||||
|
||||
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
async def public_price_trend(
|
||||
product_id: UUID,
|
||||
days: int = Query(90, ge=1, le=365),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PublicService(db)
|
||||
try:
|
||||
return await svc.get_trend(product_id)
|
||||
return await svc.get_trend(product_id, days=days)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
|
||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||
async def public_store_comparison(
|
||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not product_ids:
|
||||
@@ -39,10 +44,14 @@ async def public_store_comparison(
|
||||
detail="At least one product_id is required",
|
||||
)
|
||||
svc = PublicService(db)
|
||||
return await svc.get_store_comparison(product_ids)
|
||||
return await svc.get_store_comparison(product_ids, category=category)
|
||||
|
||||
|
||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||
async def public_inflation(
|
||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
||||
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PublicService(db)
|
||||
return await svc.get_inflation()
|
||||
return await svc.get_inflation(category=category, period=period)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Public service — unauthenticated price transparency endpoints."""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
@@ -13,7 +14,7 @@ class PublicService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trend(self, product_id: UUID) -> dict:
|
||||
async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
@@ -23,9 +24,13 @@ class PublicService:
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
date_threshold = date.today() - timedelta(days=days)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.where(
|
||||
PriceHistory.normalized_product_id == product_id,
|
||||
PriceHistory.observed_date >= date_threshold,
|
||||
)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
@@ -45,20 +50,25 @@ class PublicService:
|
||||
],
|
||||
}
|
||||
|
||||
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||
async def get_store_comparison(
|
||||
self, product_ids: list[UUID], category: str | None = None
|
||||
) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return {"products": []}
|
||||
|
||||
# Fetch all products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
if category:
|
||||
product_query = product_query.where(NormalizedProduct.category == category)
|
||||
prod_result = await self.db.execute(product_query)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
if not products_by_id:
|
||||
return {"products": []}
|
||||
|
||||
filtered_product_ids = list(products_by_id.keys())
|
||||
subq = latest_price_per_store(filtered_product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
@@ -69,18 +79,17 @@ class PublicService:
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
prices_by_product: dict[UUID, list] = {}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
products = []
|
||||
for pid in product_ids:
|
||||
for pid in filtered_product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
@@ -102,19 +111,29 @@ class PublicService:
|
||||
|
||||
return {"products": products}
|
||||
|
||||
async def get_inflation(self) -> dict:
|
||||
async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
|
||||
"""Aggregate price change stats. Compares average prices across periods."""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
# Get average prices grouped by category for recent vs older data
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
)
|
||||
.join(NormalizedProduct)
|
||||
.group_by(NormalizedProduct.category)
|
||||
)
|
||||
date_threshold = None
|
||||
if period != "all-time":
|
||||
days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
|
||||
days = days_map.get(period, 365)
|
||||
date_threshold = date.today() - timedelta(days=days)
|
||||
|
||||
query = select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
).join(NormalizedProduct)
|
||||
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
if date_threshold:
|
||||
query = query.where(PriceHistory.observed_date >= date_threshold)
|
||||
|
||||
query = query.group_by(NormalizedProduct.category)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
categories = {}
|
||||
for row in result.all():
|
||||
cat, avg_price = row
|
||||
@@ -122,7 +141,7 @@ class PublicService:
|
||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||
|
||||
return {
|
||||
"period": "all-time",
|
||||
"period": period,
|
||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||
"cpi_baseline": 100.0,
|
||||
"categories": categories,
|
||||
|
||||
+34
-7
@@ -19,6 +19,25 @@ from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.main import create_app
|
||||
from cartsnitch_api.models import Base
|
||||
|
||||
TEST_JWT_SECRET = secrets.token_urlsafe(32)
|
||||
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
|
||||
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_settings():
|
||||
original_jwt = cartsnitch_settings.jwt_secret_key
|
||||
original_service = cartsnitch_settings.service_key
|
||||
original_fernet = cartsnitch_settings.fernet_key
|
||||
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
|
||||
cartsnitch_settings.service_key = TEST_SERVICE_KEY
|
||||
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
|
||||
yield
|
||||
cartsnitch_settings.jwt_secret_key = original_jwt
|
||||
cartsnitch_settings.service_key = original_service
|
||||
cartsnitch_settings.fernet_key = original_fernet
|
||||
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@@ -60,7 +79,8 @@ async def db_engine():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||
await conn.execute(text("""
|
||||
await conn.execute(
|
||||
text("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
@@ -71,8 +91,10 @@ async def db_engine():
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
""")
|
||||
)
|
||||
await conn.execute(
|
||||
text("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
@@ -88,8 +110,10 @@ async def db_engine():
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
await conn.execute(text("""
|
||||
""")
|
||||
)
|
||||
await conn.execute(
|
||||
text("""
|
||||
CREATE TABLE IF NOT EXISTS verifications (
|
||||
id TEXT PRIMARY KEY,
|
||||
identifier TEXT NOT NULL,
|
||||
@@ -98,7 +122,8 @@ async def db_engine():
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
)
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
yield engine
|
||||
|
||||
@@ -133,7 +158,9 @@ async def client(db_engine):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||
async def _create_test_user_and_session(
|
||||
client: AsyncClient, db_engine, **user_overrides
|
||||
) -> tuple[dict, str]:
|
||||
"""Create a test user and a valid session directly in the DB.
|
||||
|
||||
Returns (user_dict, session_token). Better-Auth stores the raw token
|
||||
|
||||
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
|
||||
data = resp.json()
|
||||
assert "categories" in data
|
||||
assert "cartsnitch_index" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_invalid_uuid(client):
|
||||
resp = await client.get("/public/trends/not-a-uuid")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_zero(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=0")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_negative(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=-1")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_over_max(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=999")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_valid(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=30")
|
||||
assert resp.status_code == 200
|
||||
assert "product_name" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_empty_list(client):
|
||||
resp = await client.get("/public/store-comparison")
|
||||
assert resp.status_code == 400
|
||||
assert "detail" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_category_xss(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(
|
||||
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_category_sql_injection(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_invalid_period(client, public_data):
|
||||
resp = await client.get("/public/inflation?period=10years")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_valid_periods(client, public_data):
|
||||
for period in ["all-time", "1y", "6m", "3m", "1m"]:
|
||||
resp = await client.get(f"/public/inflation?period={period}")
|
||||
assert resp.status_code == 200, f"period={period} failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_category_too_long(client, public_data):
|
||||
long_category = "x" * 200
|
||||
resp = await client.get(f"/public/inflation?category={long_category}")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
Reference in New Issue
Block a user