feat: merge cartsnitch/api into api/ subdirectory
Consolidate API gateway service into monorepo. Squashed from https://github.com/cartsnitch/api main (89bacb1). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
"""Shared test fixtures with in-memory SQLite database."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
from cartsnitch_api.database import get_db
|
||||
from cartsnitch_api.main import create_app
|
||||
from cartsnitch_api.models import Base
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limiting():
|
||||
"""Disable rate limiting for all tests to prevent 429 interference."""
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
"""Sync in-memory SQLite engine for model unit tests."""
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
"""Sync SQLAlchemy session for model unit tests."""
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_engine):
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async def override_get_db():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client):
|
||||
"""Register a test user and return auth headers."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "test@example.com",
|
||||
"password": "testpass123",
|
||||
"display_name": "Test User",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
token = resp.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Integration tests for auth endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_success(client):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "new@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "New User",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
assert data["expires_in"] == 900 # 15 min * 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_duplicate_email(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User One",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass456",
|
||||
"display_name": "User Two",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_short_password(client):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "short@example.com",
|
||||
"password": "short",
|
||||
"display_name": "Short Pass",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Login User",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "login@example.com",
|
||||
"password": "securepass123",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(client):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "wrong@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Wrong Pass",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "wrong@example.com",
|
||||
"password": "badpassword1",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_nonexistent_user(client):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "ghost@example.com",
|
||||
"password": "doesntmatter",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token(client):
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "refresh@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Refresh User",
|
||||
},
|
||||
)
|
||||
refresh_token = reg.json()["refresh_token"]
|
||||
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_with_invalid_token(client):
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "invalid.token.here",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me(client, auth_headers):
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == "test@example.com"
|
||||
assert data["display_name"] == "Test User"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_unauthorized(client):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403) # No auth header
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me(client, auth_headers):
|
||||
resp = await client.patch(
|
||||
"/auth/me",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"display_name": "Updated Name",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["display_name"] == "Updated Name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_me(client, auth_headers):
|
||||
resp = await client.delete("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify user is gone (token still valid but user deleted)
|
||||
resp = await client.get("/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_after_delete_fails(client):
|
||||
"""Refresh token for a deleted user must be rejected."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "ghost@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Ghost User",
|
||||
},
|
||||
)
|
||||
tokens = reg.json()
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
|
||||
# Delete the user
|
||||
resp = await client.delete("/auth/me", headers=headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Refresh token should now fail
|
||||
resp = await client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": tokens["refresh_token"],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Shared fixtures for E2E integration tests.
|
||||
|
||||
Seeds a realistic dataset with stores, products, price history,
|
||||
purchases, coupons, and shrinkflation events so E2E flows can
|
||||
exercise cross-resource queries against real data.
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.auth.jwt import decode_token
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
)
|
||||
|
||||
# Shared test constants
|
||||
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
BAD_UUID = "not-a-uuid"
|
||||
# Fixed anchor date for deterministic tests
|
||||
ANCHOR_DATE = date(2026, 3, 15)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seed_data(db_engine, auth_headers):
|
||||
"""Seed a full dataset and return identifiers for test assertions."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
# -- Stores --
|
||||
meijer = Store(name="Meijer", slug="meijer")
|
||||
kroger = Store(name="Kroger", slug="kroger")
|
||||
target = Store(name="Target", slug="target")
|
||||
session.add_all([meijer, kroger, target])
|
||||
await session.flush()
|
||||
|
||||
# -- Products --
|
||||
cheerios = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
size="18",
|
||||
size_unit="oz",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
milk = NormalizedProduct(
|
||||
canonical_name="Whole Milk 1gal",
|
||||
category="dairy",
|
||||
brand="Meijer",
|
||||
size="1",
|
||||
size_unit="gal",
|
||||
)
|
||||
chicken = NormalizedProduct(
|
||||
canonical_name="Chicken Breast 1lb",
|
||||
category="meat",
|
||||
brand=None,
|
||||
size="1",
|
||||
size_unit="lb",
|
||||
)
|
||||
session.add_all([cheerios, milk, chicken])
|
||||
await session.flush()
|
||||
|
||||
# -- Price history (multiple dates, multiple stores) --
|
||||
today = ANCHOR_DATE
|
||||
prices = []
|
||||
# Cheerios at Meijer: price increase over time
|
||||
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=price_val,
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Cheerios at Kroger: stable price
|
||||
for i in range(3):
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=cheerios.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=60 - i * 30),
|
||||
regular_price=Decimal("4.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Milk at Meijer
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=meijer.id,
|
||||
observed_date=today - timedelta(days=7),
|
||||
regular_price=Decimal("3.29"),
|
||||
source="receipt",
|
||||
)
|
||||
)
|
||||
# Milk at Kroger
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=milk.id,
|
||||
store_id=kroger.id,
|
||||
observed_date=today - timedelta(days=5),
|
||||
regular_price=Decimal("3.49"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
# Chicken at Target
|
||||
prices.append(
|
||||
PriceHistory(
|
||||
normalized_product_id=chicken.id,
|
||||
store_id=target.id,
|
||||
observed_date=today - timedelta(days=3),
|
||||
regular_price=Decimal("5.99"),
|
||||
source="catalog",
|
||||
)
|
||||
)
|
||||
session.add_all(prices)
|
||||
await session.flush()
|
||||
|
||||
# -- Purchases (need the user_id from the registered test user) --
|
||||
token = auth_headers["Authorization"].split(" ")[1]
|
||||
payload = decode_token(token)
|
||||
user_id = UUID(payload["sub"])
|
||||
|
||||
purchase1 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=meijer.id,
|
||||
receipt_id="meijer-2026-001",
|
||||
purchase_date=today - timedelta(days=10),
|
||||
total=Decimal("23.45"),
|
||||
subtotal=Decimal("21.50"),
|
||||
tax=Decimal("1.95"),
|
||||
)
|
||||
purchase2 = Purchase(
|
||||
user_id=user_id,
|
||||
store_id=kroger.id,
|
||||
receipt_id="kroger-2026-001",
|
||||
purchase_date=today - timedelta(days=5),
|
||||
total=Decimal("15.78"),
|
||||
subtotal=Decimal("14.50"),
|
||||
tax=Decimal("1.28"),
|
||||
)
|
||||
session.add_all([purchase1, purchase2])
|
||||
await session.flush()
|
||||
|
||||
# -- Purchase Items --
|
||||
item1 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Cheerios 18oz Box",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.79"),
|
||||
extended_price=Decimal("4.79"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
item2 = PurchaseItem(
|
||||
purchase_id=purchase1.id,
|
||||
product_name_raw="Meijer Whole Milk 1gal",
|
||||
quantity=Decimal("2"),
|
||||
unit_price=Decimal("3.29"),
|
||||
extended_price=Decimal("6.58"),
|
||||
normalized_product_id=milk.id,
|
||||
)
|
||||
item3 = PurchaseItem(
|
||||
purchase_id=purchase2.id,
|
||||
product_name_raw="KRO CHEERIOS 18OZ",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("4.49"),
|
||||
extended_price=Decimal("4.49"),
|
||||
normalized_product_id=cheerios.id,
|
||||
)
|
||||
session.add_all([item1, item2, item3])
|
||||
await session.flush()
|
||||
|
||||
# -- Coupons --
|
||||
coupon1 = Coupon(
|
||||
store_id=meijer.id,
|
||||
normalized_product_id=cheerios.id,
|
||||
title="$1 off Cheerios",
|
||||
description="Save $1 on any Cheerios 18oz or larger",
|
||||
discount_type="fixed",
|
||||
discount_value=Decimal("1.00"),
|
||||
valid_from=today - timedelta(days=7),
|
||||
valid_to=today + timedelta(days=30),
|
||||
)
|
||||
coupon2 = Coupon(
|
||||
store_id=kroger.id,
|
||||
normalized_product_id=None,
|
||||
title="10% off dairy",
|
||||
description="10% off all dairy products",
|
||||
discount_type="percent",
|
||||
discount_value=Decimal("10.00"),
|
||||
valid_from=today - timedelta(days=3),
|
||||
valid_to=today + timedelta(days=14),
|
||||
)
|
||||
session.add_all([coupon1, coupon2])
|
||||
await session.flush()
|
||||
|
||||
# -- Shrinkflation events --
|
||||
shrink = ShrinkflationEvent(
|
||||
normalized_product_id=cheerios.id,
|
||||
detected_date=today - timedelta(days=15),
|
||||
old_size="20",
|
||||
new_size="18",
|
||||
old_unit="oz",
|
||||
new_unit="oz",
|
||||
price_at_old_size=Decimal("3.99"),
|
||||
price_at_new_size=Decimal("4.29"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced from 20oz to 18oz while price increased",
|
||||
)
|
||||
session.add(shrink)
|
||||
await session.commit()
|
||||
|
||||
for obj in [
|
||||
meijer,
|
||||
kroger,
|
||||
target,
|
||||
cheerios,
|
||||
milk,
|
||||
chicken,
|
||||
purchase1,
|
||||
purchase2,
|
||||
item1,
|
||||
item2,
|
||||
item3,
|
||||
coupon1,
|
||||
coupon2,
|
||||
shrink,
|
||||
]:
|
||||
await session.refresh(obj)
|
||||
|
||||
return {
|
||||
"headers": auth_headers,
|
||||
"user_id": user_id,
|
||||
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
|
||||
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
|
||||
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
|
||||
"items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3},
|
||||
"coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2},
|
||||
"shrinkflation": {"cheerios_shrink": shrink},
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
"""E2E: Auth and token validation flows."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthRegistrationLogin:
|
||||
"""Full registration → login → token refresh → profile flow."""
|
||||
|
||||
async def test_full_auth_lifecycle(self, client, db_engine):
|
||||
"""Register → login → get profile → refresh → get profile again."""
|
||||
# Register
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "lifecycle@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Lifecycle User",
|
||||
},
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
tokens = reg.json()
|
||||
assert "access_token" in tokens
|
||||
assert "refresh_token" in tokens
|
||||
assert tokens["token_type"] == "bearer"
|
||||
assert tokens["expires_in"] > 0
|
||||
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
|
||||
# Get profile with access token
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code == 200
|
||||
assert me.json()["email"] == "lifecycle@example.com"
|
||||
assert me.json()["display_name"] == "Lifecycle User"
|
||||
|
||||
# Sleep 1s so the new token has a different exp than the registration token
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Login with same credentials
|
||||
login = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "lifecycle@example.com", "password": "securepass123"},
|
||||
)
|
||||
assert login.status_code == 200
|
||||
login_tokens = login.json()
|
||||
assert login_tokens["access_token"] != tokens["access_token"]
|
||||
|
||||
# Refresh token
|
||||
refresh = await client.post(
|
||||
"/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert refresh.status_code == 200
|
||||
new_tokens = refresh.json()
|
||||
assert new_tokens["access_token"] != tokens["access_token"]
|
||||
|
||||
# Use refreshed token to access profile
|
||||
new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"}
|
||||
me2 = await client.get("/auth/me", headers=new_headers)
|
||||
assert me2.status_code == 200
|
||||
assert me2.json()["email"] == "lifecycle@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTokenValidation:
|
||||
"""Token edge cases and error responses."""
|
||||
|
||||
async def test_expired_token_rejected(self, client, db_engine):
|
||||
"""Manually craft an expired token and verify rejection."""
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"exp": datetime.now(UTC) - timedelta(minutes=5),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_invalid_token_rejected(self, client, db_engine):
|
||||
resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_missing_auth_header(self, client, db_engine):
|
||||
resp = await client.get("/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
async def test_refresh_token_cannot_access_endpoints(self, client, db_engine):
|
||||
"""A refresh token should not work as an access token."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "refresh-test@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Refresh Test",
|
||||
},
|
||||
)
|
||||
refresh_token = reg.json()["refresh_token"]
|
||||
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_deleted_user_token_invalid(self, client, db_engine):
|
||||
"""After deleting an account, tokens should no longer work."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "delete-me@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "Delete Me",
|
||||
},
|
||||
)
|
||||
tokens = reg.json()
|
||||
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
|
||||
# Delete account
|
||||
delete_resp = await client.delete("/auth/me", headers=headers)
|
||||
assert delete_resp.status_code == 204
|
||||
|
||||
# Profile should fail
|
||||
me = await client.get("/auth/me", headers=headers)
|
||||
assert me.status_code in (401, 404)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthProtectedEndpoints:
|
||||
"""Verify auth is enforced on all user-specific endpoints."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,path",
|
||||
[
|
||||
("GET", "/purchases"),
|
||||
("GET", "/products"),
|
||||
("GET", "/prices/trends"),
|
||||
("GET", "/prices/increases"),
|
||||
("GET", "/coupons"),
|
||||
("GET", "/alerts"),
|
||||
("GET", "/me/stores"),
|
||||
],
|
||||
)
|
||||
async def test_endpoints_require_auth(self, client, db_engine, method, path):
|
||||
resp = await client.request(method, path)
|
||||
assert resp.status_code in (401, 403), f"{method} {path} should require auth"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCrossUserDataIsolation:
|
||||
"""Verify that users cannot access other users' data."""
|
||||
|
||||
async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data):
|
||||
"""Register a second user and verify they cannot see User A's purchases."""
|
||||
# User A's purchase (from seed_data)
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
# Register User B
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userb@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User B",
|
||||
},
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
|
||||
# User B tries to access User A's specific purchase
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
|
||||
assert resp.status_code in (403, 404), (
|
||||
"User B should not be able to access User A's purchase"
|
||||
)
|
||||
|
||||
async def test_user_b_purchase_list_is_empty(self, client, seed_data):
|
||||
"""A new user should see no purchases (not User A's purchases)."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userc@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User C",
|
||||
},
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
|
||||
resp = await client.get("/purchases", headers=user_c_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no purchases"
|
||||
|
||||
async def test_user_b_stores_isolated(self, client, seed_data):
|
||||
"""User B's connected stores should be independent from User A."""
|
||||
reg = await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "userd@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "User D",
|
||||
},
|
||||
)
|
||||
assert reg.status_code == 201
|
||||
user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
|
||||
|
||||
# User D should have no connected stores
|
||||
resp = await client.get("/me/stores", headers=user_d_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0, "New user should have no connected stores"
|
||||
@@ -0,0 +1,114 @@
|
||||
"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectToPurchaseFlow:
|
||||
"""Connect a store, then verify purchases and related data are accessible."""
|
||||
|
||||
async def test_connect_store_then_list(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
# Connect to Meijer
|
||||
resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert resp.status_code in (200, 201)
|
||||
|
||||
# Verify store appears in user's connected stores
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
assert stores.status_code == 200
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "meijer" in slugs
|
||||
|
||||
async def test_disconnect_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
await client.post("/me/stores/kroger/connect", json={}, headers=headers)
|
||||
resp = await client.delete("/me/stores/kroger", headers=headers)
|
||||
assert resp.status_code in (200, 204)
|
||||
|
||||
# Verify store no longer in connected list
|
||||
stores = await client.get("/me/stores", headers=headers)
|
||||
slugs = [s["store"]["slug"] for s in stores.json()]
|
||||
assert "kroger" not in slugs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseToPriceFlow:
|
||||
"""Verify purchase data links to price comparison data."""
|
||||
|
||||
async def test_purchase_items_link_to_products(self, client, seed_data):
|
||||
"""Items from purchases reference products that have price data."""
|
||||
headers = seed_data["headers"]
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
|
||||
# Get purchase detail
|
||||
purchase = await client.get(f"/purchases/{purchase_id}", headers=headers)
|
||||
assert purchase.status_code == 200
|
||||
items = purchase.json()["line_items"]
|
||||
|
||||
# Get product detail for an item that has a product_id
|
||||
product_ids = [li["product_id"] for li in items if li.get("product_id")]
|
||||
assert len(product_ids) >= 1
|
||||
|
||||
for pid in product_ids:
|
||||
product = await client.get(f"/products/{pid}", headers=headers)
|
||||
assert product.status_code == 200
|
||||
assert len(product.json()["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCouponFlow:
|
||||
"""Verify coupon listing and relevance filtering."""
|
||||
|
||||
async def test_list_all_coupons(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
async def test_filter_coupons_by_store(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert all(c["store_name"] == "Meijer" for c in data)
|
||||
|
||||
async def test_relevant_coupons_for_user(self, client, seed_data):
|
||||
"""User bought Cheerios, so the Cheerios coupon should be relevant."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/coupons/relevant", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases"
|
||||
descriptions = [c["description"] for c in data]
|
||||
assert any("Cheerios" in d for d in descriptions)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAlertFlow:
|
||||
"""Verify alert listing with seeded data."""
|
||||
|
||||
async def test_list_alerts(self, client, seed_data):
|
||||
"""User bought Cheerios which has a shrinkflation event — may appear as alert."""
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
# If alerts are generated synchronously, verify shrinkflation alert content
|
||||
if len(data) > 0:
|
||||
alert_types = [a["alert_type"] for a in data]
|
||||
product_names = [a["product_name"] for a in data]
|
||||
assert any(t in ("shrinkflation", "price_increase") for t in alert_types)
|
||||
assert any("Cheerios" in name for name in product_names)
|
||||
|
||||
async def test_alert_settings_default(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
resp = await client.get("/alerts/settings", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "price_increase_threshold_pct" in data
|
||||
assert "shrinkflation_enabled" in data
|
||||
@@ -0,0 +1,127 @@
|
||||
"""E2E: Error responses for bad input across all endpoint categories."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRegistrationErrors:
|
||||
"""Validation errors during user registration."""
|
||||
|
||||
async def test_short_password(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_email(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_missing_fields(self, client, db_engine):
|
||||
resp = await client.post("/auth/register", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_empty_display_name(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_duplicate_email(self, client, db_engine):
|
||||
payload = {
|
||||
"email": "dupe@example.com",
|
||||
"password": "securepass123",
|
||||
"display_name": "First",
|
||||
}
|
||||
first = await client.post("/auth/register", json=payload)
|
||||
assert first.status_code == 201
|
||||
second = await client.post("/auth/register", json=payload)
|
||||
assert second.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoginErrors:
|
||||
"""Login failure modes."""
|
||||
|
||||
async def test_wrong_password(self, client, db_engine):
|
||||
await client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "login-err@example.com",
|
||||
"password": "correctpass1",
|
||||
"display_name": "Login",
|
||||
},
|
||||
)
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "login-err@example.com", "password": "wrongpass123"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_nonexistent_user(self, client, db_engine):
|
||||
resp = await client.post(
|
||||
"/auth/login",
|
||||
json={"email": "nobody@example.com", "password": "doesntmatter"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestNotFoundErrors:
|
||||
"""404 responses for missing resources."""
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_public_trend_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{ZERO_UUID}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMalformedInput:
|
||||
"""Invalid UUID formats and bad query params."""
|
||||
|
||||
async def test_invalid_uuid_product(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_purchase(self, client, seed_data):
|
||||
resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_invalid_uuid_public_trend(self, client, seed_data):
|
||||
resp = await client.get(f"/public/trends/{BAD_UUID}")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestStoreConnectionErrors:
|
||||
"""Store connection edge cases."""
|
||||
|
||||
async def test_connect_nonexistent_store(self, client, seed_data):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent-store/connect",
|
||||
json={},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_connect_store_twice(self, client, seed_data):
|
||||
headers = seed_data["headers"]
|
||||
first = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert first.status_code in (200, 201)
|
||||
second = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
|
||||
assert second.status_code == 409
|
||||
@@ -0,0 +1,102 @@
|
||||
"""E2E: Price history queries returning correct data."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceTrends:
|
||||
"""Verify price trend aggregation against seeded history."""
|
||||
|
||||
async def test_trends_returns_all_products(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
product_names = [t["product_name"] for t in data]
|
||||
assert "Cheerios 18oz" in product_names
|
||||
assert "Whole Milk 1gal" in product_names
|
||||
|
||||
async def test_trends_filter_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
# Only dairy products should appear
|
||||
for trend in data:
|
||||
assert trend["product_name"] == "Whole Milk 1gal"
|
||||
|
||||
async def test_trends_contain_data_points(self, client, seed_data):
|
||||
resp = await client.get("/prices/trends", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz")
|
||||
assert len(cheerios_trend["data_points"]) >= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceIncreases:
|
||||
"""Detect price increases from seeded price history."""
|
||||
|
||||
async def test_increases_detected(self, client, seed_data):
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Cheerios at Meijer went from 3.99 → 4.29 → 4.79
|
||||
cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"]
|
||||
assert len(cheerios_increases) >= 1
|
||||
# Verify the increase data makes sense
|
||||
for inc in cheerios_increases:
|
||||
assert inc["new_price"] > inc["old_price"]
|
||||
assert inc["increase_pct"] > 0
|
||||
assert inc["store_name"] == "Meijer"
|
||||
|
||||
async def test_stable_prices_not_flagged(self, client, seed_data):
|
||||
"""Kroger Cheerios price is stable at $4.49 — should not appear as increase."""
|
||||
resp = await client.get("/prices/increases", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
kroger_increases = [
|
||||
inc
|
||||
for inc in data
|
||||
if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger"
|
||||
]
|
||||
assert len(kroger_increases) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPriceComparison:
|
||||
"""Compare prices across stores for specific products."""
|
||||
|
||||
async def test_compare_cheerios_across_stores(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params={"product_ids": cheerios_id},
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
cheerios_cmp = data[0]
|
||||
assert cheerios_cmp["product_name"] == "Cheerios 18oz"
|
||||
store_names = [p["store_name"] for p in cheerios_cmp["prices"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_compare_requires_product_ids(self, client, seed_data):
|
||||
"""product_ids is required — omitting it must return 422."""
|
||||
resp = await client.get("/prices/comparison", headers=seed_data["headers"])
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_compare_multiple_products(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
milk_id = str(seed_data["products"]["milk"].id)
|
||||
resp = await client.get(
|
||||
"/prices/comparison",
|
||||
params=[("product_ids", cheerios_id), ("product_ids", milk_id)],
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
names = [c["product_name"] for c in data]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
@@ -0,0 +1,82 @@
|
||||
"""E2E: Product search/lookup endpoints with real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductSearch:
|
||||
"""Search and filter products against seeded data."""
|
||||
|
||||
async def test_list_all_products(self, client, seed_data):
|
||||
resp = await client.get("/products", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
names = [p["name"] for p in products]
|
||||
assert "Cheerios 18oz" in names
|
||||
assert "Whole Milk 1gal" in names
|
||||
assert "Chicken Breast 1lb" in names
|
||||
|
||||
async def test_search_by_name(self, client, seed_data):
|
||||
resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all("cheerios" in p["name"].lower() for p in products)
|
||||
|
||||
async def test_search_by_category(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"category": "dairy"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
products = resp.json()
|
||||
assert len(products) >= 1
|
||||
assert all(p["category"] == "dairy" for p in products)
|
||||
|
||||
async def test_search_no_results(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
"/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProductLookup:
|
||||
"""Detailed product lookups with cross-store pricing."""
|
||||
|
||||
async def test_get_product_detail_with_prices(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert data["category"] == "pantry"
|
||||
# Should have prices from both Meijer and Kroger
|
||||
store_names = [p["store_name"] for p in data["prices_by_store"]]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_product_prices_reflect_latest(self, client, seed_data):
|
||||
"""The latest Meijer price for Cheerios should be 4.79 (the increase)."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer")
|
||||
assert meijer_price["current_price"] == 4.79
|
||||
|
||||
async def test_product_not_found(self, client, seed_data):
|
||||
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_product_price_history(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations
|
||||
# Verify chronological ordering exists
|
||||
prices = [dp["price"] for dp in data["data_points"]]
|
||||
assert len(prices) >= 3
|
||||
@@ -0,0 +1,59 @@
|
||||
"""E2E: Public price transparency endpoints (no auth required)."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicTrends:
|
||||
"""Public price trend endpoint — no auth, real data."""
|
||||
|
||||
async def test_public_trend_returns_data(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) >= 3
|
||||
|
||||
async def test_public_trend_no_auth_needed(self, client, seed_data):
|
||||
"""Confirm no Authorization header is required."""
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(f"/public/trends/{cheerios_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicStoreComparison:
|
||||
"""Public store comparison endpoint."""
|
||||
|
||||
async def test_store_comparison(self, client, seed_data):
|
||||
cheerios_id = str(seed_data["products"]["cheerios"].id)
|
||||
resp = await client.get(
|
||||
"/public/store-comparison",
|
||||
params=[("product_ids", cheerios_id)],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "products" in data
|
||||
assert len(data["products"]) >= 1
|
||||
|
||||
async def test_store_comparison_rejects_more_than_20_ids(self, client):
|
||||
"""max_length=20 guard: 21 product IDs must return 422."""
|
||||
too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)]
|
||||
resp = await client.get("/public/store-comparison", params=too_many)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPublicInflation:
|
||||
"""Public inflation index endpoint."""
|
||||
|
||||
async def test_inflation_returns_index(self, client, seed_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "cartsnitch_index" in data
|
||||
assert "cpi_baseline" in data
|
||||
assert "categories" in data
|
||||
@@ -0,0 +1,87 @@
|
||||
"""E2E: Purchase listing, detail, and stats against real DB fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.test_e2e.conftest import ZERO_UUID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseList:
|
||||
"""List and filter a user's purchases."""
|
||||
|
||||
async def test_list_user_purchases(self, client, seed_data):
|
||||
resp = await client.get("/purchases", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
store_names = [p["store_name"] for p in data]
|
||||
assert "Meijer" in store_names
|
||||
assert "Kroger" in store_names
|
||||
|
||||
async def test_filter_purchases_by_store(self, client, seed_data):
|
||||
meijer_id = str(seed_data["stores"]["meijer"].id)
|
||||
resp = await client.get(
|
||||
"/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert all(p["store_name"] == "Meijer" for p in data)
|
||||
|
||||
async def test_purchases_require_auth(self, client, seed_data):
|
||||
resp = await client.get("/purchases")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseDetail:
|
||||
"""Retrieve individual purchase with line items."""
|
||||
|
||||
async def test_get_purchase_detail(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["store_name"] == "Meijer"
|
||||
assert data["total"] == 23.45
|
||||
assert len(data["line_items"]) == 2
|
||||
item_names = [li["name"] for li in data["line_items"]]
|
||||
assert "Cheerios 18oz Box" in item_names
|
||||
assert "Meijer Whole Milk 1gal" in item_names
|
||||
|
||||
async def test_line_item_amounts_correct(self, client, seed_data):
|
||||
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
|
||||
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"])
|
||||
assert cheerios_item["unit_price"] == 4.79
|
||||
assert cheerios_item["quantity"] == 1.0
|
||||
assert cheerios_item["total_price"] == 4.79
|
||||
|
||||
async def test_purchase_not_found(self, client, seed_data):
|
||||
resp = await client.get(
|
||||
f"/purchases/{ZERO_UUID}",
|
||||
headers=seed_data["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPurchaseStats:
|
||||
"""Verify spending aggregation across purchases."""
|
||||
|
||||
async def test_purchase_stats_totals(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["purchase_count"] == 2
|
||||
# 23.45 + 15.78 = 39.23
|
||||
assert abs(data["total_spent"] - 39.23) < 0.01
|
||||
|
||||
async def test_purchase_stats_by_store(self, client, seed_data):
|
||||
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
|
||||
data = resp.json()
|
||||
assert "Meijer" in data["by_store"]
|
||||
assert "Kroger" in data["by_store"]
|
||||
assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01
|
||||
assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Tests for EncryptedJSON TypeDecorator and session_data encryption."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import column, create_engine, table, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
from cartsnitch_api.models import Base
|
||||
from cartsnitch_api.models.store import Store
|
||||
from cartsnitch_api.models.user import User, UserStoreAccount
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
factory = sessionmaker(bind=engine)
|
||||
with factory() as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(session):
|
||||
s = Store(name="Test Store", slug="test-store")
|
||||
session.add(s)
|
||||
session.commit()
|
||||
session.refresh(s)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user(session):
|
||||
u = User(email="alice@example.com", hashed_password="fakehash")
|
||||
session.add(u)
|
||||
session.commit()
|
||||
session.refresh(u)
|
||||
return u
|
||||
|
||||
|
||||
class TestEncryptedJSONType:
|
||||
"""Unit tests for the EncryptedJSON TypeDecorator."""
|
||||
|
||||
def test_round_trip(self, session, user, store):
|
||||
"""Data written via the ORM comes back as the original dict."""
|
||||
original = {"token": "abc123", "cookies": {"session_id": "xyz"}}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == original
|
||||
|
||||
def test_stored_value_is_encrypted(self, session, user, store):
|
||||
"""The raw value in the DB should be a Fernet token, not plaintext JSON."""
|
||||
original = {"secret": "do-not-leak"}
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
# Use a raw table construct to bypass TypeDecorator on read
|
||||
raw_table = table("user_store_accounts", column("id"), column("session_data"))
|
||||
raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first()
|
||||
# If UUID matching fails with str, try bytes format
|
||||
if raw is None:
|
||||
raw = session.execute(
|
||||
text("SELECT session_data FROM user_store_accounts LIMIT 1")
|
||||
).scalar_one()
|
||||
else:
|
||||
raw = raw[1]
|
||||
|
||||
assert raw != json.dumps(original)
|
||||
assert raw.startswith("gAAAAA")
|
||||
|
||||
# Verify we can decrypt the raw value manually
|
||||
f = Fernet(settings.fernet_key.encode())
|
||||
decrypted = json.loads(f.decrypt(raw.encode()))
|
||||
assert decrypted == original
|
||||
|
||||
def test_null_round_trip(self, session, user, store):
|
||||
"""NULL session_data stays NULL."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data is None
|
||||
|
||||
def test_empty_dict_round_trip(self, session, user, store):
|
||||
"""Empty dict round-trips correctly."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {}
|
||||
|
||||
def test_update_session_data(self, session, user, store):
|
||||
"""Updating session_data re-encrypts the new value."""
|
||||
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1})
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
account.session_data = {"v": 2, "new_field": True}
|
||||
session.commit()
|
||||
|
||||
loaded = session.get(UserStoreAccount, account.id)
|
||||
assert loaded.session_data == {"v": 2, "new_field": True}
|
||||
|
||||
|
||||
class TestEncryptionKeyValidation:
|
||||
"""Test that invalid/missing keys are caught at startup."""
|
||||
|
||||
def test_invalid_fernet_key_rejected(self, monkeypatch):
|
||||
"""Settings validation rejects a bad key."""
|
||||
monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
from cartsnitch_api.config import Settings
|
||||
|
||||
Settings()
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Conftest for middleware tests — re-enables rate limiting after global disable."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.config import settings as cartsnitch_settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_rate_limiting():
|
||||
"""Re-enable rate limiting after the global disable_rate_limiting fixture runs.
|
||||
|
||||
The root conftest disables rate limiting for all tests to prevent 429
|
||||
interference. Middleware tests need it active to verify headers and
|
||||
enforcement. This fixture runs after the root fixture (more local = later
|
||||
in setup order) so True is the effective value during the test body.
|
||||
"""
|
||||
cartsnitch_settings.rate_limit_enabled = True
|
||||
yield
|
||||
cartsnitch_settings.rate_limit_enabled = False
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for structured error responses and error monitoring."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_returns_structured_error(client):
|
||||
"""Non-existent route should return structured error."""
|
||||
resp = await client.get("/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert "detail" in body
|
||||
assert "code" in body
|
||||
assert body["code"] == "NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_returns_422_with_field_errors(client):
|
||||
"""Invalid request body should return structured validation errors."""
|
||||
resp = await client.post(
|
||||
"/auth/register",
|
||||
json={"email": "not-an-email", "password": "short", "display_name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["code"] == "VALIDATION_ERROR"
|
||||
assert "errors" in body
|
||||
assert isinstance(body["errors"], list)
|
||||
assert len(body["errors"]) > 0
|
||||
# Each error should have field, message, type
|
||||
for err in body["errors"]:
|
||||
assert "field" in err
|
||||
assert "message" in err
|
||||
assert "type" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_requires_service_key(client):
|
||||
"""Error stats endpoint should require X-Service-Key."""
|
||||
resp = await client.get("/internal/error-stats")
|
||||
assert resp.status_code == 422 # Missing required header
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_stats_with_valid_key(client):
|
||||
"""Error stats endpoint returns monitoring data with valid key."""
|
||||
resp = await client.get(
|
||||
"/internal/error-stats",
|
||||
headers={"X-Service-Key": "change-me-in-production"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "error_counts" in body
|
||||
assert "recent_5xx_count" in body
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
def test_allows_within_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
|
||||
for i in range(5):
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert remaining == 4 - i
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
|
||||
for _ in range(3):
|
||||
counter.is_allowed("test-key")
|
||||
|
||||
allowed, remaining, retry = counter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert remaining == 0
|
||||
assert retry > 0
|
||||
|
||||
def test_separate_keys(self):
|
||||
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
|
||||
# Fill key-a
|
||||
counter.is_allowed("key-a")
|
||||
counter.is_allowed("key-a")
|
||||
allowed_a, _, _ = counter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, remaining, _ = counter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
assert remaining == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_returns_429(client):
|
||||
"""Public endpoint should return 429 after limit exceeded."""
|
||||
# The default limit is 60/min — we won't hit it in normal tests,
|
||||
# but we verify the middleware adds rate limit headers.
|
||||
resp = await client.get("/public/inflation")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_skips_rate_limit(client):
|
||||
"""Health endpoint should not have rate limit headers."""
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Tests for SQLAlchemy ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from cartsnitch_api.constants import (
|
||||
AccountStatus,
|
||||
DiscountType,
|
||||
PriceSource,
|
||||
ProductCategory,
|
||||
SizeUnit,
|
||||
StoreSlug,
|
||||
)
|
||||
from cartsnitch_api.models import (
|
||||
Coupon,
|
||||
NormalizedProduct,
|
||||
PriceHistory,
|
||||
Purchase,
|
||||
PurchaseItem,
|
||||
ShrinkflationEvent,
|
||||
Store,
|
||||
StoreLocation,
|
||||
User,
|
||||
UserStoreAccount,
|
||||
)
|
||||
|
||||
|
||||
class TestTableCreation:
|
||||
"""Verify all expected tables are created."""
|
||||
|
||||
def test_all_tables_exist(self, engine):
|
||||
inspector = inspect(engine)
|
||||
table_names = set(inspector.get_table_names())
|
||||
expected = {
|
||||
"stores",
|
||||
"store_locations",
|
||||
"users",
|
||||
"user_store_accounts",
|
||||
"purchases",
|
||||
"purchase_items",
|
||||
"normalized_products",
|
||||
"price_history",
|
||||
"coupons",
|
||||
"shrinkflation_events",
|
||||
}
|
||||
assert expected.issubset(table_names)
|
||||
|
||||
def test_ten_tables_total(self, engine):
|
||||
inspector = inspect(engine)
|
||||
assert len(inspector.get_table_names()) == 10
|
||||
|
||||
|
||||
class TestUUIDPrimaryKeys:
|
||||
"""All models use UUID PKs."""
|
||||
|
||||
def test_store_uuid_pk(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert isinstance(store.id, uuid.UUID)
|
||||
|
||||
def test_user_uuid_pk(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
assert isinstance(user.id, uuid.UUID)
|
||||
|
||||
|
||||
class TestStoreModel:
|
||||
def test_store_slug_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.commit()
|
||||
assert store.slug == StoreSlug.KROGER
|
||||
|
||||
def test_store_unique_slug(self, session):
|
||||
s1 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
s2 = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target Duplicate",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(s1)
|
||||
session.commit()
|
||||
session.add(s2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestStoreLocationModel:
|
||||
def test_store_location_fields(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
loc = StoreLocation(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
address="123 Main St",
|
||||
city="Ann Arbor",
|
||||
state="MI",
|
||||
zip="48104",
|
||||
lat=42.2808,
|
||||
lng=-83.7430,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(loc)
|
||||
session.commit()
|
||||
assert loc.city == "Ann Arbor"
|
||||
assert loc.lat == pytest.approx(42.2808)
|
||||
|
||||
|
||||
class TestUserStoreAccountModel:
|
||||
def test_account_status_enum(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
acct = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(acct)
|
||||
session.commit()
|
||||
assert acct.status == AccountStatus.ACTIVE
|
||||
|
||||
def test_unique_user_store_constraint(self, session):
|
||||
"""One account per user per store."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="unique@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
a1 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.ACTIVE,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
a2 = UserStoreAccount(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
status=AccountStatus.EXPIRED,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(a1)
|
||||
session.commit()
|
||||
session.add(a2)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
session.commit()
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestPurchaseModel:
|
||||
def test_purchase_with_items(self, session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="buyer@test.com",
|
||||
hashed_password="hashed",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Meijer",
|
||||
slug=StoreSlug.MEIJER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([user, store])
|
||||
session.flush()
|
||||
purchase = Purchase(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="RCP-001",
|
||||
purchase_date=date(2026, 3, 15),
|
||||
total=Decimal("42.50"),
|
||||
ingested_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(purchase)
|
||||
session.flush()
|
||||
item = PurchaseItem(
|
||||
id=uuid.uuid4(),
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Meijer Whole Milk 1 Gallon",
|
||||
upc="0041250000001",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("3.49"),
|
||||
extended_price=Decimal("3.49"),
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
assert item.product_name_raw == "Meijer Whole Milk 1 Gallon"
|
||||
assert item.unit_price == Decimal("3.49")
|
||||
|
||||
|
||||
class TestNormalizedProductModel:
|
||||
def test_product_with_upc_variants(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Whole Milk, 1 Gallon",
|
||||
category=ProductCategory.DAIRY,
|
||||
brand="Store Brand",
|
||||
size="128",
|
||||
size_unit=SizeUnit.FL_OZ,
|
||||
upc_variants=["0041250000001", "0041250000002"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
assert product.category == ProductCategory.DAIRY
|
||||
assert product.size_unit == SizeUnit.FL_OZ
|
||||
|
||||
|
||||
class TestPriceHistoryModel:
|
||||
def test_price_source_enum(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Kroger",
|
||||
slug=StoreSlug.KROGER,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Eggs, Large, 12ct",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([store, product])
|
||||
session.flush()
|
||||
ph = PriceHistory(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 15),
|
||||
regular_price=Decimal("4.99"),
|
||||
sale_price=Decimal("3.99"),
|
||||
source=PriceSource.RECEIPT,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(ph)
|
||||
session.commit()
|
||||
assert ph.source == PriceSource.RECEIPT
|
||||
assert ph.regular_price == Decimal("4.99")
|
||||
|
||||
|
||||
class TestCouponModel:
|
||||
def test_coupon_discount_types(self, session):
|
||||
store = Store(
|
||||
id=uuid.uuid4(),
|
||||
name="Target",
|
||||
slug=StoreSlug.TARGET,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(store)
|
||||
session.flush()
|
||||
coupon = Coupon(
|
||||
id=uuid.uuid4(),
|
||||
store_id=store.id,
|
||||
title="$2 off eggs",
|
||||
discount_type=DiscountType.FIXED,
|
||||
discount_value=Decimal("2.00"),
|
||||
requires_clip=True,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(coupon)
|
||||
session.commit()
|
||||
assert coupon.discount_type == DiscountType.FIXED
|
||||
assert coupon.discount_value == Decimal("2.00")
|
||||
|
||||
|
||||
class TestShrinkflationEventModel:
|
||||
def test_shrinkflation_event(self, session):
|
||||
product = NormalizedProduct(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name="Cereal, Honey Oats",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(product)
|
||||
session.flush()
|
||||
event = ShrinkflationEvent(
|
||||
id=uuid.uuid4(),
|
||||
normalized_product_id=product.id,
|
||||
detected_date=date(2026, 3, 10),
|
||||
old_size="18",
|
||||
new_size="15.4",
|
||||
old_unit=SizeUnit.OZ,
|
||||
new_unit=SizeUnit.OZ,
|
||||
price_at_old_size=Decimal("4.99"),
|
||||
price_at_new_size=Decimal("4.99"),
|
||||
confidence=Decimal("0.95"),
|
||||
notes="Size reduced by 14.4%, price unchanged",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
assert event.confidence == Decimal("0.95")
|
||||
assert event.old_unit == SizeUnit.OZ
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Verify all expected routes are present in the OpenAPI spec."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from cartsnitch_api.main import app
|
||||
|
||||
EXPECTED_ROUTES = [
|
||||
# Auth (6)
|
||||
("post", "/auth/register"),
|
||||
("post", "/auth/login"),
|
||||
("post", "/auth/refresh"),
|
||||
("get", "/auth/me"),
|
||||
("patch", "/auth/me"),
|
||||
("delete", "/auth/me"),
|
||||
# Stores (4)
|
||||
("get", "/stores"),
|
||||
("get", "/me/stores"),
|
||||
("post", "/me/stores/{store_slug}/connect"),
|
||||
("delete", "/me/stores/{store_slug}"),
|
||||
# Purchases (3)
|
||||
("get", "/purchases"),
|
||||
("get", "/purchases/stats"),
|
||||
("get", "/purchases/{purchase_id}"),
|
||||
# Products (3)
|
||||
("get", "/products"),
|
||||
("get", "/products/{product_id}"),
|
||||
("get", "/products/{product_id}/prices"),
|
||||
# Prices (3)
|
||||
("get", "/prices/trends"),
|
||||
("get", "/prices/increases"),
|
||||
("get", "/prices/comparison"),
|
||||
# Coupons (2)
|
||||
("get", "/coupons"),
|
||||
("get", "/coupons/relevant"),
|
||||
# Shopping (2)
|
||||
("post", "/shopping/optimize"),
|
||||
("get", "/shopping/lists"),
|
||||
# Alerts (3)
|
||||
("get", "/alerts"),
|
||||
("get", "/alerts/settings"),
|
||||
("put", "/alerts/settings"),
|
||||
# Scraping (2)
|
||||
("post", "/scraping/{store_slug}/sync"),
|
||||
("get", "/scraping/status"),
|
||||
# Public (3)
|
||||
("get", "/public/trends/{product_id}"),
|
||||
("get", "/public/store-comparison"),
|
||||
("get", "/public/inflation"),
|
||||
# Health (1)
|
||||
("get", "/health"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_routes_in_openapi():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
registered = set()
|
||||
for path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
registered.add((method, path))
|
||||
|
||||
missing = []
|
||||
for method, path in EXPECTED_ROUTES:
|
||||
if (method, path) not in registered:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_count():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/openapi.json")
|
||||
spec = resp.json()
|
||||
paths = spec["paths"]
|
||||
|
||||
count = 0
|
||||
for _path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method in ("get", "post", "put", "delete", "patch"):
|
||||
count += 1
|
||||
|
||||
assert count == 33, f"Expected 33 routes, found {count}"
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for alert endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_alerts_empty(client, auth_headers):
|
||||
"""No purchases means no alerts."""
|
||||
resp = await client.get("/alerts", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_alert_settings(client, auth_headers):
|
||||
resp = await client.get("/alerts/settings", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["price_increase_threshold_pct"] == 5.0
|
||||
assert data["shrinkflation_enabled"] is True
|
||||
assert data["email_notifications"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_alert_settings_returns_501(client, auth_headers):
|
||||
resp = await client.put(
|
||||
"/alerts/settings",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"price_increase_threshold_pct": 10.0,
|
||||
"shrinkflation_enabled": False,
|
||||
"email_notifications": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 501
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Integration tests for coupon endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import Coupon, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coupon_data(db_engine, auth_headers):
|
||||
"""Seed stores and coupons."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
|
||||
coupon = Coupon(
|
||||
store_id=store.id,
|
||||
title="$2 off laundry",
|
||||
description="$2 off any laundry detergent",
|
||||
discount_value=Decimal("2.00"),
|
||||
discount_type="fixed",
|
||||
valid_from=date(2026, 1, 1),
|
||||
valid_to=date(2026, 12, 31),
|
||||
)
|
||||
session.add(coupon)
|
||||
await session.commit()
|
||||
|
||||
return {"store": store, "coupon": coupon, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons(client, coupon_data):
|
||||
resp = await client.get("/coupons", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_coupons_by_store(client, coupon_data):
|
||||
store_id = str(coupon_data["store"].id)
|
||||
resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevant_coupons_empty(client, auth_headers):
|
||||
"""No purchases means no relevant coupons."""
|
||||
resp = await client.get("/coupons/relevant", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for price endpoints."""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def price_data(db_engine, auth_headers):
|
||||
"""Seed products with price history showing an increase."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Walmart", slug="walmart")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Tide Pods 42ct",
|
||||
category="household",
|
||||
brand="Tide",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
# Two price points — second is higher (increase)
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 2, 1),
|
||||
regular_price=Decimal("12.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("14.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends(client, price_data):
|
||||
resp = await client.get("/prices/trends", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["data_points"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_trends_by_category(client, price_data):
|
||||
resp = await client.get("/prices/trends?category=household", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_increases(client, price_data):
|
||||
resp = await client.get("/prices/increases", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
increase = data[0]
|
||||
assert increase["old_price"] == 12.99
|
||||
assert increase["new_price"] == 14.49
|
||||
assert increase["increase_pct"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_comparison(client, price_data):
|
||||
pid = str(price_data["product"].id)
|
||||
resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["product_name"] == "Tide Pods 42ct"
|
||||
assert len(data[0]["prices"]) >= 1
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Integration tests for product endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def product_data(db_engine, auth_headers):
|
||||
"""Seed products and price history."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Cheerios 18oz",
|
||||
category="pantry",
|
||||
brand="General Mills",
|
||||
upc_variants=["016000275263"],
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph1 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 1),
|
||||
regular_price=Decimal("4.99"),
|
||||
source="receipt",
|
||||
)
|
||||
ph2 = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 10),
|
||||
regular_price=Decimal("5.49"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add_all([ph1, ph2])
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store, "headers": auth_headers}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_products(client, product_data):
|
||||
resp = await client.get("/products", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["name"] == "Cheerios 18oz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_products(client, product_data):
|
||||
resp = await client.get("/products?q=Cheerios", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
resp = await client.get("/products?q=nonexistent", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_detail(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Cheerios 18oz"
|
||||
assert data["brand"] == "General Mills"
|
||||
assert len(data["prices_by_store"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_product_prices(client, product_data):
|
||||
pid = str(product_data["product"].id)
|
||||
resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Cheerios 18oz"
|
||||
assert len(data["data_points"]) == 2
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Integration tests for public endpoints (no auth)."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def public_data(db_engine):
|
||||
"""Seed data for public endpoints."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Target", slug="target")
|
||||
product = NormalizedProduct(
|
||||
canonical_name="Skippy PB 16oz",
|
||||
category="pantry",
|
||||
brand="Skippy",
|
||||
)
|
||||
session.add_all([store, product])
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
await session.refresh(product)
|
||||
|
||||
ph = PriceHistory(
|
||||
normalized_product_id=product.id,
|
||||
store_id=store.id,
|
||||
observed_date=date(2026, 3, 5),
|
||||
regular_price=Decimal("3.99"),
|
||||
source="receipt",
|
||||
)
|
||||
session.add(ph)
|
||||
await session.commit()
|
||||
|
||||
return {"product": product, "store": store}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/trends/{pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["product_name"] == "Skippy PB 16oz"
|
||||
assert len(data["data_points"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_trend_not_found(client):
|
||||
resp = await client.get(f"/public/trends/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_store_comparison(client, public_data):
|
||||
pid = str(public_data["product"].id)
|
||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["products"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_inflation(client, public_data):
|
||||
resp = await client.get("/public/inflation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "categories" in data
|
||||
assert "cartsnitch_index" in data
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Integration tests for purchase endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from cartsnitch_api.auth.jwt import create_access_token
|
||||
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def purchase_data(db_engine):
|
||||
"""Seed a user, store, purchase, and items."""
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
from cartsnitch_api.auth.passwords import hash_password
|
||||
|
||||
user = User(
|
||||
email="buyer@example.com",
|
||||
hashed_password=hash_password("testpass123"),
|
||||
display_name="Buyer",
|
||||
)
|
||||
store = Store(name="Kroger", slug="kroger")
|
||||
session.add_all([user, store])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(store)
|
||||
|
||||
purchase = Purchase(
|
||||
user_id=user.id,
|
||||
store_id=store.id,
|
||||
receipt_id="receipt-001",
|
||||
purchase_date=date(2026, 3, 10),
|
||||
total=Decimal("42.50"),
|
||||
)
|
||||
session.add(purchase)
|
||||
await session.commit()
|
||||
await session.refresh(purchase)
|
||||
|
||||
item = PurchaseItem(
|
||||
purchase_id=purchase.id,
|
||||
product_name_raw="Organic Milk 1gal",
|
||||
quantity=Decimal("1"),
|
||||
unit_price=Decimal("5.99"),
|
||||
extended_price=Decimal("5.99"),
|
||||
)
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
token = create_access_token(user.id)
|
||||
return {
|
||||
"user": user,
|
||||
"store": store,
|
||||
"purchase": purchase,
|
||||
"headers": {"Authorization": f"Bearer {token}"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_purchases(client, purchase_data):
|
||||
resp = await client.get("/purchases", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["store_name"] == "Kroger"
|
||||
assert data[0]["total"] == 42.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_detail(client, purchase_data):
|
||||
pid = str(purchase_data["purchase"].id)
|
||||
resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["line_items"]) == 1
|
||||
assert data["line_items"][0]["name"] == "Organic Milk 1gal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_purchase_not_found(client, auth_headers):
|
||||
resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purchase_stats(client, purchase_data):
|
||||
resp = await client.get("/purchases/stats", headers=purchase_data["headers"])
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total_spent"] == 42.50
|
||||
assert data["purchase_count"] == 1
|
||||
assert "Kroger" in data["by_store"]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Integration tests for store endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.models import Store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seeded_store(db_engine):
|
||||
"""Insert a test store directly into the DB."""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with factory() as session:
|
||||
store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None)
|
||||
session.add(store)
|
||||
await session.commit()
|
||||
await session.refresh(store)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_stores(client, seeded_store):
|
||||
resp = await client.get("/stores")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 1
|
||||
assert data[0]["slug"] == "meijer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_user_stores_empty(client, auth_headers):
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
|
||||
# Connect
|
||||
resp = await client.post(
|
||||
"/me/stores/meijer/connect",
|
||||
headers=auth_headers,
|
||||
json={"credentials": None},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["connected"] is True
|
||||
|
||||
# List should show connected
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) == 1
|
||||
|
||||
# Disconnect
|
||||
resp = await client.delete("/me/stores/meijer", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# List should be empty again
|
||||
resp = await client.get("/me/stores", headers=auth_headers)
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_nonexistent_store(client, auth_headers):
|
||||
resp = await client.post(
|
||||
"/me/stores/nonexistent/connect",
|
||||
headers=auth_headers,
|
||||
json={},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_duplicate_store(client, auth_headers, seeded_store):
|
||||
await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
|
||||
assert resp.status_code == 409
|
||||
Reference in New Issue
Block a user