Compare commits

..

1 Commits

Author SHA1 Message Date
Stockboy Steve 53e802746c fix(api): run Alembic migrations on startup to fix auth 500s
Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-31 19:52:00 +00:00
31 changed files with 129 additions and 284 deletions
+1 -1
View File
@@ -95,7 +95,7 @@ jobs:
run: | run: |
CHROME_PATH=$(find /home/runner/.cache/ms-playwright -name chrome -type f 2>/dev/null | head -1) CHROME_PATH=$(find /home/runner/.cache/ms-playwright -name chrome -type f 2>/dev/null | head -1)
npm install -g @lhci/cli npm install -g @lhci/cli
CHROME_PATH="$CHROME_PATH" lhci autorun --chrome-flags="--headless=new --no-sandbox --disable-gpu --disable-dev-shm-usage" LHCI_CHROME_PATH="$CHROME_PATH" lhci autorun
build-and-push: build-and-push:
runs-on: runners-cartsnitch runs-on: runners-cartsnitch
+10 -22
View File
@@ -5,6 +5,7 @@ Sessions are verified by querying the shared sessions table directly.
""" """
from datetime import UTC, datetime from datetime import UTC, datetime
from uuid import UUID
from fastapi import Cookie, Depends, Header, HTTPException, Request, status from fastapi import Cookie, Depends, Header, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -18,27 +19,18 @@ from cartsnitch_api.database import get_db
# but we support Bearer tokens for service-to-service or mobile clients. # but we support Bearer tokens for service-to-service or mobile clients.
bearer_scheme = HTTPBearer(auto_error=False) bearer_scheme = HTTPBearer(auto_error=False)
# Better-Auth session cookie names. # Better-Auth session cookie name
# Over HTTPS Better-Auth adds the __Secure- prefix automatically. SESSION_COOKIE_NAME = "better-auth.session_token"
SESSION_COOKIE_NAMES = [
"__Secure-better-auth.session_token", # HTTPS (deployed)
"better-auth.session_token", # HTTP (local dev)
]
async def _validate_session_token(token: str, db: AsyncSession) -> str: async def _validate_session_token(token: str, db: AsyncSession) -> UUID:
"""Validate a Better-Auth session token against the sessions table. """Validate a Better-Auth session token against the sessions table.
Returns the user_id (as str) if the session is valid and not expired. Returns the user_id (as UUID) if the session is valid and not expired.
Better-Auth v1.5.6 stores raw tokens in the DB. The session cookie
is signed: ``rawToken.base64HMACSignature``. Strip the signature
before querying.
""" """
# Signed cookie format: rawToken.hmacSignature — split and use only the token part
raw_token = token.split(".")[0] if "." in token else token
result = await db.execute( result = await db.execute(
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"), text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
{"token": raw_token}, {"token": token},
) )
row = result.first() row = result.first()
@@ -59,14 +51,14 @@ async def _validate_session_token(token: str, db: AsyncSession) -> str:
detail="Session expired", detail="Session expired",
) )
return str(user_id) return UUID(str(user_id))
async def get_current_user( async def get_current_user(
request: Request, request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> str: ) -> UUID:
"""Extract and validate the session token from cookie or Authorization header. """Extract and validate the session token from cookie or Authorization header.
Checks in order: Checks in order:
@@ -75,12 +67,8 @@ async def get_current_user(
""" """
token: str | None = None token: str | None = None
# 1. Check session cookie (try both names for HTTP/HTTPS compatibility) # 1. Check session cookie
cookie_token = None cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
for name in SESSION_COOKIE_NAMES:
cookie_token = request.cookies.get(name)
if cookie_token:
break
if cookie_token: if cookie_token:
token = cookie_token token = cookie_token
+5 -4
View File
@@ -2,21 +2,22 @@
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any, cast from typing import Any, cast
from uuid import UUID
from jose import JWTError, jwt from jose import JWTError, jwt
from cartsnitch_api.config import settings from cartsnitch_api.config import settings
def create_access_token(user_id: str) -> str: def create_access_token(user_id: UUID) -> str:
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes) expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_access_token_expire_minutes)
payload = {"sub": user_id, "exp": expire, "type": "access"} payload = {"sub": str(user_id), "exp": expire, "type": "access"}
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
def create_refresh_token(user_id: str) -> str: def create_refresh_token(user_id: UUID) -> str:
expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days) expire = datetime.now(UTC) + timedelta(days=settings.jwt_refresh_token_expire_days)
payload = {"sub": user_id, "exp": expire, "type": "refresh"} payload = {"sub": str(user_id), "exp": expire, "type": "refresh"}
return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)) return cast(str, jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
+5 -3
View File
@@ -5,6 +5,8 @@ the Better-Auth service (auth/). This router provides user profile
endpoints that query our own user data from the shared database. endpoints that query our own user data from the shared database.
""" """
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -21,7 +23,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
async def get_me( async def get_me(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = AuthService(db) svc = AuthService(db)
@@ -36,7 +38,7 @@ async def get_me(
@router.patch("/me", response_model=UserResponse) @router.patch("/me", response_model=UserResponse)
async def update_me( async def update_me(
body: UpdateUserRequest, body: UpdateUserRequest,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = AuthService(db) svc = AuthService(db)
@@ -52,7 +54,7 @@ async def update_me(
@router.delete("/me", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/me", status_code=status.HTTP_204_NO_CONTENT)
async def delete_me( async def delete_me(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = AuthService(db) svc = AuthService(db)
+10 -14
View File
@@ -2,7 +2,7 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import APIRouter, FastAPI from fastapi import FastAPI
from cartsnitch_api.auth.routes import router as auth_router from cartsnitch_api.auth.routes import router as auth_router
from cartsnitch_api.middleware.cors import add_cors_middleware from cartsnitch_api.middleware.cors import add_cors_middleware
@@ -46,19 +46,15 @@ def create_app() -> FastAPI:
# Routers # Routers
app.include_router(health_router) app.include_router(health_router)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(stores_router)
# Data endpoints mounted under /api/v1 app.include_router(purchases_router)
v1_router = APIRouter(prefix="/api/v1") app.include_router(products_router)
v1_router.include_router(stores_router) app.include_router(prices_router)
v1_router.include_router(purchases_router) app.include_router(coupons_router)
v1_router.include_router(products_router) app.include_router(shopping_router)
v1_router.include_router(prices_router) app.include_router(alerts_router)
v1_router.include_router(coupons_router) app.include_router(scraping_router)
v1_router.include_router(shopping_router) app.include_router(public_router)
v1_router.include_router(alerts_router)
v1_router.include_router(scraping_router)
v1_router.include_router(public_router)
app.include_router(v1_router)
return app return app
+2 -2
View File
@@ -9,14 +9,14 @@ from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import DiscountType from cartsnitch_api.constants import DiscountType
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct from cartsnitch_api.models.product import NormalizedProduct
from cartsnitch_api.models.store import Store from cartsnitch_api.models.store import Store
class Coupon(UUIDPrimaryKeyMixin, Base): class Coupon(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A coupon or deal for a product at a store.""" """A coupon or deal for a product at a store."""
__tablename__ = "coupons" __tablename__ = "coupons"
+2 -2
View File
@@ -9,7 +9,7 @@ from sqlalchemy import Date, ForeignKey, Index, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import PriceSource from cartsnitch_api.constants import PriceSource
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct from cartsnitch_api.models.product import NormalizedProduct
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from cartsnitch_api.models.store import Store from cartsnitch_api.models.store import Store
class PriceHistory(UUIDPrimaryKeyMixin, Base): class PriceHistory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A single price observation for a product at a store on a date.""" """A single price observation for a product at a store on a date."""
__tablename__ = "price_history" __tablename__ = "price_history"
+3 -3
View File
@@ -18,7 +18,7 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from cartsnitch_api.models.price import PriceHistory from cartsnitch_api.models.price import PriceHistory
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from cartsnitch_api.models.user import User from cartsnitch_api.models.user import User
class Purchase(UUIDPrimaryKeyMixin, Base): class Purchase(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A single shopping trip / receipt.""" """A single shopping trip / receipt."""
__tablename__ = "purchases" __tablename__ = "purchases"
@@ -61,7 +61,7 @@ class Purchase(UUIDPrimaryKeyMixin, Base):
) )
class PurchaseItem(UUIDPrimaryKeyMixin, Base): class PurchaseItem(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Individual line item on a receipt.""" """Individual line item on a receipt."""
__tablename__ = "purchase_items" __tablename__ = "purchase_items"
@@ -9,13 +9,13 @@ from sqlalchemy import Date, ForeignKey, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_api.constants import SizeUnit from cartsnitch_api.constants import SizeUnit
from cartsnitch_api.models.base import Base, UUIDPrimaryKeyMixin from cartsnitch_api.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from cartsnitch_api.models.product import NormalizedProduct from cartsnitch_api.models.product import NormalizedProduct
class ShrinkflationEvent(UUIDPrimaryKeyMixin, Base): class ShrinkflationEvent(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Detected shrinkflation event — product size changed while price held or rose.""" """Detected shrinkflation event — product size changed while price held or rose."""
__tablename__ = "shrinkflation_events" __tablename__ = "shrinkflation_events"
+5 -3
View File
@@ -1,5 +1,7 @@
"""Alert routes: list alerts, manage settings.""" """Alert routes: list alerts, manage settings."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,7 +15,7 @@ router = APIRouter(prefix="/alerts", tags=["alerts"])
@router.get("", response_model=list[AlertResponse]) @router.get("", response_model=list[AlertResponse])
async def list_alerts( async def list_alerts(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = AlertService(db) svc = AlertService(db)
@@ -22,7 +24,7 @@ async def list_alerts(
@router.get("/settings", response_model=AlertSettingsResponse) @router.get("/settings", response_model=AlertSettingsResponse)
async def get_alert_settings( async def get_alert_settings(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = AlertService(db) svc = AlertService(db)
@@ -32,7 +34,7 @@ async def get_alert_settings(
@router.put("/settings") @router.put("/settings")
async def update_alert_settings( async def update_alert_settings(
body: AlertSettingsRequest, body: AlertSettingsRequest,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
raise HTTPException( raise HTTPException(
+2 -2
View File
@@ -16,7 +16,7 @@ router = APIRouter(prefix="/coupons", tags=["coupons"])
@router.get("", response_model=list[CouponResponse]) @router.get("", response_model=list[CouponResponse])
async def list_coupons( async def list_coupons(
store_id: UUID | None = Query(None), store_id: UUID | None = Query(None),
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = CouponService(db) svc = CouponService(db)
@@ -25,7 +25,7 @@ async def list_coupons(
@router.get("/relevant", response_model=list[CouponResponse]) @router.get("/relevant", response_model=list[CouponResponse])
async def relevant_coupons( async def relevant_coupons(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = CouponService(db) svc = CouponService(db)
+3 -3
View File
@@ -20,7 +20,7 @@ router = APIRouter(prefix="/prices", tags=["prices"])
@router.get("/trends", response_model=list[PriceTrendResponse]) @router.get("/trends", response_model=list[PriceTrendResponse])
async def price_trends( async def price_trends(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
category: str | None = Query(None), category: str | None = Query(None),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
@@ -30,7 +30,7 @@ async def price_trends(
@router.get("/increases", response_model=list[PriceIncreaseResponse]) @router.get("/increases", response_model=list[PriceIncreaseResponse])
async def price_increases( async def price_increases(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = PriceService(db) svc = PriceService(db)
@@ -40,7 +40,7 @@ async def price_increases(
@router.get("/comparison", response_model=list[PriceComparisonResponse]) @router.get("/comparison", response_model=list[PriceComparisonResponse])
async def price_comparison( async def price_comparison(
product_ids: Annotated[list[UUID], Query()], product_ids: Annotated[list[UUID], Query()],
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = PriceService(db) svc = PriceService(db)
+3 -3
View File
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/products", tags=["products"])
@router.get("", response_model=list[ProductResponse]) @router.get("", response_model=list[ProductResponse])
async def list_products( async def list_products(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
q: str | None = Query(None), q: str | None = Query(None),
category: str | None = Query(None), category: str | None = Query(None),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
@@ -29,7 +29,7 @@ async def list_products(
@router.get("/{product_id}", response_model=ProductDetailResponse) @router.get("/{product_id}", response_model=ProductDetailResponse)
async def get_product( async def get_product(
product_id: UUID, product_id: UUID,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = ProductService(db) svc = ProductService(db)
@@ -44,7 +44,7 @@ async def get_product(
@router.get("/{product_id}/prices", response_model=PriceTrendResponse) @router.get("/{product_id}/prices", response_model=PriceTrendResponse)
async def get_product_prices( async def get_product_prices(
product_id: UUID, product_id: UUID,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = ProductService(db) svc = ProductService(db)
+3 -3
View File
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/purchases", tags=["purchases"])
@router.get("", response_model=list[PurchaseResponse]) @router.get("", response_model=list[PurchaseResponse])
async def list_purchases( async def list_purchases(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
store_id: UUID | None = Query(None), store_id: UUID | None = Query(None),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
@@ -27,7 +27,7 @@ async def list_purchases(
@router.get("/stats", response_model=PurchaseStatsResponse) @router.get("/stats", response_model=PurchaseStatsResponse)
async def purchase_stats( async def purchase_stats(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = PurchaseService(db) svc = PurchaseService(db)
@@ -37,7 +37,7 @@ async def purchase_stats(
@router.get("/{purchase_id}", response_model=PurchaseDetailResponse) @router.get("/{purchase_id}", response_model=PurchaseDetailResponse)
async def get_purchase( async def get_purchase(
purchase_id: UUID, purchase_id: UUID,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = PurchaseService(db) svc = PurchaseService(db)
+4 -2
View File
@@ -1,5 +1,7 @@
"""Scraping routes: trigger sync, check status (proxy to ReceiptWitness).""" """Scraping routes: trigger sync, check status (proxy to ReceiptWitness)."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from httpx import HTTPStatusError, RequestError from httpx import HTTPStatusError, RequestError
@@ -11,7 +13,7 @@ router = APIRouter(prefix="/scraping", tags=["scraping"])
@router.post("/{store_slug}/sync", response_model=SyncTriggerResponse) @router.post("/{store_slug}/sync", response_model=SyncTriggerResponse)
async def trigger_sync(store_slug: str, user_id: str = Depends(get_current_user)): async def trigger_sync(store_slug: str, user_id: UUID = Depends(get_current_user)):
client = ReceiptWitnessClient() client = ReceiptWitnessClient()
try: try:
result = await client.trigger_sync(str(user_id), store_slug) result = await client.trigger_sync(str(user_id), store_slug)
@@ -29,7 +31,7 @@ async def trigger_sync(store_slug: str, user_id: str = Depends(get_current_user)
@router.get("/status", response_model=list[SyncStatusResponse]) @router.get("/status", response_model=list[SyncStatusResponse])
async def sync_status(user_id: str = Depends(get_current_user)): async def sync_status(user_id: UUID = Depends(get_current_user)):
client = ReceiptWitnessClient() client = ReceiptWitnessClient()
try: try:
return await client.get_sync_status(str(user_id)) return await client.get_sync_status(str(user_id))
+4 -2
View File
@@ -1,5 +1,7 @@
"""Shopping routes: optimize list, saved lists.""" """Shopping routes: optimize list, saved lists."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from httpx import HTTPStatusError, RequestError from httpx import HTTPStatusError, RequestError
@@ -11,7 +13,7 @@ router = APIRouter(prefix="/shopping", tags=["shopping"])
@router.post("/optimize", response_model=OptimizeResponse) @router.post("/optimize", response_model=OptimizeResponse)
async def optimize_shopping(body: OptimizeRequest, user_id: str = Depends(get_current_user)): async def optimize_shopping(body: OptimizeRequest, user_id: UUID = Depends(get_current_user)):
client = ClipArtistClient() client = ClipArtistClient()
try: try:
result = await client.optimize( result = await client.optimize(
@@ -35,7 +37,7 @@ async def optimize_shopping(body: OptimizeRequest, user_id: str = Depends(get_cu
@router.get("/lists", response_model=list[ShoppingListResponse]) @router.get("/lists", response_model=list[ShoppingListResponse])
async def list_shopping_lists(user_id: str = Depends(get_current_user)): async def list_shopping_lists(user_id: UUID = Depends(get_current_user)):
client = ClipArtistClient() client = ClipArtistClient()
try: try:
return await client.get_shopping_lists(str(user_id)) return await client.get_shopping_lists(str(user_id))
+5 -3
View File
@@ -1,5 +1,7 @@
"""Store routes: list stores, manage user store connections.""" """Store routes: list stores, manage user store connections."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -19,7 +21,7 @@ async def list_stores(db: AsyncSession = Depends(get_db)):
@router.get("/me/stores", response_model=list[StoreAccountResponse]) @router.get("/me/stores", response_model=list[StoreAccountResponse])
async def list_user_stores( async def list_user_stores(
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = StoreService(db) svc = StoreService(db)
@@ -34,7 +36,7 @@ async def list_user_stores(
async def connect_store( async def connect_store(
store_slug: str, store_slug: str,
body: ConnectStoreRequest, body: ConnectStoreRequest,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = StoreService(db) svc = StoreService(db)
@@ -49,7 +51,7 @@ async def connect_store(
@router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/me/stores/{store_slug}", status_code=status.HTTP_204_NO_CONTENT)
async def disconnect_store( async def disconnect_store(
store_slug: str, store_slug: str,
user_id: str = Depends(get_current_user), user_id: UUID = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
svc = StoreService(db) svc = StoreService(db)
+5 -3
View File
@@ -4,6 +4,8 @@ Alerts are generated by StickerShock and ShrinkRay services and written to the D
This service reads them for the API gateway. This service reads them for the API gateway.
""" """
from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@@ -13,7 +15,7 @@ class AlertService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def list_alerts(self, user_id: str) -> list[dict]: async def list_alerts(self, user_id: UUID) -> list[dict]:
"""List shrinkflation events for products the user has purchased.""" """List shrinkflation events for products the user has purchased."""
from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent from cartsnitch_api.models import Purchase, PurchaseItem, ShrinkflationEvent
@@ -55,7 +57,7 @@ class AlertService:
for e in events for e in events
] ]
async def get_settings(self, user_id: str) -> dict: async def get_settings(self, user_id: UUID) -> dict:
# Alert settings would be stored in a user_settings table. # Alert settings would be stored in a user_settings table.
# For now, return defaults since the table doesn't exist yet in common lib. # For now, return defaults since the table doesn't exist yet in common lib.
return { return {
@@ -64,7 +66,7 @@ class AlertService:
"email_notifications": False, "email_notifications": False,
} }
async def update_settings(self, user_id: str, **fields) -> dict: async def update_settings(self, user_id: UUID, **fields) -> dict:
# Would update user_settings table. Return merged defaults for now. # Would update user_settings table. Return merged defaults for now.
current = await self.get_settings(user_id) current = await self.get_settings(user_id)
for k, v in fields.items(): for k, v in fields.items():
+5 -3
View File
@@ -5,6 +5,8 @@ handled by the Better-Auth service (auth/). This service provides
user lookup and profile update operations for the API gateway. user lookup and profile update operations for the API gateway.
""" """
from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,7 +15,7 @@ class AuthService:
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self.db = db self.db = db
async def get_user(self, user_id: str) -> dict: async def get_user(self, user_id: UUID) -> dict:
from cartsnitch_api.models import User from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id)) result = await self.db.execute(select(User).where(User.id == user_id))
@@ -28,7 +30,7 @@ class AuthService:
"created_at": user.created_at, "created_at": user.created_at,
} }
async def update_user(self, user_id: str, **fields) -> dict: async def update_user(self, user_id: UUID, **fields) -> dict:
from cartsnitch_api.models import User from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id)) result = await self.db.execute(select(User).where(User.id == user_id))
@@ -56,7 +58,7 @@ class AuthService:
"created_at": user.created_at, "created_at": user.created_at,
} }
async def delete_user(self, user_id: str) -> None: async def delete_user(self, user_id: UUID) -> None:
from cartsnitch_api.models import User from cartsnitch_api.models import User
result = await self.db.execute(select(User).where(User.id == user_id)) result = await self.db.execute(select(User).where(User.id == user_id))
+1 -1
View File
@@ -29,7 +29,7 @@ class CouponService:
coupons = result.scalars().all() coupons = result.scalars().all()
return [self._to_dict(c) for c in coupons] return [self._to_dict(c) for c in coupons]
async def relevant_coupons(self, user_id: str) -> list[dict]: async def relevant_coupons(self, user_id: UUID) -> list[dict]:
"""Coupons for products the user has purchased.""" """Coupons for products the user has purchased."""
from cartsnitch_api.models import Coupon, PurchaseItem from cartsnitch_api.models import Coupon, PurchaseItem
+3 -3
View File
@@ -13,7 +13,7 @@ class PurchaseService:
async def list_purchases( async def list_purchases(
self, self,
user_id: str, user_id: UUID,
store_id: UUID | None = None, store_id: UUID | None = None,
page: int = 1, page: int = 1,
page_size: int = 20, page_size: int = 20,
@@ -56,7 +56,7 @@ class PurchaseService:
for p, item_count, store_name in result.all() for p, item_count, store_name in result.all()
] ]
async def get_purchase(self, purchase_id: UUID, user_id: str) -> dict: async def get_purchase(self, purchase_id: UUID, user_id: UUID) -> dict:
from cartsnitch_api.models import Purchase from cartsnitch_api.models import Purchase
result = await self.db.execute( result = await self.db.execute(
@@ -88,7 +88,7 @@ class PurchaseService:
], ],
} }
async def get_stats(self, user_id: str) -> dict: async def get_stats(self, user_id: UUID) -> dict:
from cartsnitch_api.models import Purchase from cartsnitch_api.models import Purchase
result = await self.db.execute( result = await self.db.execute(
+4 -3
View File
@@ -1,6 +1,7 @@
"""Store service — list stores, manage user store account connections.""" """Store service — list stores, manage user store account connections."""
import json import json
from uuid import UUID
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from sqlalchemy import select from sqlalchemy import select
@@ -34,7 +35,7 @@ class StoreService:
for s in stores for s in stores
] ]
async def list_user_stores(self, user_id: str) -> list[dict]: async def list_user_stores(self, user_id: UUID) -> list[dict]:
from cartsnitch_api.models import UserStoreAccount from cartsnitch_api.models import UserStoreAccount
result = await self.db.execute( result = await self.db.execute(
@@ -59,7 +60,7 @@ class StoreService:
for a in accounts for a in accounts
] ]
async def connect_store(self, user_id: str, store_slug: str, credentials: dict | None) -> dict: async def connect_store(self, user_id: UUID, store_slug: str, credentials: dict | None) -> dict:
from cartsnitch_api.models import Store, UserStoreAccount from cartsnitch_api.models import Store, UserStoreAccount
result = await self.db.execute(select(Store).where(Store.slug == store_slug)) result = await self.db.execute(select(Store).where(Store.slug == store_slug))
@@ -106,7 +107,7 @@ class StoreService:
"sync_status": "active", "sync_status": "active",
} }
async def disconnect_store(self, user_id: str, store_slug: str) -> None: async def disconnect_store(self, user_id: UUID, store_slug: str) -> None:
from cartsnitch_api.models import Store, UserStoreAccount from cartsnitch_api.models import Store, UserStoreAccount
result = await self.db.execute(select(Store).where(Store.slug == store_slug)) result = await self.db.execute(select(Store).where(Store.slug == store_slug))
+1 -6
View File
@@ -3,12 +3,7 @@
"collect": { "collect": {
"staticDistDir": "./dist", "staticDistDir": "./dist",
"url": ["http://localhost:4173/"], "url": ["http://localhost:4173/"],
"numberOfRuns": 1, "numberOfRuns": 1
"settings": {
"chromeFlags": ["--headless=new", "--no-sandbox", "--disable-gpu", "--disable-dev-shm-usage"],
"skipAudits": ["bf-cache"],
"disableFullPageScreenshot": true
}
}, },
"assert": { "assert": {
"assertions": { "assertions": {
-61
View File
@@ -1,61 +0,0 @@
# seed-dev-job.yaml
# K8s Job to run the CartSnitch seed runner against the dev database.
#
# Usage:
# kubectl apply -f seed-dev-job.yaml -n cartsnitch-dev
#
# To view logs:
# kubectl logs -n cartsnitch-dev job/seed-dev -f
#
# To re-run after fixing issues:
# kubectl delete -f seed-dev-job.yaml -n cartsnitch-dev && kubectl apply -f seed-dev-job.yaml -n cartsnitch-dev
#
apiVersion: batch/v1
kind: Job
metadata:
name: seed-dev
namespace: cartsnitch-dev
labels:
app: cartsnitch
component: seed
environment: dev
annotations:
description: "Runs cartsnitch-common seed runner to populate dev database with realistic test data."
spec:
# Prevent retries — a failed seed run should be investigated, not auto-repeated.
backoffLimit: 0
# Do not run concurrently; sequential runs are safer for truncate+reseed.
concurrencyPolicy: Forbid
template:
metadata:
labels:
app: cartsnitch
component: seed
environment: dev
spec:
restartPolicy: Never
containers:
- name: seed
# Use slim Python image with the cartsnitch-common package installed from git.
# The common repo is public; no additional secret is needed for the pip install.
image: python:3.12-slim
command:
- sh
- -c
- |
pip install --no-cache-dir "cartsnitch-common @ git+https://github.com/cartsnitch/common.git@main" && \
python -m cartsnitch_common.seed --database-url "$${DATABASE_URL}"
env:
- name: DATABASE_URL
valueFrom:
secretKeyRef:
name: cartsnitch-secrets
key: database-url-pg
optional: false
resources:
requests:
cpu: 100m
memory: 256Mi
limits:
cpu: 500m
memory: 512Mi
-104
View File
@@ -1,104 +0,0 @@
#!/usr/bin/env bash
# =============================================================================
# seed-dev.sh — Run the CartSnitch seed runner against the dev database.
#
# Usage:
# ./seed-dev.sh Run full seed against dev
# ./seed-dev.sh --dry-run Show planned record counts without writing
# ./seed-dev.sh --help Show this help
#
# Prerequisites:
# - kubectl configured for the cartsnitch-dev cluster
# - Namespace cartsnitch-dev exists (CNPG Postgres must be running)
#
# What it does:
# 1. Starts a background port-forward to cartsnitch-pg-rw:5432
# 2. Waits for the tunnel to be ready
# 3. Runs python -m cartsnitch_common.seed with --database-url pointing
# to localhost:<forwarded-port>/cartsnitch
# 4. Cleans up the port-forward on exit (normal, interrupt, or error)
# =============================================================================
set -euo pipefail
# --- Config -------------------------------------------------------------------
readonly NAMESPACE="cartsnitch-dev"
readonly SVC_NAME="cartsnitch-pg-rw"
readonly LOCAL_PORT="5433" # use a non-privileged port to avoid conflicts
readonly DB_NAME="cartsnitch"
readonly PG_USER="cartsnitch"
# Retrieve password from the CNPG credentials secret
readonly PG_PASSWORD="$(
kubectl get secret cartsnitch-pg-credentials \
-n "$NAMESPACE" \
-o jsonpath='{.data.password}' \
| base64 -d
)"
readonly DB_URL="postgresql://${PG_USER}:${PG_PASSWORD}@localhost:${LOCAL_PORT}/${DB_NAME}"
# --- Helpers ------------------------------------------------------------------
log() { echo "[seed-dev] $*"; }
fail() { log "ERROR: $*" >&2; exit 1; }
# Cleanup port-forward and exit.
cleanup() {
if [[ -n "${PF_PID:-}" ]]; then
log "Stopping port-forward (PID $PF_PID)..."
kill "$PF_PID" 2>/dev/null || true
wait "$PF_PID" 2>/dev/null || true
fi
}
trap cleanup EXIT
# --- Args ---------------------------------------------------------------------
DRY_RUN=""
HELP_FLAG=""
while [[ $# -gt 0 ]]; do
case "$1" in
--dry-run) DRY_RUN="--dry-run"; shift ;;
--help) HELP_FLAG="1"; shift ;;
*) fail "Unknown argument: $1";;
esac
done
if [[ -n "$HELP_FLAG" ]]; then
sed -n '3,/^# ---/p' "$0" | head -n -1 | sed 's/^# //'
echo ""
echo "Additional arguments are passed through to the seed runner."
echo "Common seed-runner options:"
echo " --dry-run Show planned record counts without writing"
echo " --seed N Set random seed (default: 42)"
exit 0
fi
# --- Prerequisites ------------------------------------------------------------
if ! command -v kubectl &>/dev/null; then
fail "kubectl not found — must be installed and configured."
fi
# --- Port-forward -------------------------------------------------------------
log "Starting port-forward ${SVC_NAME}:5432 -> localhost:${LOCAL_PORT} ..."
kubectl port-forward \
-n "$NAMESPACE" \
svc/"$SVC_NAME" \
"${LOCAL_PORT}:5432" \
&>/dev/null &
PF_PID=$!
# Give the tunnel a moment to establish
sleep 2
# Verify the tunnel is up
if ! kill -0 "$PF_PID" 2>/dev/null; then
fail "Port-forward failed to start."
fi
log "Port-forward active (PID $PF_PID) on localhost:${LOCAL_PORT}"
# --- Seed --------------------------------------------------------------------
log "Running seed against dev database..."
set -x
python -m cartsnitch_common.seed --database-url "$DB_URL" $DRY_RUN
set +x
log "Done."
+2 -2
View File
@@ -35,7 +35,7 @@ export function useProduct(id: string) {
export function usePriceHistory(productId: string) { export function usePriceHistory(productId: string) {
return useQuery({ return useQuery({
queryKey: ['priceHistory', productId], queryKey: ['priceHistory', productId],
queryFn: () => api.get<PriceHistory[]>(`/products/${productId}/prices`), queryFn: () => api.get<PriceHistory[]>(`/products/${productId}/price-history`),
enabled: !!productId, enabled: !!productId,
}) })
} }
@@ -50,6 +50,6 @@ export function useCoupons() {
export function usePriceAlerts() { export function usePriceAlerts() {
return useQuery({ return useQuery({
queryKey: ['priceAlerts'], queryKey: ['priceAlerts'],
queryFn: () => api.get<PriceAlert[]>('/alerts'), queryFn: () => api.get<PriceAlert[]>('/price-alerts'),
}) })
} }
+2 -2
View File
@@ -15,7 +15,7 @@ const mockRoutes: Record<string, (path: string) => unknown> = {
'/purchases': () => mockPurchases, '/purchases': () => mockPurchases,
'/products': () => mockProducts, '/products': () => mockProducts,
'/coupons': () => mockCoupons, '/coupons': () => mockCoupons,
'/alerts': () => mockAlerts, '/price-alerts': () => mockAlerts,
} }
function matchMockRoute<T>(path: string): T | null { function matchMockRoute<T>(path: string): T | null {
@@ -30,7 +30,7 @@ function matchMockRoute<T>(path: string): T | null {
} }
// /products/:id/price-history // /products/:id/price-history
const priceHistoryMatch = path.match(/^\/products\/(.+)\/prices$/) const priceHistoryMatch = path.match(/^\/products\/(.+)\/price-history$/)
if (priceHistoryMatch) { if (priceHistoryMatch) {
return getMockPriceHistory(priceHistoryMatch[1]) as T return getMockPriceHistory(priceHistoryMatch[1]) as T
} }
+31 -3
View File
@@ -1,8 +1,13 @@
import React, { Suspense } from 'react'
import { Link } from 'react-router-dom' import { Link } from 'react-router-dom'
import { authClient } from '../lib/auth-client.ts' import { authClient } from '../lib/auth-client.ts'
import { usePurchases, usePriceAlerts } from '../hooks/useApi.ts' import { usePurchases, usePriceAlerts, usePriceHistory } from '../hooks/useApi.ts'
import { StoreIcon } from '../components/StoreIcon.tsx' import { StoreIcon } from '../components/StoreIcon.tsx'
const LazySparklineCard = React.lazy(() =>
import('../components/SparklineChart.tsx').then((mod) => ({ default: mod.SparklineCard }))
)
export function Dashboard() { export function Dashboard() {
const { data: session, isPending } = authClient.useSession() const { data: session, isPending } = authClient.useSession()
@@ -39,11 +44,19 @@ export function Dashboard() {
function AuthenticatedDashboard({ userName }: { userName: string }) { function AuthenticatedDashboard({ userName }: { userName: string }) {
const { data: purchases = [], isLoading: purchasesLoading } = usePurchases() const { data: purchases = [], isLoading: purchasesLoading } = usePurchases()
const { data: alerts = [], isLoading: alertsLoading } = usePriceAlerts() const { data: alerts = [], isLoading: alertsLoading } = usePriceAlerts()
const { data: eggHistory = [] } = usePriceHistory('prod10')
const { data: milkHistory = [] } = usePriceHistory('prod1')
const triggeredAlerts = alerts.filter((a) => a.triggered) const triggeredAlerts = alerts.filter((a) => a.triggered)
const watchingAlerts = alerts.filter((a) => !a.triggered) const watchingAlerts = alerts.filter((a) => !a.triggered)
const recentPurchases = purchases.slice(0, 3) const recentPurchases = purchases.slice(0, 3)
const sparklineData = eggHistory.filter((p) => p.storeId === 'meijer').slice(-8)
const milkSparkline = milkHistory.filter((p) => p.storeId === 'kroger').slice(-8)
const eggCurrent = sparklineData.length > 0 ? `$${sparklineData[sparklineData.length - 1].price.toFixed(2)}` : '—'
const milkCurrent = milkSparkline.length > 0 ? `$${milkSparkline[milkSparkline.length - 1].price.toFixed(2)}` : '—'
if (purchasesLoading || alertsLoading) { if (purchasesLoading || alertsLoading) {
return <DashboardSkeleton /> return <DashboardSkeleton />
} }
@@ -93,8 +106,11 @@ function AuthenticatedDashboard({ userName }: { userName: string }) {
{/* Price trend sparklines */} {/* Price trend sparklines */}
<section className="mt-6"> <section className="mt-6">
<h2 className="mb-3 text-lg font-semibold text-gray-700">Price Trends</h2> <h2 className="mb-3 text-lg font-semibold text-gray-700">Price Trends</h2>
<div className="rounded-xl bg-white p-4 shadow-sm text-center text-sm text-gray-400"> <div className="space-y-3">
Connect a store to see price trends <Suspense fallback={<SparklinePlaceholder />}>
<LazySparklineCard label="Eggs (dozen)" data={sparklineData} current={eggCurrent} />
<LazySparklineCard label="Whole Milk (1 gal)" data={milkSparkline} current={milkCurrent} />
</Suspense>
</div> </div>
</section> </section>
@@ -171,3 +187,15 @@ function DashboardSkeleton() {
</div> </div>
) )
} }
function SparklinePlaceholder() {
return (
<div className="flex items-center gap-4 rounded-xl bg-white p-4 shadow-sm animate-pulse">
<div className="min-w-0 flex-1">
<div className="h-4 w-24 rounded bg-gray-200" />
<div className="mt-2 h-6 w-16 rounded bg-gray-200" />
</div>
<div className="h-10 w-24 rounded bg-gray-100" />
</div>
)
}
+1 -7
View File
@@ -31,14 +31,8 @@ export function Login() {
throw new Error(authError.message ?? 'Sign in failed') throw new Error(authError.message ?? 'Sign in failed')
} }
// After successful signIn, force a session fetch to confirm the cookie is set setAuthenticated(true)
// before navigating to the protected route
const sessionResult = await authClient.getSession()
if (sessionResult.data) {
navigate('/') navigate('/')
} else {
setError('Sign in failed. Please try again.')
}
} catch { } catch {
if (import.meta.env.VITE_MOCK_AUTH === 'true') { if (import.meta.env.VITE_MOCK_AUTH === 'true') {
setAuthenticated(true) setAuthenticated(true)
+1 -8
View File
@@ -38,15 +38,8 @@ export function Register() {
throw new Error(authError.message ?? 'Registration failed') throw new Error(authError.message ?? 'Registration failed')
} }
// After successful signUp, force a session fetch to confirm the cookie is set setAuthenticated(true)
// before navigating to the protected route
const sessionResult = await authClient.getSession()
if (sessionResult.data) {
navigate('/') navigate('/')
} else {
// Session not established — show success message and link to login
setError('Account created! Please sign in.')
}
} catch { } catch {
if (import.meta.env.VITE_MOCK_AUTH === 'true') { if (import.meta.env.VITE_MOCK_AUTH === 'true') {
setAuthenticated(true) setAuthenticated(true)
+1 -1
View File
@@ -61,5 +61,5 @@ export const handlers = [
http.get('/api/v1/products', () => HttpResponse.json(mockProducts)), http.get('/api/v1/products', () => HttpResponse.json(mockProducts)),
http.get('/api/v1/products/prod_1', () => HttpResponse.json(mockProducts[0])), http.get('/api/v1/products/prod_1', () => HttpResponse.json(mockProducts[0])),
http.get('/api/v1/coupons', () => HttpResponse.json(mockCoupons)), http.get('/api/v1/coupons', () => HttpResponse.json(mockCoupons)),
http.get('/api/v1/alerts', () => HttpResponse.json(mockAlerts)), http.get('/api/v1/price-alerts', () => HttpResponse.json(mockAlerts)),
] ]