Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 82978f072b | |||
| c7b7494151 | |||
| ffdc26cce5 | |||
| 2e96e8f0a7 | |||
| 66279716ba | |||
| 15ab4ed38c | |||
| fbd77a9434 | |||
| fef5e86645 | |||
| cf39ed1dcd | |||
| 71e2978f52 | |||
| 4945ac71ae | |||
| 5308923136 | |||
| bdaca519f6 | |||
| 90e23ac592 | |||
| c03e599ae3 | |||
| 908ebde4c6 | |||
| 1ce5d738d1 | |||
| e69b3c47be | |||
| 4c217757c3 | |||
| 121dc5724e | |||
| ee45400c7c | |||
| 1aff898545 | |||
| 6b75d4906f | |||
| 24f0dd0e67 | |||
| cfea2586cb |
@@ -553,6 +553,7 @@ jobs:
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/dev/kustomization.yaml
|
||||
git diff --cached --quiet && echo "No image changes to deploy" && exit 0
|
||||
git commit -m "ci(dev): update cartsnitch, auth, receiptwitness, and api images"
|
||||
git pull --rebase origin main
|
||||
git push origin main
|
||||
@@ -651,6 +652,7 @@ jobs:
|
||||
git config user.name "cartsnitch-ci[bot]"
|
||||
git config user.email "cartsnitch-ci[bot]@users.noreply.github.com"
|
||||
git add apps/overlays/uat/kustomization.yaml
|
||||
git diff --cached --quiet && echo "No image changes to deploy" && exit 0
|
||||
git commit -m "ci(uat): update cartsnitch, auth, receiptwitness, and api images"
|
||||
git pull --rebase origin main
|
||||
git push origin main
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Add GIN index on upc_variants and alter column to JSONB.
|
||||
|
||||
Revision ID: 009_add_gin_index_upc_variants
|
||||
Revises: 008_create_domain_tables
|
||||
Create Date: 2026-04-14
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "009_add_gin_index_upc_variants"
|
||||
down_revision = "008_create_domain_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"normalized_products",
|
||||
"upc_variants",
|
||||
type_=sa.dialects.postgresql.JSONB(),
|
||||
postgresql_using="upc_variants::jsonb",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_normalized_products_upc_variants_gin",
|
||||
"normalized_products",
|
||||
["upc_variants"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_normalized_products_upc_variants_gin", table_name="normalized_products")
|
||||
op.alter_column(
|
||||
"normalized_products",
|
||||
"upc_variants",
|
||||
type_=sa.JSON(),
|
||||
)
|
||||
@@ -69,7 +69,9 @@ async def get_current_user(
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
|
||||
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME)
|
||||
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(
|
||||
SESSION_COOKIE_NAME
|
||||
)
|
||||
if cookie_token:
|
||||
# Better-Auth cookie format is "token.sessionId" — extract just the token part
|
||||
token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token
|
||||
@@ -86,7 +88,9 @@ async def get_current_user(
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
return await _validate_session_token(token, db)
|
||||
user_id = await _validate_session_token(token, db)
|
||||
request.state.user_id = user_id
|
||||
return user_id
|
||||
|
||||
|
||||
async def verify_service_key(x_service_key: str = Header()) -> None:
|
||||
|
||||
@@ -47,5 +47,30 @@ class CacheClient:
|
||||
return
|
||||
await self._client.delete(key)
|
||||
|
||||
async def invalidate_price_cache(self, product_id: str) -> None:
|
||||
"""Invalidate all price-related cache entries for a product."""
|
||||
if not self._client:
|
||||
return
|
||||
pattern = f"price:*:{product_id}"
|
||||
await self._delete_pattern(pattern)
|
||||
|
||||
async def invalidate_product_cache(self, product_id: str) -> None:
|
||||
"""Invalidate the product detail cache entry."""
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.delete(f"product:{product_id}")
|
||||
|
||||
async def _delete_pattern(self, pattern: str) -> None:
|
||||
"""Delete all keys matching a pattern using SCAN."""
|
||||
if not self._client:
|
||||
return
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor=cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await self._client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
|
||||
cache_client = CacheClient()
|
||||
|
||||
@@ -32,6 +32,9 @@ class Settings(BaseSettings):
|
||||
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
rate_limit_auth_requests: int = 5
|
||||
rate_limit_auth_window_seconds: int = 60
|
||||
rate_limit_redis_enabled: bool = True
|
||||
rate_limit_enabled: bool = True
|
||||
|
||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||
@@ -72,7 +75,9 @@ class Settings(BaseSettings):
|
||||
def normalize_database_url(self):
|
||||
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
|
||||
if self.database_url.startswith("postgresql://"):
|
||||
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
self.database_url = self.database_url.replace(
|
||||
"postgresql://", "postgresql+asyncpg://", 1
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from cartsnitch_api.database import dispose_engine
|
||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
|
||||
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
|
||||
from cartsnitch_api.middleware.audit import add_audit_middleware
|
||||
from cartsnitch_api.routes.alerts import router as alerts_router
|
||||
from cartsnitch_api.routes.coupons import router as coupons_router
|
||||
from cartsnitch_api.routes.health import router as health_router
|
||||
@@ -43,6 +44,7 @@ def create_app() -> FastAPI:
|
||||
add_cors_middleware(app)
|
||||
add_error_monitor_middleware(app)
|
||||
add_rate_limit_middleware(app)
|
||||
add_audit_middleware(app)
|
||||
|
||||
# Exception handlers
|
||||
add_error_handlers(app)
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Audit logging middleware for sensitive API operations.
|
||||
|
||||
Logs structured JSON for POST/PUT/PATCH/DELETE requests and GET /auth/me.
|
||||
Never logs request bodies, response bodies, Authorization headers, or cookie values.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger("cartsnitch_api.audit")
|
||||
|
||||
HEALTH_PATHS = {"/health", "/healthz", "/ready"}
|
||||
|
||||
|
||||
class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to log structured audit events for sensitive operations."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable],
|
||||
):
|
||||
if request.method == "OPTIONS" or request.url.path in HEALTH_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
method = request.method
|
||||
path = request.url.path
|
||||
|
||||
is_sensitive_write = method in {"POST", "PUT", "PATCH", "DELETE"}
|
||||
is_auth_me_read = method == "GET" and path == "/auth/me"
|
||||
|
||||
if not (is_sensitive_write or is_auth_me_read):
|
||||
return await call_next(request)
|
||||
|
||||
start = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
duration_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
log_entry = {
|
||||
"event": "audit",
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"user_id": user_id,
|
||||
"method": method,
|
||||
"path": path,
|
||||
"client_ip": client_ip,
|
||||
"status_code": response.status_code,
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
}
|
||||
|
||||
logger.info(json.dumps(log_entry))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def add_audit_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(AuditMiddleware)
|
||||
@@ -5,18 +5,31 @@ Per-IP limiting on public endpoints, per-token limiting on authenticated endpoin
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis, RedisError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
|
||||
class RateLimitBackend(Protocol):
|
||||
"""Protocol for rate limit backends."""
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
|
||||
|
||||
class InMemorySlidingWindow:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
@@ -25,13 +38,12 @@ class _SlidingWindowCounter:
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
@@ -44,15 +56,84 @@ class _SlidingWindowCounter:
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
class RedisSlidingWindow:
|
||||
"""Redis-backed sliding window rate limiter using sorted sets."""
|
||||
|
||||
def __init__(self, redis: Redis, max_requests: int, window_seconds: int) -> None:
|
||||
self.redis = redis
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
async def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
try:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
now_ms = int(now * 1000)
|
||||
cutoff_ms = int(cutoff * 1000)
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, cutoff_ms)
|
||||
pipe.zcard(key)
|
||||
results = await pipe.execute()
|
||||
|
||||
current_count = results[1]
|
||||
|
||||
if current_count >= self.max_requests:
|
||||
oldest = await self.redis.zrange(key, 0, 0, withscores=True)
|
||||
if oldest:
|
||||
retry_after = int((oldest[0][1] - cutoff) / 1000) + 1
|
||||
else:
|
||||
retry_after = self.window_seconds
|
||||
return False, 0, retry_after
|
||||
|
||||
member = f"{now_ms}:{uuid.uuid4().hex[:8]}"
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zadd(key, {member: now_ms})
|
||||
pipe.expire(key, self.window_seconds)
|
||||
await pipe.execute()
|
||||
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
except RedisError as e:
|
||||
logger.warning("Redis rate limit error, falling back to in-memory: %s", e)
|
||||
in_memory = InMemorySlidingWindow(self.max_requests, self.window_seconds)
|
||||
return await in_memory.is_allowed(key)
|
||||
|
||||
|
||||
_redis_client: Redis | None = None
|
||||
_use_redis = False
|
||||
|
||||
if settings.rate_limit_redis_enabled:
|
||||
try:
|
||||
_redis_client = Redis.from_url(settings.redis_url)
|
||||
_use_redis = True
|
||||
logger.info("Rate limiting will use Redis at %s", settings.redis_url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e)
|
||||
_use_redis = False
|
||||
|
||||
if _use_redis and _redis_client:
|
||||
_public_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = RedisSlidingWindow(
|
||||
_redis_client, settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
else:
|
||||
_public_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_requests * 5, settings.rate_limit_window_seconds
|
||||
)
|
||||
_auth_strict_limiter = InMemorySlidingWindow(
|
||||
settings.rate_limit_auth_requests, settings.rate_limit_auth_window_seconds
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
@@ -63,30 +144,30 @@ def _get_client_ip(request: Request) -> str:
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, RateLimitBackend]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
if request.url.path.startswith("/auth/") and request.method == "POST":
|
||||
return f"ip:{_get_client_ip(request)}", _auth_strict_limiter
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
return f"token:{token_hash}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
allowed, remaining, retry_after = await limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
|
||||
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
async def public_price_trend(
|
||||
product_id: UUID,
|
||||
days: int = Query(90, ge=1, le=365),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PublicService(db)
|
||||
try:
|
||||
return await svc.get_trend(product_id)
|
||||
return await svc.get_trend(product_id, days=days)
|
||||
except LookupError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||
@@ -31,6 +35,7 @@ async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db
|
||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||
async def public_store_comparison(
|
||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not product_ids:
|
||||
@@ -39,10 +44,14 @@ async def public_store_comparison(
|
||||
detail="At least one product_id is required",
|
||||
)
|
||||
svc = PublicService(db)
|
||||
return await svc.get_store_comparison(product_ids)
|
||||
return await svc.get_store_comparison(product_ids, category=category)
|
||||
|
||||
|
||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||
async def public_inflation(
|
||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
||||
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PublicService(db)
|
||||
return await svc.get_inflation()
|
||||
return await svc.get_inflation(category=category, period=period)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Public service — unauthenticated price transparency endpoints."""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
@@ -13,7 +14,7 @@ class PublicService:
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self.db = db
|
||||
|
||||
async def get_trend(self, product_id: UUID) -> dict:
|
||||
async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
result = await self.db.execute(
|
||||
@@ -23,9 +24,13 @@ class PublicService:
|
||||
if not product:
|
||||
raise LookupError("Product not found")
|
||||
|
||||
date_threshold = date.today() - timedelta(days=days)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.where(PriceHistory.normalized_product_id == product_id)
|
||||
.where(
|
||||
PriceHistory.normalized_product_id == product_id,
|
||||
PriceHistory.observed_date >= date_threshold,
|
||||
)
|
||||
.options(selectinload(PriceHistory.store))
|
||||
.order_by(PriceHistory.observed_date)
|
||||
)
|
||||
@@ -45,20 +50,25 @@ class PublicService:
|
||||
],
|
||||
}
|
||||
|
||||
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||
async def get_store_comparison(
|
||||
self, product_ids: list[UUID], category: str | None = None
|
||||
) -> dict:
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
if not product_ids:
|
||||
return {"products": []}
|
||||
|
||||
# Fetch all products in one query
|
||||
prod_result = await self.db.execute(
|
||||
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
)
|
||||
product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||
if category:
|
||||
product_query = product_query.where(NormalizedProduct.category == category)
|
||||
prod_result = await self.db.execute(product_query)
|
||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||
|
||||
# Latest prices for all requested products in one query
|
||||
subq = latest_price_per_store(product_ids)
|
||||
if not products_by_id:
|
||||
return {"products": []}
|
||||
|
||||
filtered_product_ids = list(products_by_id.keys())
|
||||
subq = latest_price_per_store(filtered_product_ids)
|
||||
prices_result = await self.db.execute(
|
||||
select(PriceHistory)
|
||||
.join(
|
||||
@@ -69,18 +79,17 @@ class PublicService:
|
||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||
),
|
||||
)
|
||||
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||
.where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
|
||||
.options(selectinload(PriceHistory.store))
|
||||
)
|
||||
all_prices = prices_result.scalars().all()
|
||||
|
||||
# Group by product
|
||||
prices_by_product: dict[UUID, list] = {}
|
||||
for ph in all_prices:
|
||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||
|
||||
products = []
|
||||
for pid in product_ids:
|
||||
for pid in filtered_product_ids:
|
||||
product = products_by_id.get(pid)
|
||||
if not product:
|
||||
continue
|
||||
@@ -102,19 +111,29 @@ class PublicService:
|
||||
|
||||
return {"products": products}
|
||||
|
||||
async def get_inflation(self) -> dict:
|
||||
async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
|
||||
"""Aggregate price change stats. Compares average prices across periods."""
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||
|
||||
# Get average prices grouped by category for recent vs older data
|
||||
result = await self.db.execute(
|
||||
select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
)
|
||||
.join(NormalizedProduct)
|
||||
.group_by(NormalizedProduct.category)
|
||||
)
|
||||
date_threshold = None
|
||||
if period != "all-time":
|
||||
days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
|
||||
days = days_map.get(period, 365)
|
||||
date_threshold = date.today() - timedelta(days=days)
|
||||
|
||||
query = select(
|
||||
NormalizedProduct.category,
|
||||
func.avg(PriceHistory.regular_price),
|
||||
).join(NormalizedProduct)
|
||||
|
||||
if category:
|
||||
query = query.where(NormalizedProduct.category == category)
|
||||
if date_threshold:
|
||||
query = query.where(PriceHistory.observed_date >= date_threshold)
|
||||
|
||||
query = query.group_by(NormalizedProduct.category)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
categories = {}
|
||||
for row in result.all():
|
||||
cat, avg_price = row
|
||||
@@ -122,7 +141,7 @@ class PublicService:
|
||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||
|
||||
return {
|
||||
"period": "all-time",
|
||||
"period": period,
|
||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||
"cpi_baseline": 100.0,
|
||||
"categories": categories,
|
||||
|
||||
@@ -1,49 +1,184 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.middleware.rate_limit import (
|
||||
InMemorySlidingWindow,
|
||||
RedisSlidingWindow,
|
||||
_get_client_ip,
|
||||
_get_rate_limit_key,
|
||||
)
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
class TestInMemorySlidingWindow:
|
||||
def test_allows_within_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
||||
limiter = InMemorySlidingWindow(max_requests=5, window_seconds=60)
|
||||
for i in range(5):
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4 - i
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
||||
limiter = InMemorySlidingWindow(max_requests=3, window_seconds=60)
|
||||
for _ in range(3):
|
||||
counter.is_allowed("test-key")
|
||||
limiter.is_allowed("test-key")
|
||||
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
allowed, remaining, retry = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert remaining == 0
|
||||
assert retry > 0
|
||||
|
||||
def test_separate_keys(self):
|
||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
||||
# Fill key-a
|
||||
counter.is_allowed("key-a")
|
||||
counter.is_allowed("key-a")
|
||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
||||
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=60)
|
||||
limiter.is_allowed("key-a")
|
||||
limiter.is_allowed("key-a")
|
||||
allowed_a, _, _ = limiter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
||||
allowed_b, remaining, _ = limiter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
def test_resets_after_window_expires(self):
|
||||
limiter = InMemorySlidingWindow(max_requests=2, window_seconds=1)
|
||||
for _ in range(2):
|
||||
limiter.is_allowed("test-key")
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
|
||||
time.sleep(1.1)
|
||||
allowed, remaining, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def test_x_forwarded_for_single(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1, 10.0.0.1, 172.16.0.1"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_x_forwarded_for_with_port(self):
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "192.168.1.1:8080"}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "192.168.1.1"
|
||||
|
||||
def test_no_forwarded_header(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client.host = "127.0.0.1"
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
def test_no_client(self):
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = None
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(
|
||||
self,
|
||||
path: str = "/purchases",
|
||||
method: str = "GET",
|
||||
auth_header: str = "",
|
||||
headers: dict | None = None,
|
||||
) -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.url.path = path
|
||||
req.method = method
|
||||
req.headers = dict(headers) if headers else {}
|
||||
if auth_header:
|
||||
req.headers["authorization"] = auth_header
|
||||
return req
|
||||
|
||||
def test_public_path_uses_public_limiter(self):
|
||||
req = self._make_request("/public/inflation")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests
|
||||
|
||||
def test_auth_post_path_uses_strict_limiter(self):
|
||||
req = self._make_request("/auth/login", method="POST")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_auth_requests
|
||||
assert limiter.window_seconds == settings.rate_limit_auth_window_seconds
|
||||
|
||||
def test_auth_get_path_uses_auth_limiter(self):
|
||||
req = self._make_request("/auth/me", method="GET")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("ip:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||
|
||||
def test_authenticated_token_uses_auth_limiter(self):
|
||||
req = self._make_request("/purchases", auth_header="Bearer token123")
|
||||
key, limiter = _get_rate_limit_key(req)
|
||||
assert key.startswith("token:")
|
||||
assert limiter.max_requests == settings.rate_limit_requests * 5
|
||||
|
||||
def test_distinct_tokens_produce_distinct_keys(self):
|
||||
req1 = self._make_request("/purchases", auth_header="Bearer token_alpha_12345")
|
||||
req2 = self._make_request("/purchases", auth_header="Bearer token_beta_67890")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_token_produces_same_key(self):
|
||||
req1 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||
req2 = self._make_request("/purchases", auth_header="Bearer same_token_value_abc")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 == key2
|
||||
|
||||
def test_key_does_not_contain_raw_token_suffix(self):
|
||||
raw_token = "my_secret_jwt_token_xyz"
|
||||
req = self._make_request("/purchases", auth_header=f"Bearer {raw_token}")
|
||||
key, _ = _get_rate_limit_key(req)
|
||||
assert raw_token[-16:] not in key
|
||||
assert raw_token not in key
|
||||
|
||||
|
||||
class TestRedisSlidingWindowFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_connection_error(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline.return_value = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Connection refused")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=5, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_error_during_pipeline(self):
|
||||
mock_redis = AsyncMock()
|
||||
pipe_mock = AsyncMock()
|
||||
pipe_mock.execute.side_effect = Exception("Redis error")
|
||||
mock_redis.pipeline.return_value = pipe_mock
|
||||
|
||||
limiter = RedisSlidingWindow(mock_redis, max_requests=3, window_seconds=60)
|
||||
allowed, remaining, retry = await limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
"""Public endpoint should return 429 after limit exceeded."""
|
||||
# The default limit is 60/min — we won't hit it in normal tests,
|
||||
# but we verify the middleware adds rate limit headers.
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
@@ -51,36 +186,6 @@ async def test_rate_limit_returns_429(client):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(self, auth_header: str = "") -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.url.path = "/purchases"
|
||||
req.headers = {"authorization": auth_header} if auth_header else {}
|
||||
return req
|
||||
|
||||
def test_distinct_tokens_produce_distinct_keys(self):
|
||||
req1 = self._make_request("Bearer token_alpha_12345")
|
||||
req2 = self._make_request("Bearer token_beta_67890")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_token_produces_same_key(self):
|
||||
req1 = self._make_request("Bearer same_token_value_abc")
|
||||
req2 = self._make_request("Bearer same_token_value_abc")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 == key2
|
||||
|
||||
def test_key_does_not_contain_raw_token_suffix(self):
|
||||
raw_token = "my_secret_jwt_token_xyz"
|
||||
req = self._make_request(f"Bearer {raw_token}")
|
||||
key, _ = _get_rate_limit_key(req)
|
||||
assert raw_token[-16:] not in key
|
||||
assert raw_token not in key
|
||||
|
||||
@@ -71,3 +71,97 @@ async def test_public_inflation(client, public_data):
|
||||
data = resp.json()
|
||||
assert "categories" in data
|
||||
assert "cartsnitch_index" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_invalid_uuid(client):
|
||||
resp = await client.get("/public/trends/not-a-uuid")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_zero(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=0")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_negative(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=-1")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_over_max(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=999")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_days_valid(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}?days=30")
|
||||
assert resp.status_code == 200
|
||||
assert "product_name" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_empty_list(client):
|
||||
resp = await client.get("/public/store-comparison")
|
||||
assert resp.status_code == 400
|
||||
assert "detail" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_category_xss(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(
|
||||
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_comparison_category_sql_injection(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_invalid_period(client, public_data):
|
||||
resp = await client.get("/public/inflation?period=10years")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_valid_periods(client, public_data):
|
||||
for period in ["all-time", "1y", "6m", "3m", "1m"]:
|
||||
resp = await client.get(f"/public/inflation?period={period}")
|
||||
assert resp.status_code == 200, f"period={period} failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflation_category_too_long(client, public_data):
|
||||
long_category = "x" * 200
|
||||
resp = await client.get(f"/public/inflation?category={long_category}")
|
||||
assert resp.status_code == 422
|
||||
assert "detail" in resp.json()
|
||||
assert "stack" not in resp.json()
|
||||
|
||||
@@ -9,3 +9,7 @@ DATABASE_URL=postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch
|
||||
|
||||
# Port the auth service listens on
|
||||
PORT=3001
|
||||
|
||||
# Resend email provider for transactional email
|
||||
RESEND_API_KEY=re_your_api_key_here
|
||||
FROM_EMAIL=CartSnitch <noreply@cartsnitch.com>
|
||||
|
||||
Generated
+74
-1
@@ -10,7 +10,8 @@
|
||||
"dependencies": {
|
||||
"bcrypt": "^6.0.0",
|
||||
"better-auth": "^1.2.0",
|
||||
"pg": "^8.13.0"
|
||||
"pg": "^8.13.0",
|
||||
"resend": "^6.11.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/bcrypt": "^6.0.0",
|
||||
@@ -633,6 +634,12 @@
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/@stablelib/base64": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@stablelib/base64/-/base64-1.0.1.tgz",
|
||||
"integrity": "sha512-1bnPQqSxSuc3Ii6MhBysoWCg58j97aUjuCSZrGSmDxNqtytIi0k8utUenAwTZN4V5mXXYGsVUI9zeBqy+jBOSQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@standard-schema/spec": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz",
|
||||
@@ -858,6 +865,12 @@
|
||||
"@esbuild/win32-x64": "0.27.4"
|
||||
}
|
||||
},
|
||||
"node_modules/fast-sha256": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/fast-sha256/-/fast-sha256-1.3.0.tgz",
|
||||
"integrity": "sha512-n11RGP/lrWEFI/bWdygLxhI+pVeo1ZYIVwvvPkW7azl/rOy+F3HYRZ2K5zeE9mmkhQppyv9sQFx0JM9UabnpPQ==",
|
||||
"license": "Unlicense"
|
||||
},
|
||||
"node_modules/fsevents": {
|
||||
"version": "2.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
|
||||
@@ -1028,6 +1041,12 @@
|
||||
"split2": "^4.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/postal-mime": {
|
||||
"version": "2.7.4",
|
||||
"resolved": "https://registry.npmjs.org/postal-mime/-/postal-mime-2.7.4.tgz",
|
||||
"integrity": "sha512-0WdnFQYUrPGGTFu1uOqD2s7omwua8xaeYGdO6rb88oD5yJ/4pPHDA4sdWqfD8wQVfCny563n/HQS7zTFft+f/g==",
|
||||
"license": "MIT-0"
|
||||
},
|
||||
"node_modules/postgres-array": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz",
|
||||
@@ -1067,6 +1086,27 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/resend": {
|
||||
"version": "6.11.0",
|
||||
"resolved": "https://registry.npmjs.org/resend/-/resend-6.11.0.tgz",
|
||||
"integrity": "sha512-S9gxOccfwc+E6Cr3q28Gu8NkiIjYlYPlj9rqk4zkIuzlEoh8sWu/IvJSg7U7t+o3g0Ov2IOCzcneUaCi/M/WdQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"postal-mime": "2.7.4",
|
||||
"svix": "1.90.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@react-email/render": "*"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@react-email/render": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/resolve-pkg-maps": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
|
||||
@@ -1098,6 +1138,26 @@
|
||||
"node": ">= 10.x"
|
||||
}
|
||||
},
|
||||
"node_modules/standardwebhooks": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/standardwebhooks/-/standardwebhooks-1.0.0.tgz",
|
||||
"integrity": "sha512-BbHGOQK9olHPMvQNHWul6MYlrRTAOKn03rOe4A8O3CLWhNf4YHBqq2HJKKC+sfqpxiBY52pNeesD6jIiLDz8jg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@stablelib/base64": "^1.0.0",
|
||||
"fast-sha256": "^1.3.0"
|
||||
}
|
||||
},
|
||||
"node_modules/svix": {
|
||||
"version": "1.90.0",
|
||||
"resolved": "https://registry.npmjs.org/svix/-/svix-1.90.0.tgz",
|
||||
"integrity": "sha512-ljkZuyy2+IBEoESkIpn8sLM+sxJHQcPxlZFxU+nVDhltNfUMisMBzWX/UR8SjEnzoI28ZjCzMbmYAPwSTucoMw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"standardwebhooks": "1.0.0",
|
||||
"uuid": "^10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tsx": {
|
||||
"version": "4.21.0",
|
||||
"resolved": "https://registry.npmjs.org/tsx/-/tsx-4.21.0.tgz",
|
||||
@@ -1139,6 +1199,19 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "10.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz",
|
||||
"integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==",
|
||||
"funding": [
|
||||
"https://github.com/sponsors/broofa",
|
||||
"https://github.com/sponsors/ctavan"
|
||||
],
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/xtend": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
|
||||
|
||||
+4
-3
@@ -10,15 +10,16 @@
|
||||
"generate": "npx @better-auth/cli generate"
|
||||
},
|
||||
"dependencies": {
|
||||
"bcrypt": "^6.0.0",
|
||||
"better-auth": "^1.2.0",
|
||||
"pg": "^8.13.0",
|
||||
"bcrypt": "^6.0.0"
|
||||
"resend": "^6.11.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/bcrypt": "^6.0.0",
|
||||
"@types/node": "^22.0.0",
|
||||
"@types/pg": "^8.11.0",
|
||||
"@types/bcrypt": "^6.0.0",
|
||||
"tsx": "^4.19.0",
|
||||
"typescript": "^5.7.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
+18
-1
@@ -1,6 +1,7 @@
|
||||
import { betterAuth } from "better-auth";
|
||||
import bcrypt from "bcrypt";
|
||||
import pg from "pg";
|
||||
import { Resend } from "resend";
|
||||
|
||||
const { Pool } = pg;
|
||||
|
||||
@@ -21,6 +22,9 @@ export const pool = new Pool({
|
||||
connectionString: databaseUrl ?? "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||
});
|
||||
|
||||
const resend = new Resend(process.env.RESEND_API_KEY);
|
||||
const fromEmail = process.env.FROM_EMAIL || "CartSnitch <noreply@cartsnitch.com>";
|
||||
|
||||
export const auth = betterAuth({
|
||||
database: pool,
|
||||
basePath: "/auth",
|
||||
@@ -41,6 +45,19 @@ export const auth = betterAuth({
|
||||
},
|
||||
},
|
||||
|
||||
emailVerification: {
|
||||
sendOnSignUp: true,
|
||||
autoSignInAfterVerification: true,
|
||||
sendVerificationEmail: async ({ user, url }) => {
|
||||
await resend.emails.send({
|
||||
from: fromEmail,
|
||||
to: user.email,
|
||||
subject: "Verify your CartSnitch email",
|
||||
html: `<p>Hi ${user.name || ""},</p><p>Click the link below to verify your email address:</p><p><a href="${url}">Verify Email</a></p><p>This link expires in 1 hour.</p><p>— CartSnitch</p>`,
|
||||
});
|
||||
},
|
||||
},
|
||||
|
||||
session: {
|
||||
modelName: "sessions",
|
||||
fields: {
|
||||
@@ -103,4 +120,4 @@ export const auth = betterAuth({
|
||||
"https://cartsnitch.dev.farh.net",
|
||||
"https://cartsnitch.uat.farh.net",
|
||||
],
|
||||
});
|
||||
});
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_common.constants import ProductCategory, SizeUnit
|
||||
@@ -26,7 +27,9 @@ class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
brand: Mapped[str | None] = mapped_column(String(200))
|
||||
size: Mapped[str | None] = mapped_column(String(50))
|
||||
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(
|
||||
JSON().with_variant(JSONB(), "postgresql"), default=list
|
||||
)
|
||||
|
||||
# Relationships
|
||||
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
|
||||
|
||||
+89
-1
@@ -1,4 +1,4 @@
|
||||
import { test as base, expect } from "@playwright/test";
|
||||
import { test as base, expect, type Page } from "@playwright/test";
|
||||
import AxeBuilder from "@axe-core/playwright";
|
||||
|
||||
export const test = base.extend<{ axeCheck: void }>({
|
||||
@@ -10,3 +10,91 @@ export const test = base.extend<{ axeCheck: void }>({
|
||||
});
|
||||
|
||||
export { expect } from "@playwright/test";
|
||||
|
||||
const MOCK_USER_ID = "mock_user_123";
|
||||
const MOCK_SESSION_ID = "mock_session_456";
|
||||
|
||||
async function mockAuthRoutes(page: Page, authenticated = false) {
|
||||
await page.route(/.*\/auth\/sign-up\/email.*/, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
token: null,
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(/.*\/auth\/sign-in\/email.*/, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
redirect: false,
|
||||
token: "mock_token_123",
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(/.*\/auth\/get-session.*/, async (route) => {
|
||||
if (authenticated) {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
session: {
|
||||
id: MOCK_SESSION_ID,
|
||||
expiresAt: new Date(Date.now() + 7 * 24 * 60 * 60 * 1000).toISOString(),
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
ipAddress: null,
|
||||
userAgent: null,
|
||||
},
|
||||
user: {
|
||||
id: MOCK_USER_ID,
|
||||
email: "mock@cartsnitch.test",
|
||||
name: "Mock User",
|
||||
emailVerified: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 401,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ error: "Unauthorized" }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export async function mockSessionDelayed(page: Page, delayMs = 3000) {
|
||||
await page.route(/.*\/auth\/get-session.*/, async (route) => {
|
||||
await new Promise((r) => setTimeout(r, delayMs));
|
||||
await route.fulfill({
|
||||
status: 401,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ error: "Unauthorized" }),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export { mockAuthRoutes };
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { mockAuthRoutes } from '../fixtures';
|
||||
|
||||
const uniqueEmail = () => `betty+e2e-${Date.now()}@cartsnitch.test`;
|
||||
|
||||
test.describe('J1: Registration and Login', () => {
|
||||
test('can register a new account and lands on dashboard', async ({ page }) => {
|
||||
test('can register a new account and see check your email screen', async ({ page }) => {
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/register');
|
||||
await page.fill('[placeholder="Full Name"]', 'Betty Tester');
|
||||
await page.fill('[placeholder="Email"]', uniqueEmail());
|
||||
await page.fill('[placeholder="Password (min. 8 characters)"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
// With VITE_MOCK_AUTH=true the app navigates to "/" on success
|
||||
await expect(page).toHaveURL('http://localhost:5173/');
|
||||
await expect(page.getByRole('heading', { name: /cart/i })).toBeVisible();
|
||||
await expect(page.getByRole('heading', { name: /check your email/i })).toBeVisible();
|
||||
});
|
||||
|
||||
test('shows validation error when registration fields are empty', async ({ page }) => {
|
||||
@@ -31,22 +31,9 @@ test.describe('J1: Registration and Login', () => {
|
||||
});
|
||||
|
||||
test('can sign in with credentials and land on dashboard', async ({ page }) => {
|
||||
// Register first so we have a real account
|
||||
const email = uniqueEmail();
|
||||
await page.goto('/register');
|
||||
await page.fill('[placeholder="Full Name"]', 'Login Betty');
|
||||
await page.fill('[placeholder="Email"]', email);
|
||||
await page.fill('[placeholder="Password (min. 8 characters)"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
await expect(page).toHaveURL('http://localhost:5173/');
|
||||
|
||||
// Sign out by clearing the mock session (reload with no session)
|
||||
await page.goto('/');
|
||||
await page.reload();
|
||||
|
||||
// Now sign in
|
||||
await mockAuthRoutes(page, true);
|
||||
await page.goto('/login');
|
||||
await page.fill('[placeholder="Email"]', email);
|
||||
await page.fill('[placeholder="Email"]', 'test@cartsnitch.test');
|
||||
await page.fill('[placeholder="Password"]', 'TestPass123!');
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { mockAuthRoutes, mockSessionDelayed } from '../fixtures';
|
||||
|
||||
test.describe('J8: Unauthenticated Access', () => {
|
||||
test('redirects /dashboard (/) to /login when not authenticated', async ({ page }) => {
|
||||
// No session cookie — start fresh
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -11,7 +11,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /purchases to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/purchases');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -19,7 +19,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /products to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/products');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -27,7 +27,7 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('redirects /coupons to /login when not authenticated', async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/coupons');
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
@@ -35,15 +35,9 @@ test.describe('J8: Unauthenticated Access', () => {
|
||||
});
|
||||
|
||||
test('shows loading spinner while auth session is pending', async ({ page }) => {
|
||||
// Intercept but don't respond — session stays pending
|
||||
await page.context().clearCookies();
|
||||
await page.request.fetch('/api/auth/session', {
|
||||
method: 'GET',
|
||||
});
|
||||
|
||||
// Just navigate to a protected route — ProtectedRoute will show spinner while session is pending
|
||||
await mockSessionDelayed(page, 3000);
|
||||
await page.goto('/purchases');
|
||||
// Spinner is visible briefly; once resolved, should redirect to login
|
||||
await expect(page.locator('.animate-spin')).toBeVisible({ timeout: 2000 });
|
||||
await expect(page).toHaveURL(/\/login/, { timeout: 10_000 });
|
||||
});
|
||||
});
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
import { test, expect } from './fixtures';
|
||||
import { test, expect, mockAuthRoutes } from './fixtures';
|
||||
|
||||
test('app loads', async ({ page }) => {
|
||||
await mockAuthRoutes(page, false);
|
||||
await page.goto('/');
|
||||
// Unauthenticated users are redirected to /login
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
await expect(page.getByRole('heading', { name: /CartSnitch/i })).toBeVisible();
|
||||
});
|
||||
|
||||
Generated
+100
-28
@@ -38,7 +38,7 @@
|
||||
"tailwindcss": "^4.0.0",
|
||||
"typescript": "^5.7.3",
|
||||
"typescript-eslint": "^8.56.1",
|
||||
"vite": "^6.3.5",
|
||||
"vite": "^6.4.2",
|
||||
"vite-plugin-pwa": "^0.21.2",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
@@ -1867,6 +1867,40 @@
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@emnapi/core": {
|
||||
"version": "1.9.2",
|
||||
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.9.2.tgz",
|
||||
"integrity": "sha512-UC+ZhH3XtczQYfOlu3lNEkdW/p4dsJ1r/bP7H8+rhao3TTTMO1ATq/4DdIi23XuGoFY+Cz0JmCbdVl0hz9jZcA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/wasi-threads": "1.2.1",
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@emnapi/runtime": {
|
||||
"version": "1.9.2",
|
||||
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.9.2.tgz",
|
||||
"integrity": "sha512-3U4+MIWHImeyu1wnmVygh5WlgfYDtyf0k8AbLhMFxOipihf6nrWC4syIm/SwEeec0mNSafiiNnMJwbza/Is6Lw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@emnapi/wasi-threads": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz",
|
||||
"integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/aix-ppc64": {
|
||||
"version": "0.25.12",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
|
||||
@@ -2669,6 +2703,25 @@
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/wasm-runtime": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-1.1.3.tgz",
|
||||
"integrity": "sha512-xK9sGVbJWYb08+mTJt3/YV24WxvxpXcXtP6B172paPZ+Ts69Re9dAr7lKwJoeIx8OoeuimEiRZ7umkiUVClmmQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@tybys/wasm-util": "^0.10.1"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/Brooooooklyn"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@emnapi/core": "^1.7.1",
|
||||
"@emnapi/runtime": "^1.7.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@noble/ciphers": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@noble/ciphers/-/ciphers-2.1.1.tgz",
|
||||
@@ -3659,6 +3712,17 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@tybys/wasm-util": {
|
||||
"version": "0.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.10.1.tgz",
|
||||
"integrity": "sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/aria-query": {
|
||||
"version": "5.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz",
|
||||
@@ -4166,33 +4230,6 @@
|
||||
"url": "https://opencollective.com/vitest"
|
||||
}
|
||||
},
|
||||
"node_modules/@vitest/mocker": {
|
||||
"version": "3.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-3.2.4.tgz",
|
||||
"integrity": "sha512-46ryTE9RZO/rfDd7pEqFl7etuyzekzEhUbTW3BvmeO/BcCMEgq59BKhek3dXDWgAj4oMK6OZi+vRr1wPW6qjEQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@vitest/spy": "3.2.4",
|
||||
"estree-walker": "^3.0.3",
|
||||
"magic-string": "^0.30.17"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://opencollective.com/vitest"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"msw": "^2.4.9",
|
||||
"vite": "^5.0.0 || ^6.0.0 || ^7.0.0-0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"msw": {
|
||||
"optional": true
|
||||
},
|
||||
"vite": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@vitest/pretty-format": {
|
||||
"version": "3.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-3.2.4.tgz",
|
||||
@@ -9473,6 +9510,14 @@
|
||||
"typescript": ">=4.8.4"
|
||||
}
|
||||
},
|
||||
"node_modules/tslib": {
|
||||
"version": "2.8.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
|
||||
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
||||
"dev": true,
|
||||
"license": "0BSD",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/type-check": {
|
||||
"version": "0.4.0",
|
||||
"resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz",
|
||||
@@ -10020,6 +10065,33 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/vitest/node_modules/@vitest/mocker": {
|
||||
"version": "3.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-3.2.4.tgz",
|
||||
"integrity": "sha512-46ryTE9RZO/rfDd7pEqFl7etuyzekzEhUbTW3BvmeO/BcCMEgq59BKhek3dXDWgAj4oMK6OZi+vRr1wPW6qjEQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@vitest/spy": "3.2.4",
|
||||
"estree-walker": "^3.0.3",
|
||||
"magic-string": "^0.30.17"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://opencollective.com/vitest"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"msw": "^2.4.9",
|
||||
"vite": "^5.0.0 || ^6.0.0 || ^7.0.0-0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"msw": {
|
||||
"optional": true
|
||||
},
|
||||
"vite": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/w3c-xmlserializer": {
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz",
|
||||
|
||||
+1
-1
@@ -43,7 +43,7 @@
|
||||
"tailwindcss": "^4.0.0",
|
||||
"typescript": "^5.7.3",
|
||||
"typescript-eslint": "^8.56.1",
|
||||
"vite": "^6.3.5",
|
||||
"vite": "^6.4.2",
|
||||
"vite-plugin-pwa": "^0.21.2",
|
||||
"vitest": "^3.2.4"
|
||||
},
|
||||
|
||||
@@ -9,7 +9,7 @@ export default defineConfig({
|
||||
},
|
||||
],
|
||||
webServer: {
|
||||
command: 'VITE_MOCK_AUTH=true npm run dev',
|
||||
command: 'npm run dev',
|
||||
url: 'http://localhost:5173',
|
||||
reuseExistingServer: !process.env.CI,
|
||||
},
|
||||
|
||||
@@ -5,12 +5,14 @@ Matches products across retailers by:
|
||||
2. Fuzzy name matching via token-based Jaccard similarity (lower confidence)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from cartsnitch_common.models.product import NormalizedProduct
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import cast, func, select, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -96,17 +98,24 @@ def jaccard_similarity(a: str, b: str) -> float:
|
||||
def match_by_upc(session: Session, upc: str) -> MatchResult | None:
|
||||
"""Find a normalized product by exact UPC match.
|
||||
|
||||
Loads products with upc_variants and checks membership in Python
|
||||
for cross-database compatibility (works on both PostgreSQL and SQLite).
|
||||
Uses PostgreSQL JSONB containment (@>) for production efficiency.
|
||||
Falls back to LIKE on SQLite for test compatibility.
|
||||
"""
|
||||
# TODO: Use PostgreSQL JSON containment query (@>) for production.
|
||||
# Current approach loads all products into memory — acceptable for tests
|
||||
# and small datasets, but will not scale.
|
||||
stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None))
|
||||
products = session.execute(stmt).scalars().all()
|
||||
for product in products:
|
||||
if product.upc_variants and upc in product.upc_variants:
|
||||
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
|
||||
dialect_name = session.bind.dialect.name if session.bind else "default"
|
||||
if dialect_name == "postgresql":
|
||||
stmt = select(NormalizedProduct).where(
|
||||
cast(NormalizedProduct.upc_variants, JSONB).op("@>")(
|
||||
func.cast(json.dumps([upc]), JSONB)
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = select(NormalizedProduct).where(
|
||||
NormalizedProduct.upc_variants.is_not(None),
|
||||
cast(NormalizedProduct.upc_variants, String).contains(upc),
|
||||
)
|
||||
product = session.execute(stmt).scalars().first()
|
||||
if product:
|
||||
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import { AccountLinking } from './pages/AccountLinking.tsx'
|
||||
import { Login } from './pages/Login.tsx'
|
||||
import { Register } from './pages/Register.tsx'
|
||||
import { ForgotPassword } from './pages/ForgotPassword.tsx'
|
||||
import { VerifyEmail } from './pages/VerifyEmail.tsx'
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
@@ -47,6 +48,7 @@ export default function App() {
|
||||
<Route path="login" element={<Login />} />
|
||||
<Route path="register" element={<Register />} />
|
||||
<Route path="forgot-password" element={<ForgotPassword />} />
|
||||
<Route path="verify-email" element={<VerifyEmail />} />
|
||||
</Routes>
|
||||
</BrowserRouter>
|
||||
</QueryClientProvider>
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
import { useEffect } from 'react'
|
||||
import { Navigate, Outlet } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function ProtectedRoute() {
|
||||
const isMockAuth = import.meta.env.VITE_MOCK_AUTH === 'true'
|
||||
const { data: session, isPending } = authClient.useSession()
|
||||
const isAuthenticated = useAuthStore((s) => s.isAuthenticated)
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
|
||||
useEffect(() => {
|
||||
if (!isMockAuth) {
|
||||
setAuthenticated(!!session)
|
||||
}
|
||||
}, [session, setAuthenticated, isMockAuth])
|
||||
|
||||
// In mock auth mode, rely on Zustand store (set by Login/Register pages)
|
||||
if (isMockAuth) {
|
||||
if (!isAuthenticated) return <Navigate to="/login" replace />
|
||||
return <Outlet />
|
||||
}
|
||||
|
||||
if (isPending) {
|
||||
return (
|
||||
|
||||
@@ -79,21 +79,21 @@ function AuthenticatedDashboard({ userName }: { userName: string }) {
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm">
|
||||
<p className="text-xs font-medium text-gray-500">Watching</p>
|
||||
<p className="mt-1 text-2xl font-bold text-gray-900">{watchingAlerts.length}</p>
|
||||
<p className="text-xs text-gray-400">price alerts</p>
|
||||
<p className="text-xs text-gray-600">price alerts</p>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm">
|
||||
<p className="text-xs font-medium text-gray-500">This Month</p>
|
||||
<p className="mt-1 text-2xl font-bold text-gray-900">
|
||||
${recentPurchases.reduce((sum, p) => sum + p.total, 0).toFixed(0)}
|
||||
</p>
|
||||
<p className="text-xs text-gray-400">grocery spend</p>
|
||||
<p className="text-xs text-gray-600">grocery spend</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Price trend sparklines */}
|
||||
<section className="mt-6">
|
||||
<h2 className="mb-3 text-lg font-semibold text-gray-700">Price Trends</h2>
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm text-center text-sm text-gray-400">
|
||||
<div className="rounded-xl bg-white p-4 shadow-sm text-center text-sm text-gray-600">
|
||||
Connect a store to see price trends
|
||||
</div>
|
||||
</section>
|
||||
|
||||
+1
-8
@@ -1,7 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
import { Link, useNavigate } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function Login() {
|
||||
const [email, setEmail] = useState('')
|
||||
@@ -9,7 +8,6 @@ export function Login() {
|
||||
const [error, setError] = useState('')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const navigate = useNavigate()
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
|
||||
async function handleSubmit(e: React.FormEvent) {
|
||||
e.preventDefault()
|
||||
@@ -40,12 +38,7 @@ export function Login() {
|
||||
setError('Sign in failed. Please try again.')
|
||||
}
|
||||
} catch {
|
||||
if (import.meta.env.VITE_MOCK_AUTH === 'true') {
|
||||
setAuthenticated(true)
|
||||
navigate('/')
|
||||
} else {
|
||||
setError('Invalid email or password. Please try again.')
|
||||
}
|
||||
setError('Invalid email or password. Please try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
+49
-19
@@ -1,7 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
import { Link, useNavigate } from 'react-router-dom'
|
||||
import { Link } from 'react-router-dom'
|
||||
import { authClient } from '../lib/auth-client.ts'
|
||||
import { useAuthStore } from '../stores/auth.ts'
|
||||
|
||||
export function Register() {
|
||||
const [name, setName] = useState('')
|
||||
@@ -9,8 +8,9 @@ export function Register() {
|
||||
const [password, setPassword] = useState('')
|
||||
const [error, setError] = useState('')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const navigate = useNavigate()
|
||||
const setAuthenticated = useAuthStore((s) => s.setAuthenticated)
|
||||
const [registrationComplete, setRegistrationComplete] = useState(false)
|
||||
const [resendLoading, setResendLoading] = useState(false)
|
||||
const [resendMessage, setResendMessage] = useState('')
|
||||
|
||||
async function handleSubmit(e: React.FormEvent) {
|
||||
e.preventDefault()
|
||||
@@ -38,27 +38,57 @@ export function Register() {
|
||||
throw new Error(authError.message ?? 'Registration failed')
|
||||
}
|
||||
|
||||
// After successful signUp, force a session fetch to confirm the cookie is set
|
||||
// before navigating to the protected route
|
||||
const sessionResult = await authClient.getSession()
|
||||
if (sessionResult.data) {
|
||||
navigate('/')
|
||||
} else {
|
||||
// Session not established — show success message and link to login
|
||||
setError('Account created! Please sign in.')
|
||||
}
|
||||
setRegistrationComplete(true)
|
||||
} catch {
|
||||
if (import.meta.env.VITE_MOCK_AUTH === 'true') {
|
||||
setAuthenticated(true)
|
||||
navigate('/')
|
||||
} else {
|
||||
setError('Registration failed. Please try again.')
|
||||
}
|
||||
setError('Registration failed. Please try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleResendVerification() {
|
||||
setResendLoading(true)
|
||||
setResendMessage('')
|
||||
try {
|
||||
const { error } = await authClient.sendVerificationEmail({ email })
|
||||
if (error) {
|
||||
setResendMessage('Failed to resend. Please try again.')
|
||||
} else {
|
||||
setResendMessage('Verification email sent!')
|
||||
}
|
||||
} finally {
|
||||
setResendLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
if (registrationComplete) {
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col items-center justify-center px-4">
|
||||
<h1 className="mb-2 text-3xl font-bold text-gray-900">Check your email</h1>
|
||||
<p className="mb-8 text-sm text-gray-500">
|
||||
We sent a verification link to {email}. Click it to activate your account.
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleResendVerification}
|
||||
disabled={resendLoading}
|
||||
className="min-h-12 rounded-xl bg-brand-blue px-6 py-3 text-base font-medium text-white active:bg-brand-blue/90 disabled:opacity-60"
|
||||
>
|
||||
{resendLoading ? 'Sending...' : 'Resend email'}
|
||||
</button>
|
||||
{resendMessage && (
|
||||
<p className="mt-4 text-sm text-gray-500">{resendMessage}</p>
|
||||
)}
|
||||
<p className="mt-6 text-sm text-gray-500">
|
||||
Already have an account?{' '}
|
||||
<Link to="/login" className="text-brand-blue">
|
||||
Sign in
|
||||
</Link>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col items-center justify-center px-4">
|
||||
<h1 className="mb-2 text-3xl font-bold text-gray-900">Create Account</h1>
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useNavigate, useSearchParams } from "react-router-dom";
|
||||
import { authClient } from "../lib/auth-client.ts";
|
||||
|
||||
export function VerifyEmail() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const [status, setStatus] = useState<"verifying" | "success" | "error">("verifying");
|
||||
const [resendEmail, setResendEmail] = useState("");
|
||||
const [showResend, setShowResend] = useState(false);
|
||||
const [resending, setResending] = useState(false);
|
||||
const [resendMessage, setResendMessage] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const token = searchParams.get("token");
|
||||
const callbackURL = searchParams.get("callbackURL") || "/";
|
||||
|
||||
if (!token) {
|
||||
setStatus("error");
|
||||
return;
|
||||
}
|
||||
|
||||
authClient.verifyEmail({ query: { token } })
|
||||
.then(() => {
|
||||
setStatus("success");
|
||||
setTimeout(() => {
|
||||
navigate(callbackURL);
|
||||
}, 2000);
|
||||
})
|
||||
.catch(() => {
|
||||
setStatus("error");
|
||||
});
|
||||
}, [searchParams, navigate]);
|
||||
|
||||
async function handleResend() {
|
||||
if (!resendEmail) {
|
||||
setResendMessage("Please enter your email address.");
|
||||
return;
|
||||
}
|
||||
|
||||
setResending(true);
|
||||
setResendMessage("");
|
||||
|
||||
try {
|
||||
const { error } = await authClient.sendVerificationEmail({ email: resendEmail });
|
||||
if (error) {
|
||||
setResendMessage("Failed to resend. Please try again.");
|
||||
} else {
|
||||
setResendMessage("Verification email sent!");
|
||||
setShowResend(false);
|
||||
}
|
||||
} finally {
|
||||
setResending(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col items-center justify-center px-4">
|
||||
{status === "verifying" && (
|
||||
<>
|
||||
<div className="mb-4 h-8 w-8 animate-spin rounded-full border-4 border-gray-200 border-t-brand-blue" />
|
||||
<h1 className="mb-2 text-2xl font-bold text-gray-900">Verifying your email...</h1>
|
||||
<p className="text-sm text-gray-500">Please wait while we verify your email address.</p>
|
||||
</>
|
||||
)}
|
||||
|
||||
{status === "success" && (
|
||||
<>
|
||||
<h1 className="mb-2 text-2xl font-bold text-gray-900">Email verified!</h1>
|
||||
<p className="text-sm text-gray-500">Redirecting you shortly...</p>
|
||||
</>
|
||||
)}
|
||||
|
||||
{status === "error" && (
|
||||
<>
|
||||
<h1 className="mb-2 text-2xl font-bold text-gray-900">Verification failed</h1>
|
||||
<p className="mb-6 text-sm text-gray-500">The verification link may have expired or is invalid.</p>
|
||||
|
||||
{!showResend ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowResend(true)}
|
||||
className="min-h-12 rounded-xl bg-brand-blue px-6 py-3 text-base font-medium text-white active:bg-brand-blue/90"
|
||||
>
|
||||
Resend verification email
|
||||
</button>
|
||||
) : (
|
||||
<div className="w-full max-w-sm space-y-4">
|
||||
<input
|
||||
type="email"
|
||||
placeholder="Your email address"
|
||||
value={resendEmail}
|
||||
onChange={(e) => setResendEmail(e.target.value)}
|
||||
className="min-h-12 w-full rounded-xl border border-gray-200 px-4 text-base focus:border-brand-blue focus:outline-none focus:ring-1 focus:ring-brand-blue"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleResend}
|
||||
disabled={resending}
|
||||
className="min-h-12 w-full rounded-xl bg-brand-blue px-4 py-3 text-base font-medium text-white active:bg-brand-blue/90 disabled:opacity-60"
|
||||
>
|
||||
{resending ? "Sending..." : "Send verification email"}
|
||||
</button>
|
||||
{resendMessage && (
|
||||
<p className="text-sm text-gray-500">{resendMessage}</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user