feat: merge cartsnitch/api into api/ subdirectory

Consolidate API gateway service into monorepo.
Squashed from https://github.com/cartsnitch/api main (89bacb1).

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Coupon Carl
2026-03-28 02:24:02 +00:00
commit b7e6f637a7
91 changed files with 6296 additions and 0 deletions
View File
+101
View File
@@ -0,0 +1,101 @@
"""Shared test fixtures with in-memory SQLite database."""
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from cartsnitch_api.config import settings as cartsnitch_settings
from cartsnitch_api.database import get_db
from cartsnitch_api.main import create_app
from cartsnitch_api.models import Base
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest.fixture(autouse=True)
def disable_rate_limiting():
"""Disable rate limiting for all tests to prevent 429 interference."""
cartsnitch_settings.rate_limit_enabled = False
yield
cartsnitch_settings.rate_limit_enabled = True
@pytest.fixture
def engine():
"""Sync in-memory SQLite engine for model unit tests."""
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def session(engine):
"""Sync SQLAlchemy session for model unit tests."""
factory = sessionmaker(bind=engine)
with factory() as sess:
yield sess
@pytest.fixture
async def db_engine():
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
yield session
@pytest.fixture
async def client(db_engine):
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async def override_get_db():
async with factory() as session:
yield session
app = create_app()
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def auth_headers(client):
"""Register a test user and return auth headers."""
resp = await client.post(
"/auth/register",
json={
"email": "test@example.com",
"password": "testpass123",
"display_name": "Test User",
},
)
assert resp.status_code == 201
token = resp.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
View File
+209
View File
@@ -0,0 +1,209 @@
"""Integration tests for auth endpoints."""
import pytest
@pytest.mark.asyncio
async def test_register_success(client):
resp = await client.post(
"/auth/register",
json={
"email": "new@example.com",
"password": "securepass123",
"display_name": "New User",
},
)
assert resp.status_code == 201
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
assert data["expires_in"] == 900 # 15 min * 60
@pytest.mark.asyncio
async def test_register_duplicate_email(client):
await client.post(
"/auth/register",
json={
"email": "dupe@example.com",
"password": "securepass123",
"display_name": "User One",
},
)
resp = await client.post(
"/auth/register",
json={
"email": "dupe@example.com",
"password": "securepass456",
"display_name": "User Two",
},
)
assert resp.status_code == 409
@pytest.mark.asyncio
async def test_register_short_password(client):
resp = await client.post(
"/auth/register",
json={
"email": "short@example.com",
"password": "short",
"display_name": "Short Pass",
},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_login_success(client):
await client.post(
"/auth/register",
json={
"email": "login@example.com",
"password": "securepass123",
"display_name": "Login User",
},
)
resp = await client.post(
"/auth/login",
json={
"email": "login@example.com",
"password": "securepass123",
},
)
assert resp.status_code == 200
assert "access_token" in resp.json()
@pytest.mark.asyncio
async def test_login_wrong_password(client):
await client.post(
"/auth/register",
json={
"email": "wrong@example.com",
"password": "securepass123",
"display_name": "Wrong Pass",
},
)
resp = await client.post(
"/auth/login",
json={
"email": "wrong@example.com",
"password": "badpassword1",
},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_login_nonexistent_user(client):
resp = await client.post(
"/auth/login",
json={
"email": "ghost@example.com",
"password": "doesntmatter",
},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_refresh_token(client):
reg = await client.post(
"/auth/register",
json={
"email": "refresh@example.com",
"password": "securepass123",
"display_name": "Refresh User",
},
)
refresh_token = reg.json()["refresh_token"]
resp = await client.post(
"/auth/refresh",
json={
"refresh_token": refresh_token,
},
)
assert resp.status_code == 200
assert "access_token" in resp.json()
@pytest.mark.asyncio
async def test_refresh_with_invalid_token(client):
resp = await client.post(
"/auth/refresh",
json={
"refresh_token": "invalid.token.here",
},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_get_me(client, auth_headers):
resp = await client.get("/auth/me", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["email"] == "test@example.com"
assert data["display_name"] == "Test User"
assert "id" in data
assert "created_at" in data
@pytest.mark.asyncio
async def test_get_me_unauthorized(client):
resp = await client.get("/auth/me")
assert resp.status_code in (401, 403) # No auth header
@pytest.mark.asyncio
async def test_update_me(client, auth_headers):
resp = await client.patch(
"/auth/me",
headers=auth_headers,
json={
"display_name": "Updated Name",
},
)
assert resp.status_code == 200
assert resp.json()["display_name"] == "Updated Name"
@pytest.mark.asyncio
async def test_delete_me(client, auth_headers):
resp = await client.delete("/auth/me", headers=auth_headers)
assert resp.status_code == 204
# Verify user is gone (token still valid but user deleted)
resp = await client.get("/auth/me", headers=auth_headers)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_refresh_after_delete_fails(client):
"""Refresh token for a deleted user must be rejected."""
reg = await client.post(
"/auth/register",
json={
"email": "ghost@example.com",
"password": "securepass123",
"display_name": "Ghost User",
},
)
tokens = reg.json()
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
# Delete the user
resp = await client.delete("/auth/me", headers=headers)
assert resp.status_code == 204
# Refresh token should now fail
resp = await client.post(
"/auth/refresh",
json={
"refresh_token": tokens["refresh_token"],
},
)
assert resp.status_code == 401
View File
+250
View File
@@ -0,0 +1,250 @@
"""Shared fixtures for E2E integration tests.
Seeds a realistic dataset with stores, products, price history,
purchases, coupons, and shrinkflation events so E2E flows can
exercise cross-resource queries against real data.
"""
from datetime import date, timedelta
from decimal import Decimal
from uuid import UUID
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.auth.jwt import decode_token
from cartsnitch_api.models import (
Coupon,
NormalizedProduct,
PriceHistory,
Purchase,
PurchaseItem,
ShrinkflationEvent,
Store,
)
# Shared test constants
ZERO_UUID = "00000000-0000-0000-0000-000000000000"
BAD_UUID = "not-a-uuid"
# Fixed anchor date for deterministic tests
ANCHOR_DATE = date(2026, 3, 15)
@pytest.fixture
async def seed_data(db_engine, auth_headers):
"""Seed a full dataset and return identifiers for test assertions."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
# -- Stores --
meijer = Store(name="Meijer", slug="meijer")
kroger = Store(name="Kroger", slug="kroger")
target = Store(name="Target", slug="target")
session.add_all([meijer, kroger, target])
await session.flush()
# -- Products --
cheerios = NormalizedProduct(
canonical_name="Cheerios 18oz",
category="pantry",
brand="General Mills",
size="18",
size_unit="oz",
upc_variants=["016000275263"],
)
milk = NormalizedProduct(
canonical_name="Whole Milk 1gal",
category="dairy",
brand="Meijer",
size="1",
size_unit="gal",
)
chicken = NormalizedProduct(
canonical_name="Chicken Breast 1lb",
category="meat",
brand=None,
size="1",
size_unit="lb",
)
session.add_all([cheerios, milk, chicken])
await session.flush()
# -- Price history (multiple dates, multiple stores) --
today = ANCHOR_DATE
prices = []
# Cheerios at Meijer: price increase over time
for i, price_val in enumerate([Decimal("3.99"), Decimal("4.29"), Decimal("4.79")]):
prices.append(
PriceHistory(
normalized_product_id=cheerios.id,
store_id=meijer.id,
observed_date=today - timedelta(days=60 - i * 30),
regular_price=price_val,
source="receipt",
)
)
# Cheerios at Kroger: stable price
for i in range(3):
prices.append(
PriceHistory(
normalized_product_id=cheerios.id,
store_id=kroger.id,
observed_date=today - timedelta(days=60 - i * 30),
regular_price=Decimal("4.49"),
source="catalog",
)
)
# Milk at Meijer
prices.append(
PriceHistory(
normalized_product_id=milk.id,
store_id=meijer.id,
observed_date=today - timedelta(days=7),
regular_price=Decimal("3.29"),
source="receipt",
)
)
# Milk at Kroger
prices.append(
PriceHistory(
normalized_product_id=milk.id,
store_id=kroger.id,
observed_date=today - timedelta(days=5),
regular_price=Decimal("3.49"),
source="catalog",
)
)
# Chicken at Target
prices.append(
PriceHistory(
normalized_product_id=chicken.id,
store_id=target.id,
observed_date=today - timedelta(days=3),
regular_price=Decimal("5.99"),
source="catalog",
)
)
session.add_all(prices)
await session.flush()
# -- Purchases (need the user_id from the registered test user) --
token = auth_headers["Authorization"].split(" ")[1]
payload = decode_token(token)
user_id = UUID(payload["sub"])
purchase1 = Purchase(
user_id=user_id,
store_id=meijer.id,
receipt_id="meijer-2026-001",
purchase_date=today - timedelta(days=10),
total=Decimal("23.45"),
subtotal=Decimal("21.50"),
tax=Decimal("1.95"),
)
purchase2 = Purchase(
user_id=user_id,
store_id=kroger.id,
receipt_id="kroger-2026-001",
purchase_date=today - timedelta(days=5),
total=Decimal("15.78"),
subtotal=Decimal("14.50"),
tax=Decimal("1.28"),
)
session.add_all([purchase1, purchase2])
await session.flush()
# -- Purchase Items --
item1 = PurchaseItem(
purchase_id=purchase1.id,
product_name_raw="Cheerios 18oz Box",
quantity=Decimal("1"),
unit_price=Decimal("4.79"),
extended_price=Decimal("4.79"),
normalized_product_id=cheerios.id,
)
item2 = PurchaseItem(
purchase_id=purchase1.id,
product_name_raw="Meijer Whole Milk 1gal",
quantity=Decimal("2"),
unit_price=Decimal("3.29"),
extended_price=Decimal("6.58"),
normalized_product_id=milk.id,
)
item3 = PurchaseItem(
purchase_id=purchase2.id,
product_name_raw="KRO CHEERIOS 18OZ",
quantity=Decimal("1"),
unit_price=Decimal("4.49"),
extended_price=Decimal("4.49"),
normalized_product_id=cheerios.id,
)
session.add_all([item1, item2, item3])
await session.flush()
# -- Coupons --
coupon1 = Coupon(
store_id=meijer.id,
normalized_product_id=cheerios.id,
title="$1 off Cheerios",
description="Save $1 on any Cheerios 18oz or larger",
discount_type="fixed",
discount_value=Decimal("1.00"),
valid_from=today - timedelta(days=7),
valid_to=today + timedelta(days=30),
)
coupon2 = Coupon(
store_id=kroger.id,
normalized_product_id=None,
title="10% off dairy",
description="10% off all dairy products",
discount_type="percent",
discount_value=Decimal("10.00"),
valid_from=today - timedelta(days=3),
valid_to=today + timedelta(days=14),
)
session.add_all([coupon1, coupon2])
await session.flush()
# -- Shrinkflation events --
shrink = ShrinkflationEvent(
normalized_product_id=cheerios.id,
detected_date=today - timedelta(days=15),
old_size="20",
new_size="18",
old_unit="oz",
new_unit="oz",
price_at_old_size=Decimal("3.99"),
price_at_new_size=Decimal("4.29"),
confidence=Decimal("0.95"),
notes="Size reduced from 20oz to 18oz while price increased",
)
session.add(shrink)
await session.commit()
for obj in [
meijer,
kroger,
target,
cheerios,
milk,
chicken,
purchase1,
purchase2,
item1,
item2,
item3,
coupon1,
coupon2,
shrink,
]:
await session.refresh(obj)
return {
"headers": auth_headers,
"user_id": user_id,
"stores": {"meijer": meijer, "kroger": kroger, "target": target},
"products": {"cheerios": cheerios, "milk": milk, "chicken": chicken},
"purchases": {"meijer_trip": purchase1, "kroger_trip": purchase2},
"items": {"cheerios_meijer": item1, "milk_meijer": item2, "cheerios_kroger": item3},
"coupons": {"cheerios_coupon": coupon1, "dairy_coupon": coupon2},
"shrinkflation": {"cheerios_shrink": shrink},
}
+213
View File
@@ -0,0 +1,213 @@
"""E2E: Auth and token validation flows."""
import asyncio
import pytest
@pytest.mark.asyncio
class TestAuthRegistrationLogin:
"""Full registration → login → token refresh → profile flow."""
async def test_full_auth_lifecycle(self, client, db_engine):
"""Register → login → get profile → refresh → get profile again."""
# Register
reg = await client.post(
"/auth/register",
json={
"email": "lifecycle@example.com",
"password": "securepass123",
"display_name": "Lifecycle User",
},
)
assert reg.status_code == 201
tokens = reg.json()
assert "access_token" in tokens
assert "refresh_token" in tokens
assert tokens["token_type"] == "bearer"
assert tokens["expires_in"] > 0
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
# Get profile with access token
me = await client.get("/auth/me", headers=headers)
assert me.status_code == 200
assert me.json()["email"] == "lifecycle@example.com"
assert me.json()["display_name"] == "Lifecycle User"
# Sleep 1s so the new token has a different exp than the registration token
await asyncio.sleep(1)
# Login with same credentials
login = await client.post(
"/auth/login",
json={"email": "lifecycle@example.com", "password": "securepass123"},
)
assert login.status_code == 200
login_tokens = login.json()
assert login_tokens["access_token"] != tokens["access_token"]
# Refresh token
refresh = await client.post(
"/auth/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert refresh.status_code == 200
new_tokens = refresh.json()
assert new_tokens["access_token"] != tokens["access_token"]
# Use refreshed token to access profile
new_headers = {"Authorization": f"Bearer {new_tokens['access_token']}"}
me2 = await client.get("/auth/me", headers=new_headers)
assert me2.status_code == 200
assert me2.json()["email"] == "lifecycle@example.com"
@pytest.mark.asyncio
class TestTokenValidation:
"""Token edge cases and error responses."""
async def test_expired_token_rejected(self, client, db_engine):
"""Manually craft an expired token and verify rejection."""
import uuid
from datetime import UTC, datetime, timedelta
from jose import jwt
from cartsnitch_api.config import settings
payload = {
"sub": str(uuid.uuid4()),
"exp": datetime.now(UTC) - timedelta(minutes=5),
"type": "access",
}
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
async def test_invalid_token_rejected(self, client, db_engine):
resp = await client.get("/auth/me", headers={"Authorization": "Bearer not-a-real-token"})
assert resp.status_code == 401
async def test_missing_auth_header(self, client, db_engine):
resp = await client.get("/auth/me")
assert resp.status_code in (401, 403)
async def test_refresh_token_cannot_access_endpoints(self, client, db_engine):
"""A refresh token should not work as an access token."""
reg = await client.post(
"/auth/register",
json={
"email": "refresh-test@example.com",
"password": "securepass123",
"display_name": "Refresh Test",
},
)
refresh_token = reg.json()["refresh_token"]
resp = await client.get("/auth/me", headers={"Authorization": f"Bearer {refresh_token}"})
assert resp.status_code == 401
async def test_deleted_user_token_invalid(self, client, db_engine):
"""After deleting an account, tokens should no longer work."""
reg = await client.post(
"/auth/register",
json={
"email": "delete-me@example.com",
"password": "securepass123",
"display_name": "Delete Me",
},
)
tokens = reg.json()
headers = {"Authorization": f"Bearer {tokens['access_token']}"}
# Delete account
delete_resp = await client.delete("/auth/me", headers=headers)
assert delete_resp.status_code == 204
# Profile should fail
me = await client.get("/auth/me", headers=headers)
assert me.status_code in (401, 404)
@pytest.mark.asyncio
class TestAuthProtectedEndpoints:
"""Verify auth is enforced on all user-specific endpoints."""
@pytest.mark.parametrize(
"method,path",
[
("GET", "/purchases"),
("GET", "/products"),
("GET", "/prices/trends"),
("GET", "/prices/increases"),
("GET", "/coupons"),
("GET", "/alerts"),
("GET", "/me/stores"),
],
)
async def test_endpoints_require_auth(self, client, db_engine, method, path):
resp = await client.request(method, path)
assert resp.status_code in (401, 403), f"{method} {path} should require auth"
@pytest.mark.asyncio
class TestCrossUserDataIsolation:
"""Verify that users cannot access other users' data."""
async def test_user_b_cannot_access_user_a_purchases(self, client, seed_data):
"""Register a second user and verify they cannot see User A's purchases."""
# User A's purchase (from seed_data)
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
# Register User B
reg = await client.post(
"/auth/register",
json={
"email": "userb@example.com",
"password": "securepass123",
"display_name": "User B",
},
)
assert reg.status_code == 201
user_b_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
# User B tries to access User A's specific purchase
resp = await client.get(f"/purchases/{purchase_id}", headers=user_b_headers)
assert resp.status_code in (403, 404), (
"User B should not be able to access User A's purchase"
)
async def test_user_b_purchase_list_is_empty(self, client, seed_data):
"""A new user should see no purchases (not User A's purchases)."""
reg = await client.post(
"/auth/register",
json={
"email": "userc@example.com",
"password": "securepass123",
"display_name": "User C",
},
)
assert reg.status_code == 201
user_c_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
resp = await client.get("/purchases", headers=user_c_headers)
assert resp.status_code == 200
assert len(resp.json()) == 0, "New user should have no purchases"
async def test_user_b_stores_isolated(self, client, seed_data):
"""User B's connected stores should be independent from User A."""
reg = await client.post(
"/auth/register",
json={
"email": "userd@example.com",
"password": "securepass123",
"display_name": "User D",
},
)
assert reg.status_code == 201
user_d_headers = {"Authorization": f"Bearer {reg.json()['access_token']}"}
# User D should have no connected stores
resp = await client.get("/me/stores", headers=user_d_headers)
assert resp.status_code == 200
assert len(resp.json()) == 0, "New user should have no connected stores"
+114
View File
@@ -0,0 +1,114 @@
"""E2E: Cross-resource flows — store connect → purchases → prices → coupons → alerts."""
import pytest
@pytest.mark.asyncio
class TestStoreConnectToPurchaseFlow:
"""Connect a store, then verify purchases and related data are accessible."""
async def test_connect_store_then_list(self, client, seed_data):
headers = seed_data["headers"]
# Connect to Meijer
resp = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert resp.status_code in (200, 201)
# Verify store appears in user's connected stores
stores = await client.get("/me/stores", headers=headers)
assert stores.status_code == 200
slugs = [s["store"]["slug"] for s in stores.json()]
assert "meijer" in slugs
async def test_disconnect_store(self, client, seed_data):
headers = seed_data["headers"]
await client.post("/me/stores/kroger/connect", json={}, headers=headers)
resp = await client.delete("/me/stores/kroger", headers=headers)
assert resp.status_code in (200, 204)
# Verify store no longer in connected list
stores = await client.get("/me/stores", headers=headers)
slugs = [s["store"]["slug"] for s in stores.json()]
assert "kroger" not in slugs
@pytest.mark.asyncio
class TestPurchaseToPriceFlow:
"""Verify purchase data links to price comparison data."""
async def test_purchase_items_link_to_products(self, client, seed_data):
"""Items from purchases reference products that have price data."""
headers = seed_data["headers"]
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
# Get purchase detail
purchase = await client.get(f"/purchases/{purchase_id}", headers=headers)
assert purchase.status_code == 200
items = purchase.json()["line_items"]
# Get product detail for an item that has a product_id
product_ids = [li["product_id"] for li in items if li.get("product_id")]
assert len(product_ids) >= 1
for pid in product_ids:
product = await client.get(f"/products/{pid}", headers=headers)
assert product.status_code == 200
assert len(product.json()["prices_by_store"]) >= 1
@pytest.mark.asyncio
class TestCouponFlow:
"""Verify coupon listing and relevance filtering."""
async def test_list_all_coupons(self, client, seed_data):
headers = seed_data["headers"]
resp = await client.get("/coupons", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
descriptions = [c["description"] for c in data]
assert any("Cheerios" in d for d in descriptions)
async def test_filter_coupons_by_store(self, client, seed_data):
headers = seed_data["headers"]
meijer_id = str(seed_data["stores"]["meijer"].id)
resp = await client.get("/coupons", params={"store_id": meijer_id}, headers=headers)
assert resp.status_code == 200
data = resp.json()
assert all(c["store_name"] == "Meijer" for c in data)
async def test_relevant_coupons_for_user(self, client, seed_data):
"""User bought Cheerios, so the Cheerios coupon should be relevant."""
headers = seed_data["headers"]
resp = await client.get("/coupons/relevant", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1, "Expected at least one relevant coupon for user with purchases"
descriptions = [c["description"] for c in data]
assert any("Cheerios" in d for d in descriptions)
@pytest.mark.asyncio
class TestAlertFlow:
"""Verify alert listing with seeded data."""
async def test_list_alerts(self, client, seed_data):
"""User bought Cheerios which has a shrinkflation event — may appear as alert."""
headers = seed_data["headers"]
resp = await client.get("/alerts", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
# If alerts are generated synchronously, verify shrinkflation alert content
if len(data) > 0:
alert_types = [a["alert_type"] for a in data]
product_names = [a["product_name"] for a in data]
assert any(t in ("shrinkflation", "price_increase") for t in alert_types)
assert any("Cheerios" in name for name in product_names)
async def test_alert_settings_default(self, client, seed_data):
headers = seed_data["headers"]
resp = await client.get("/alerts/settings", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert "price_increase_threshold_pct" in data
assert "shrinkflation_enabled" in data
+127
View File
@@ -0,0 +1,127 @@
"""E2E: Error responses for bad input across all endpoint categories."""
import pytest
from tests.test_e2e.conftest import BAD_UUID, ZERO_UUID
@pytest.mark.asyncio
class TestRegistrationErrors:
"""Validation errors during user registration."""
async def test_short_password(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "short@example.com", "password": "short", "display_name": "Test"},
)
assert resp.status_code == 422
async def test_invalid_email(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "not-an-email", "password": "securepass123", "display_name": "Test"},
)
assert resp.status_code == 422
async def test_missing_fields(self, client, db_engine):
resp = await client.post("/auth/register", json={})
assert resp.status_code == 422
async def test_empty_display_name(self, client, db_engine):
resp = await client.post(
"/auth/register",
json={"email": "empty@example.com", "password": "securepass123", "display_name": ""},
)
assert resp.status_code == 422
async def test_duplicate_email(self, client, db_engine):
payload = {
"email": "dupe@example.com",
"password": "securepass123",
"display_name": "First",
}
first = await client.post("/auth/register", json=payload)
assert first.status_code == 201
second = await client.post("/auth/register", json=payload)
assert second.status_code == 409
@pytest.mark.asyncio
class TestLoginErrors:
"""Login failure modes."""
async def test_wrong_password(self, client, db_engine):
await client.post(
"/auth/register",
json={
"email": "login-err@example.com",
"password": "correctpass1",
"display_name": "Login",
},
)
resp = await client.post(
"/auth/login",
json={"email": "login-err@example.com", "password": "wrongpass123"},
)
assert resp.status_code == 401
async def test_nonexistent_user(self, client, db_engine):
resp = await client.post(
"/auth/login",
json={"email": "nobody@example.com", "password": "doesntmatter"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
class TestNotFoundErrors:
"""404 responses for missing resources."""
async def test_product_not_found(self, client, seed_data):
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_purchase_not_found(self, client, seed_data):
resp = await client.get(f"/purchases/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_public_trend_not_found(self, client, seed_data):
resp = await client.get(f"/public/trends/{ZERO_UUID}")
assert resp.status_code == 404
@pytest.mark.asyncio
class TestMalformedInput:
"""Invalid UUID formats and bad query params."""
async def test_invalid_uuid_product(self, client, seed_data):
resp = await client.get(f"/products/{BAD_UUID}", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_invalid_uuid_purchase(self, client, seed_data):
resp = await client.get(f"/purchases/{BAD_UUID}", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_invalid_uuid_public_trend(self, client, seed_data):
resp = await client.get(f"/public/trends/{BAD_UUID}")
assert resp.status_code == 422
@pytest.mark.asyncio
class TestStoreConnectionErrors:
"""Store connection edge cases."""
async def test_connect_nonexistent_store(self, client, seed_data):
resp = await client.post(
"/me/stores/nonexistent-store/connect",
json={},
headers=seed_data["headers"],
)
assert resp.status_code == 404
async def test_connect_store_twice(self, client, seed_data):
headers = seed_data["headers"]
first = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert first.status_code in (200, 201)
second = await client.post("/me/stores/meijer/connect", json={}, headers=headers)
assert second.status_code == 409
+102
View File
@@ -0,0 +1,102 @@
"""E2E: Price history queries returning correct data."""
import pytest
@pytest.mark.asyncio
class TestPriceTrends:
"""Verify price trend aggregation against seeded history."""
async def test_trends_returns_all_products(self, client, seed_data):
resp = await client.get("/prices/trends", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
product_names = [t["product_name"] for t in data]
assert "Cheerios 18oz" in product_names
assert "Whole Milk 1gal" in product_names
async def test_trends_filter_by_category(self, client, seed_data):
resp = await client.get(
"/prices/trends", params={"category": "dairy"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
# Only dairy products should appear
for trend in data:
assert trend["product_name"] == "Whole Milk 1gal"
async def test_trends_contain_data_points(self, client, seed_data):
resp = await client.get("/prices/trends", headers=seed_data["headers"])
data = resp.json()
cheerios_trend = next(t for t in data if t["product_name"] == "Cheerios 18oz")
assert len(cheerios_trend["data_points"]) >= 3
@pytest.mark.asyncio
class TestPriceIncreases:
"""Detect price increases from seeded price history."""
async def test_increases_detected(self, client, seed_data):
resp = await client.get("/prices/increases", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
# Cheerios at Meijer went from 3.99 → 4.29 → 4.79
cheerios_increases = [inc for inc in data if inc["product_name"] == "Cheerios 18oz"]
assert len(cheerios_increases) >= 1
# Verify the increase data makes sense
for inc in cheerios_increases:
assert inc["new_price"] > inc["old_price"]
assert inc["increase_pct"] > 0
assert inc["store_name"] == "Meijer"
async def test_stable_prices_not_flagged(self, client, seed_data):
"""Kroger Cheerios price is stable at $4.49 — should not appear as increase."""
resp = await client.get("/prices/increases", headers=seed_data["headers"])
data = resp.json()
kroger_increases = [
inc
for inc in data
if inc["product_name"] == "Cheerios 18oz" and inc["store_name"] == "Kroger"
]
assert len(kroger_increases) == 0
@pytest.mark.asyncio
class TestPriceComparison:
"""Compare prices across stores for specific products."""
async def test_compare_cheerios_across_stores(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(
"/prices/comparison",
params={"product_ids": cheerios_id},
headers=seed_data["headers"],
)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
cheerios_cmp = data[0]
assert cheerios_cmp["product_name"] == "Cheerios 18oz"
store_names = [p["store_name"] for p in cheerios_cmp["prices"]]
assert "Meijer" in store_names
assert "Kroger" in store_names
async def test_compare_requires_product_ids(self, client, seed_data):
"""product_ids is required — omitting it must return 422."""
resp = await client.get("/prices/comparison", headers=seed_data["headers"])
assert resp.status_code == 422
async def test_compare_multiple_products(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
milk_id = str(seed_data["products"]["milk"].id)
resp = await client.get(
"/prices/comparison",
params=[("product_ids", cheerios_id), ("product_ids", milk_id)],
headers=seed_data["headers"],
)
assert resp.status_code == 200
data = resp.json()
names = [c["product_name"] for c in data]
assert "Cheerios 18oz" in names
assert "Whole Milk 1gal" in names
@@ -0,0 +1,82 @@
"""E2E: Product search/lookup endpoints with real DB fixtures."""
import pytest
from tests.test_e2e.conftest import ZERO_UUID
@pytest.mark.asyncio
class TestProductSearch:
"""Search and filter products against seeded data."""
async def test_list_all_products(self, client, seed_data):
resp = await client.get("/products", headers=seed_data["headers"])
assert resp.status_code == 200
products = resp.json()
names = [p["name"] for p in products]
assert "Cheerios 18oz" in names
assert "Whole Milk 1gal" in names
assert "Chicken Breast 1lb" in names
async def test_search_by_name(self, client, seed_data):
resp = await client.get("/products", params={"q": "cheerios"}, headers=seed_data["headers"])
assert resp.status_code == 200
products = resp.json()
assert len(products) >= 1
assert all("cheerios" in p["name"].lower() for p in products)
async def test_search_by_category(self, client, seed_data):
resp = await client.get(
"/products", params={"category": "dairy"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
products = resp.json()
assert len(products) >= 1
assert all(p["category"] == "dairy" for p in products)
async def test_search_no_results(self, client, seed_data):
resp = await client.get(
"/products", params={"q": "nonexistentxyz"}, headers=seed_data["headers"]
)
assert resp.status_code == 200
assert resp.json() == []
@pytest.mark.asyncio
class TestProductLookup:
"""Detailed product lookups with cross-store pricing."""
async def test_get_product_detail_with_prices(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Cheerios 18oz"
assert data["brand"] == "General Mills"
assert data["category"] == "pantry"
# Should have prices from both Meijer and Kroger
store_names = [p["store_name"] for p in data["prices_by_store"]]
assert "Meijer" in store_names
assert "Kroger" in store_names
async def test_product_prices_reflect_latest(self, client, seed_data):
"""The latest Meijer price for Cheerios should be 4.79 (the increase)."""
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/products/{cheerios_id}", headers=seed_data["headers"])
data = resp.json()
meijer_price = next(p for p in data["prices_by_store"] if p["store_name"] == "Meijer")
assert meijer_price["current_price"] == 4.79
async def test_product_not_found(self, client, seed_data):
resp = await client.get(f"/products/{ZERO_UUID}", headers=seed_data["headers"])
assert resp.status_code == 404
async def test_product_price_history(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/products/{cheerios_id}/prices", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data["data_points"]) >= 3 # At least the 3 Meijer observations
# Verify chronological ordering exists
prices = [dp["price"] for dp in data["data_points"]]
assert len(prices) >= 3
+59
View File
@@ -0,0 +1,59 @@
"""E2E: Public price transparency endpoints (no auth required)."""
import uuid
import pytest
@pytest.mark.asyncio
class TestPublicTrends:
"""Public price trend endpoint — no auth, real data."""
async def test_public_trend_returns_data(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/public/trends/{cheerios_id}")
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Cheerios 18oz"
assert len(data["data_points"]) >= 3
async def test_public_trend_no_auth_needed(self, client, seed_data):
"""Confirm no Authorization header is required."""
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(f"/public/trends/{cheerios_id}")
assert resp.status_code == 200
@pytest.mark.asyncio
class TestPublicStoreComparison:
"""Public store comparison endpoint."""
async def test_store_comparison(self, client, seed_data):
cheerios_id = str(seed_data["products"]["cheerios"].id)
resp = await client.get(
"/public/store-comparison",
params=[("product_ids", cheerios_id)],
)
assert resp.status_code == 200
data = resp.json()
assert "products" in data
assert len(data["products"]) >= 1
async def test_store_comparison_rejects_more_than_20_ids(self, client):
"""max_length=20 guard: 21 product IDs must return 422."""
too_many = [("product_ids", str(uuid.uuid4())) for _ in range(21)]
resp = await client.get("/public/store-comparison", params=too_many)
assert resp.status_code == 422
@pytest.mark.asyncio
class TestPublicInflation:
"""Public inflation index endpoint."""
async def test_inflation_returns_index(self, client, seed_data):
resp = await client.get("/public/inflation")
assert resp.status_code == 200
data = resp.json()
assert "cartsnitch_index" in data
assert "cpi_baseline" in data
assert "categories" in data
+87
View File
@@ -0,0 +1,87 @@
"""E2E: Purchase listing, detail, and stats against real DB fixtures."""
import pytest
from tests.test_e2e.conftest import ZERO_UUID
@pytest.mark.asyncio
class TestPurchaseList:
"""List and filter a user's purchases."""
async def test_list_user_purchases(self, client, seed_data):
resp = await client.get("/purchases", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
store_names = [p["store_name"] for p in data]
assert "Meijer" in store_names
assert "Kroger" in store_names
async def test_filter_purchases_by_store(self, client, seed_data):
meijer_id = str(seed_data["stores"]["meijer"].id)
resp = await client.get(
"/purchases", params={"store_id": meijer_id}, headers=seed_data["headers"]
)
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
assert all(p["store_name"] == "Meijer" for p in data)
async def test_purchases_require_auth(self, client, seed_data):
resp = await client.get("/purchases")
assert resp.status_code in (401, 403)
@pytest.mark.asyncio
class TestPurchaseDetail:
"""Retrieve individual purchase with line items."""
async def test_get_purchase_detail(self, client, seed_data):
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["store_name"] == "Meijer"
assert data["total"] == 23.45
assert len(data["line_items"]) == 2
item_names = [li["name"] for li in data["line_items"]]
assert "Cheerios 18oz Box" in item_names
assert "Meijer Whole Milk 1gal" in item_names
async def test_line_item_amounts_correct(self, client, seed_data):
purchase_id = str(seed_data["purchases"]["meijer_trip"].id)
resp = await client.get(f"/purchases/{purchase_id}", headers=seed_data["headers"])
data = resp.json()
cheerios_item = next(li for li in data["line_items"] if "Cheerios" in li["name"])
assert cheerios_item["unit_price"] == 4.79
assert cheerios_item["quantity"] == 1.0
assert cheerios_item["total_price"] == 4.79
async def test_purchase_not_found(self, client, seed_data):
resp = await client.get(
f"/purchases/{ZERO_UUID}",
headers=seed_data["headers"],
)
assert resp.status_code == 404
@pytest.mark.asyncio
class TestPurchaseStats:
"""Verify spending aggregation across purchases."""
async def test_purchase_stats_totals(self, client, seed_data):
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["purchase_count"] == 2
# 23.45 + 15.78 = 39.23
assert abs(data["total_spent"] - 39.23) < 0.01
async def test_purchase_stats_by_store(self, client, seed_data):
resp = await client.get("/purchases/stats", headers=seed_data["headers"])
data = resp.json()
assert "Meijer" in data["by_store"]
assert "Kroger" in data["by_store"]
assert abs(data["by_store"]["Meijer"] - 23.45) < 0.01
assert abs(data["by_store"]["Kroger"] - 15.78) < 0.01
+130
View File
@@ -0,0 +1,130 @@
"""Tests for EncryptedJSON TypeDecorator and session_data encryption."""
import json
import pytest
from cryptography.fernet import Fernet
from pydantic import ValidationError
from sqlalchemy import column, create_engine, table, text
from sqlalchemy.orm import sessionmaker
from cartsnitch_api.config import settings
from cartsnitch_api.models import Base
from cartsnitch_api.models.store import Store
from cartsnitch_api.models.user import User, UserStoreAccount
@pytest.fixture
def engine():
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def session(engine):
factory = sessionmaker(bind=engine)
with factory() as sess:
yield sess
@pytest.fixture
def store(session):
s = Store(name="Test Store", slug="test-store")
session.add(s)
session.commit()
session.refresh(s)
return s
@pytest.fixture
def user(session):
u = User(email="alice@example.com", hashed_password="fakehash")
session.add(u)
session.commit()
session.refresh(u)
return u
class TestEncryptedJSONType:
"""Unit tests for the EncryptedJSON TypeDecorator."""
def test_round_trip(self, session, user, store):
"""Data written via the ORM comes back as the original dict."""
original = {"token": "abc123", "cookies": {"session_id": "xyz"}}
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == original
def test_stored_value_is_encrypted(self, session, user, store):
"""The raw value in the DB should be a Fernet token, not plaintext JSON."""
original = {"secret": "do-not-leak"}
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
session.add(account)
session.commit()
# Use a raw table construct to bypass TypeDecorator on read
raw_table = table("user_store_accounts", column("id"), column("session_data"))
raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first()
# If UUID matching fails with str, try bytes format
if raw is None:
raw = session.execute(
text("SELECT session_data FROM user_store_accounts LIMIT 1")
).scalar_one()
else:
raw = raw[1]
assert raw != json.dumps(original)
assert raw.startswith("gAAAAA")
# Verify we can decrypt the raw value manually
f = Fernet(settings.fernet_key.encode())
decrypted = json.loads(f.decrypt(raw.encode()))
assert decrypted == original
def test_null_round_trip(self, session, user, store):
"""NULL session_data stays NULL."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None)
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data is None
def test_empty_dict_round_trip(self, session, user, store):
"""Empty dict round-trips correctly."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={})
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == {}
def test_update_session_data(self, session, user, store):
"""Updating session_data re-encrypts the new value."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1})
session.add(account)
session.commit()
account.session_data = {"v": 2, "new_field": True}
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == {"v": 2, "new_field": True}
class TestEncryptionKeyValidation:
"""Test that invalid/missing keys are caught at startup."""
def test_invalid_fernet_key_rejected(self, monkeypatch):
"""Settings validation rejects a bad key."""
monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key")
with pytest.raises(ValidationError):
from cartsnitch_api.config import Settings
Settings()
View File
+19
View File
@@ -0,0 +1,19 @@
"""Conftest for middleware tests — re-enables rate limiting after global disable."""
import pytest
from cartsnitch_api.config import settings as cartsnitch_settings
@pytest.fixture(autouse=True)
def enable_rate_limiting():
"""Re-enable rate limiting after the global disable_rate_limiting fixture runs.
The root conftest disables rate limiting for all tests to prevent 429
interference. Middleware tests need it active to verify headers and
enforcement. This fixture runs after the root fixture (more local = later
in setup order) so True is the effective value during the test body.
"""
cartsnitch_settings.rate_limit_enabled = True
yield
cartsnitch_settings.rate_limit_enabled = False
@@ -0,0 +1,54 @@
"""Tests for structured error responses and error monitoring."""
import pytest
@pytest.mark.asyncio
async def test_404_returns_structured_error(client):
"""Non-existent route should return structured error."""
resp = await client.get("/nonexistent")
assert resp.status_code == 404
body = resp.json()
assert "detail" in body
assert "code" in body
assert body["code"] == "NOT_FOUND"
@pytest.mark.asyncio
async def test_validation_error_returns_422_with_field_errors(client):
"""Invalid request body should return structured validation errors."""
resp = await client.post(
"/auth/register",
json={"email": "not-an-email", "password": "short", "display_name": ""},
)
assert resp.status_code == 422
body = resp.json()
assert body["code"] == "VALIDATION_ERROR"
assert "errors" in body
assert isinstance(body["errors"], list)
assert len(body["errors"]) > 0
# Each error should have field, message, type
for err in body["errors"]:
assert "field" in err
assert "message" in err
assert "type" in err
@pytest.mark.asyncio
async def test_error_stats_requires_service_key(client):
"""Error stats endpoint should require X-Service-Key."""
resp = await client.get("/internal/error-stats")
assert resp.status_code == 422 # Missing required header
@pytest.mark.asyncio
async def test_error_stats_with_valid_key(client):
"""Error stats endpoint returns monitoring data with valid key."""
resp = await client.get(
"/internal/error-stats",
headers={"X-Service-Key": "change-me-in-production"},
)
assert resp.status_code == 200
body = resp.json()
assert "error_counts" in body
assert "recent_5xx_count" in body
+55
View File
@@ -0,0 +1,55 @@
"""Tests for rate limiting middleware."""
import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
class TestSlidingWindowCounter:
def test_allows_within_limit(self):
counter = _SlidingWindowCounter(max_requests=5, window_seconds=60)
for i in range(5):
allowed, remaining, retry = counter.is_allowed("test-key")
assert allowed is True
assert remaining == 4 - i
def test_blocks_over_limit(self):
counter = _SlidingWindowCounter(max_requests=3, window_seconds=60)
for _ in range(3):
counter.is_allowed("test-key")
allowed, remaining, retry = counter.is_allowed("test-key")
assert allowed is False
assert remaining == 0
assert retry > 0
def test_separate_keys(self):
counter = _SlidingWindowCounter(max_requests=2, window_seconds=60)
# Fill key-a
counter.is_allowed("key-a")
counter.is_allowed("key-a")
allowed_a, _, _ = counter.is_allowed("key-a")
assert allowed_a is False
# key-b should still be allowed
allowed_b, remaining, _ = counter.is_allowed("key-b")
assert allowed_b is True
assert remaining == 1
@pytest.mark.asyncio
async def test_rate_limit_returns_429(client):
"""Public endpoint should return 429 after limit exceeded."""
# The default limit is 60/min — we won't hit it in normal tests,
# but we verify the middleware adds rate limit headers.
resp = await client.get("/public/inflation")
assert "x-ratelimit-limit" in resp.headers
assert "x-ratelimit-remaining" in resp.headers
@pytest.mark.asyncio
async def test_health_skips_rate_limit(client):
"""Health endpoint should not have rate limit headers."""
resp = await client.get("/health")
assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers
+376
View File
@@ -0,0 +1,376 @@
"""Tests for SQLAlchemy ORM models."""
import uuid
from datetime import UTC, date, datetime
from decimal import Decimal
import pytest
from sqlalchemy import inspect
from cartsnitch_api.constants import (
AccountStatus,
DiscountType,
PriceSource,
ProductCategory,
SizeUnit,
StoreSlug,
)
from cartsnitch_api.models import (
Coupon,
NormalizedProduct,
PriceHistory,
Purchase,
PurchaseItem,
ShrinkflationEvent,
Store,
StoreLocation,
User,
UserStoreAccount,
)
class TestTableCreation:
"""Verify all expected tables are created."""
def test_all_tables_exist(self, engine):
inspector = inspect(engine)
table_names = set(inspector.get_table_names())
expected = {
"stores",
"store_locations",
"users",
"user_store_accounts",
"purchases",
"purchase_items",
"normalized_products",
"price_history",
"coupons",
"shrinkflation_events",
}
assert expected.issubset(table_names)
def test_ten_tables_total(self, engine):
inspector = inspect(engine)
assert len(inspector.get_table_names()) == 10
class TestUUIDPrimaryKeys:
"""All models use UUID PKs."""
def test_store_uuid_pk(self, session):
store = Store(
id=uuid.uuid4(),
name="Meijer",
slug=StoreSlug.MEIJER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(store)
session.commit()
assert isinstance(store.id, uuid.UUID)
def test_user_uuid_pk(self, session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(user)
session.commit()
assert isinstance(user.id, uuid.UUID)
class TestStoreModel:
def test_store_slug_enum(self, session):
store = Store(
id=uuid.uuid4(),
name="Kroger",
slug=StoreSlug.KROGER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(store)
session.commit()
assert store.slug == StoreSlug.KROGER
def test_store_unique_slug(self, session):
s1 = Store(
id=uuid.uuid4(),
name="Target",
slug=StoreSlug.TARGET,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
s2 = Store(
id=uuid.uuid4(),
name="Target Duplicate",
slug=StoreSlug.TARGET,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(s1)
session.commit()
session.add(s2)
with pytest.raises(Exception): # noqa: B017
session.commit()
session.rollback()
class TestStoreLocationModel:
def test_store_location_fields(self, session):
store = Store(
id=uuid.uuid4(),
name="Meijer",
slug=StoreSlug.MEIJER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(store)
session.flush()
loc = StoreLocation(
id=uuid.uuid4(),
store_id=store.id,
address="123 Main St",
city="Ann Arbor",
state="MI",
zip="48104",
lat=42.2808,
lng=-83.7430,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(loc)
session.commit()
assert loc.city == "Ann Arbor"
assert loc.lat == pytest.approx(42.2808)
class TestUserStoreAccountModel:
def test_account_status_enum(self, session):
user = User(
id=uuid.uuid4(),
email="test@test.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
store = Store(
id=uuid.uuid4(),
name="Kroger",
slug=StoreSlug.KROGER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add_all([user, store])
session.flush()
acct = UserStoreAccount(
id=uuid.uuid4(),
user_id=user.id,
store_id=store.id,
status=AccountStatus.ACTIVE,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(acct)
session.commit()
assert acct.status == AccountStatus.ACTIVE
def test_unique_user_store_constraint(self, session):
"""One account per user per store."""
user = User(
id=uuid.uuid4(),
email="unique@test.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
store = Store(
id=uuid.uuid4(),
name="Target",
slug=StoreSlug.TARGET,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add_all([user, store])
session.flush()
a1 = UserStoreAccount(
id=uuid.uuid4(),
user_id=user.id,
store_id=store.id,
status=AccountStatus.ACTIVE,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
a2 = UserStoreAccount(
id=uuid.uuid4(),
user_id=user.id,
store_id=store.id,
status=AccountStatus.EXPIRED,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(a1)
session.commit()
session.add(a2)
with pytest.raises(Exception): # noqa: B017
session.commit()
session.rollback()
class TestPurchaseModel:
def test_purchase_with_items(self, session):
user = User(
id=uuid.uuid4(),
email="buyer@test.com",
hashed_password="hashed",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
store = Store(
id=uuid.uuid4(),
name="Meijer",
slug=StoreSlug.MEIJER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add_all([user, store])
session.flush()
purchase = Purchase(
id=uuid.uuid4(),
user_id=user.id,
store_id=store.id,
receipt_id="RCP-001",
purchase_date=date(2026, 3, 15),
total=Decimal("42.50"),
ingested_at=datetime.now(UTC),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(purchase)
session.flush()
item = PurchaseItem(
id=uuid.uuid4(),
purchase_id=purchase.id,
product_name_raw="Meijer Whole Milk 1 Gallon",
upc="0041250000001",
quantity=Decimal("1"),
unit_price=Decimal("3.49"),
extended_price=Decimal("3.49"),
)
session.add(item)
session.commit()
assert item.product_name_raw == "Meijer Whole Milk 1 Gallon"
assert item.unit_price == Decimal("3.49")
class TestNormalizedProductModel:
def test_product_with_upc_variants(self, session):
product = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Whole Milk, 1 Gallon",
category=ProductCategory.DAIRY,
brand="Store Brand",
size="128",
size_unit=SizeUnit.FL_OZ,
upc_variants=["0041250000001", "0041250000002"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(product)
session.commit()
assert product.category == ProductCategory.DAIRY
assert product.size_unit == SizeUnit.FL_OZ
class TestPriceHistoryModel:
def test_price_source_enum(self, session):
store = Store(
id=uuid.uuid4(),
name="Kroger",
slug=StoreSlug.KROGER,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
product = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Eggs, Large, 12ct",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add_all([store, product])
session.flush()
ph = PriceHistory(
id=uuid.uuid4(),
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 3, 15),
regular_price=Decimal("4.99"),
sale_price=Decimal("3.99"),
source=PriceSource.RECEIPT,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(ph)
session.commit()
assert ph.source == PriceSource.RECEIPT
assert ph.regular_price == Decimal("4.99")
class TestCouponModel:
def test_coupon_discount_types(self, session):
store = Store(
id=uuid.uuid4(),
name="Target",
slug=StoreSlug.TARGET,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(store)
session.flush()
coupon = Coupon(
id=uuid.uuid4(),
store_id=store.id,
title="$2 off eggs",
discount_type=DiscountType.FIXED,
discount_value=Decimal("2.00"),
requires_clip=True,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(coupon)
session.commit()
assert coupon.discount_type == DiscountType.FIXED
assert coupon.discount_value == Decimal("2.00")
class TestShrinkflationEventModel:
def test_shrinkflation_event(self, session):
product = NormalizedProduct(
id=uuid.uuid4(),
canonical_name="Cereal, Honey Oats",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(product)
session.flush()
event = ShrinkflationEvent(
id=uuid.uuid4(),
normalized_product_id=product.id,
detected_date=date(2026, 3, 10),
old_size="18",
new_size="15.4",
old_unit=SizeUnit.OZ,
new_unit=SizeUnit.OZ,
price_at_old_size=Decimal("4.99"),
price_at_new_size=Decimal("4.99"),
confidence=Decimal("0.95"),
notes="Size reduced by 14.4%, price unchanged",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
session.add(event)
session.commit()
assert event.confidence == Decimal("0.95")
assert event.old_unit == SizeUnit.OZ
+92
View File
@@ -0,0 +1,92 @@
"""Verify all expected routes are present in the OpenAPI spec."""
import pytest
from httpx import ASGITransport, AsyncClient
from cartsnitch_api.main import app
EXPECTED_ROUTES = [
# Auth (6)
("post", "/auth/register"),
("post", "/auth/login"),
("post", "/auth/refresh"),
("get", "/auth/me"),
("patch", "/auth/me"),
("delete", "/auth/me"),
# Stores (4)
("get", "/stores"),
("get", "/me/stores"),
("post", "/me/stores/{store_slug}/connect"),
("delete", "/me/stores/{store_slug}"),
# Purchases (3)
("get", "/purchases"),
("get", "/purchases/stats"),
("get", "/purchases/{purchase_id}"),
# Products (3)
("get", "/products"),
("get", "/products/{product_id}"),
("get", "/products/{product_id}/prices"),
# Prices (3)
("get", "/prices/trends"),
("get", "/prices/increases"),
("get", "/prices/comparison"),
# Coupons (2)
("get", "/coupons"),
("get", "/coupons/relevant"),
# Shopping (2)
("post", "/shopping/optimize"),
("get", "/shopping/lists"),
# Alerts (3)
("get", "/alerts"),
("get", "/alerts/settings"),
("put", "/alerts/settings"),
# Scraping (2)
("post", "/scraping/{store_slug}/sync"),
("get", "/scraping/status"),
# Public (3)
("get", "/public/trends/{product_id}"),
("get", "/public/store-comparison"),
("get", "/public/inflation"),
# Health (1)
("get", "/health"),
]
@pytest.mark.asyncio
async def test_all_routes_in_openapi():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/openapi.json")
assert resp.status_code == 200
spec = resp.json()
paths = spec["paths"]
registered = set()
for path, methods in paths.items():
for method in methods:
if method in ("get", "post", "put", "delete", "patch"):
registered.add((method, path))
missing = []
for method, path in EXPECTED_ROUTES:
if (method, path) not in registered:
missing.append(f"{method.upper()} {path}")
assert not missing, "Missing routes in OpenAPI spec:\n" + "\n".join(missing)
@pytest.mark.asyncio
async def test_route_count():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/openapi.json")
spec = resp.json()
paths = spec["paths"]
count = 0
for _path, methods in paths.items():
for method in methods:
if method in ("get", "post", "put", "delete", "patch"):
count += 1
assert count == 33, f"Expected 33 routes, found {count}"
View File
+35
View File
@@ -0,0 +1,35 @@
"""Integration tests for alert endpoints."""
import pytest
@pytest.mark.asyncio
async def test_list_alerts_empty(client, auth_headers):
"""No purchases means no alerts."""
resp = await client.get("/alerts", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
@pytest.mark.asyncio
async def test_get_alert_settings(client, auth_headers):
resp = await client.get("/alerts/settings", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["price_increase_threshold_pct"] == 5.0
assert data["shrinkflation_enabled"] is True
assert data["email_notifications"] is False
@pytest.mark.asyncio
async def test_update_alert_settings_returns_501(client, auth_headers):
resp = await client.put(
"/alerts/settings",
headers=auth_headers,
json={
"price_increase_threshold_pct": 10.0,
"shrinkflation_enabled": False,
"email_notifications": True,
},
)
assert resp.status_code == 501
+58
View File
@@ -0,0 +1,58 @@
"""Integration tests for coupon endpoints."""
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import Coupon, Store
@pytest.fixture
async def coupon_data(db_engine, auth_headers):
"""Seed stores and coupons."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
store = Store(name="Target", slug="target")
session.add(store)
await session.commit()
await session.refresh(store)
coupon = Coupon(
store_id=store.id,
title="$2 off laundry",
description="$2 off any laundry detergent",
discount_value=Decimal("2.00"),
discount_type="fixed",
valid_from=date(2026, 1, 1),
valid_to=date(2026, 12, 31),
)
session.add(coupon)
await session.commit()
return {"store": store, "coupon": coupon, "headers": auth_headers}
@pytest.mark.asyncio
async def test_list_coupons(client, coupon_data):
resp = await client.get("/coupons", headers=coupon_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
@pytest.mark.asyncio
async def test_list_coupons_by_store(client, coupon_data):
store_id = str(coupon_data["store"].id)
resp = await client.get(f"/coupons?store_id={store_id}", headers=coupon_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) >= 1
@pytest.mark.asyncio
async def test_relevant_coupons_empty(client, auth_headers):
"""No purchases means no relevant coupons."""
resp = await client.get("/coupons/relevant", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
+90
View File
@@ -0,0 +1,90 @@
"""Integration tests for price endpoints."""
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
@pytest.fixture
async def price_data(db_engine, auth_headers):
"""Seed products with price history showing an increase."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
store = Store(name="Walmart", slug="walmart")
product = NormalizedProduct(
canonical_name="Tide Pods 42ct",
category="household",
brand="Tide",
)
session.add_all([store, product])
await session.commit()
await session.refresh(store)
await session.refresh(product)
# Two price points — second is higher (increase)
ph1 = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 2, 1),
regular_price=Decimal("12.99"),
source="receipt",
)
ph2 = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 3, 1),
regular_price=Decimal("14.49"),
source="receipt",
)
session.add_all([ph1, ph2])
await session.commit()
return {"product": product, "store": store, "headers": auth_headers}
@pytest.mark.asyncio
async def test_price_trends(client, price_data):
resp = await client.get("/prices/trends", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
assert data[0]["product_name"] == "Tide Pods 42ct"
assert len(data[0]["data_points"]) == 2
@pytest.mark.asyncio
async def test_price_trends_by_category(client, price_data):
resp = await client.get("/prices/trends?category=household", headers=price_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 1
resp = await client.get("/prices/trends?category=nonexistent", headers=price_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 0
@pytest.mark.asyncio
async def test_price_increases(client, price_data):
resp = await client.get("/prices/increases", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
increase = data[0]
assert increase["old_price"] == 12.99
assert increase["new_price"] == 14.49
assert increase["increase_pct"] > 0
@pytest.mark.asyncio
async def test_price_comparison(client, price_data):
pid = str(price_data["product"].id)
resp = await client.get(f"/prices/comparison?product_ids={pid}", headers=price_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
assert data[0]["product_name"] == "Tide Pods 42ct"
assert len(data[0]["prices"]) >= 1
+94
View File
@@ -0,0 +1,94 @@
"""Integration tests for product endpoints."""
import uuid
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
@pytest.fixture
async def product_data(db_engine, auth_headers):
"""Seed products and price history."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
store = Store(name="Meijer", slug="meijer")
product = NormalizedProduct(
canonical_name="Cheerios 18oz",
category="pantry",
brand="General Mills",
upc_variants=["016000275263"],
)
session.add_all([store, product])
await session.commit()
await session.refresh(store)
await session.refresh(product)
ph1 = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 3, 1),
regular_price=Decimal("4.99"),
source="receipt",
)
ph2 = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 3, 10),
regular_price=Decimal("5.49"),
source="receipt",
)
session.add_all([ph1, ph2])
await session.commit()
return {"product": product, "store": store, "headers": auth_headers}
@pytest.mark.asyncio
async def test_list_products(client, product_data):
resp = await client.get("/products", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
assert data[0]["name"] == "Cheerios 18oz"
@pytest.mark.asyncio
async def test_search_products(client, product_data):
resp = await client.get("/products?q=Cheerios", headers=product_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 1
resp = await client.get("/products?q=nonexistent", headers=product_data["headers"])
assert resp.status_code == 200
assert len(resp.json()) == 0
@pytest.mark.asyncio
async def test_get_product_detail(client, product_data):
pid = str(product_data["product"].id)
resp = await client.get(f"/products/{pid}", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Cheerios 18oz"
assert data["brand"] == "General Mills"
assert len(data["prices_by_store"]) >= 1
@pytest.mark.asyncio
async def test_get_product_not_found(client, auth_headers):
resp = await client.get(f"/products/{uuid.uuid4()}", headers=auth_headers)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_get_product_prices(client, product_data):
pid = str(product_data["product"].id)
resp = await client.get(f"/products/{pid}/prices", headers=product_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Cheerios 18oz"
assert len(data["data_points"]) == 2
+73
View File
@@ -0,0 +1,73 @@
"""Integration tests for public endpoints (no auth)."""
import uuid
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.models import NormalizedProduct, PriceHistory, Store
@pytest.fixture
async def public_data(db_engine):
"""Seed data for public endpoints."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
store = Store(name="Target", slug="target")
product = NormalizedProduct(
canonical_name="Skippy PB 16oz",
category="pantry",
brand="Skippy",
)
session.add_all([store, product])
await session.commit()
await session.refresh(store)
await session.refresh(product)
ph = PriceHistory(
normalized_product_id=product.id,
store_id=store.id,
observed_date=date(2026, 3, 5),
regular_price=Decimal("3.99"),
source="receipt",
)
session.add(ph)
await session.commit()
return {"product": product, "store": store}
@pytest.mark.asyncio
async def test_public_trend(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/trends/{pid}")
assert resp.status_code == 200
data = resp.json()
assert data["product_name"] == "Skippy PB 16oz"
assert len(data["data_points"]) == 1
@pytest.mark.asyncio
async def test_public_trend_not_found(client):
resp = await client.get(f"/public/trends/{uuid.uuid4()}")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_public_store_comparison(client, public_data):
pid = str(public_data["product"].id)
resp = await client.get(f"/public/store-comparison?product_ids={pid}")
assert resp.status_code == 200
data = resp.json()
assert len(data["products"]) == 1
@pytest.mark.asyncio
async def test_public_inflation(client, public_data):
resp = await client.get("/public/inflation")
assert resp.status_code == 200
data = resp.json()
assert "categories" in data
assert "cartsnitch_index" in data
+95
View File
@@ -0,0 +1,95 @@
"""Integration tests for purchase endpoints."""
import uuid
from datetime import date
from decimal import Decimal
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from cartsnitch_api.auth.jwt import create_access_token
from cartsnitch_api.models import Purchase, PurchaseItem, Store, User
@pytest.fixture
async def purchase_data(db_engine):
"""Seed a user, store, purchase, and items."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
from cartsnitch_api.auth.passwords import hash_password
user = User(
email="buyer@example.com",
hashed_password=hash_password("testpass123"),
display_name="Buyer",
)
store = Store(name="Kroger", slug="kroger")
session.add_all([user, store])
await session.commit()
await session.refresh(user)
await session.refresh(store)
purchase = Purchase(
user_id=user.id,
store_id=store.id,
receipt_id="receipt-001",
purchase_date=date(2026, 3, 10),
total=Decimal("42.50"),
)
session.add(purchase)
await session.commit()
await session.refresh(purchase)
item = PurchaseItem(
purchase_id=purchase.id,
product_name_raw="Organic Milk 1gal",
quantity=Decimal("1"),
unit_price=Decimal("5.99"),
extended_price=Decimal("5.99"),
)
session.add(item)
await session.commit()
token = create_access_token(user.id)
return {
"user": user,
"store": store,
"purchase": purchase,
"headers": {"Authorization": f"Bearer {token}"},
}
@pytest.mark.asyncio
async def test_list_purchases(client, purchase_data):
resp = await client.get("/purchases", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["store_name"] == "Kroger"
assert data[0]["total"] == 42.50
@pytest.mark.asyncio
async def test_get_purchase_detail(client, purchase_data):
pid = str(purchase_data["purchase"].id)
resp = await client.get(f"/purchases/{pid}", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert len(data["line_items"]) == 1
assert data["line_items"][0]["name"] == "Organic Milk 1gal"
@pytest.mark.asyncio
async def test_get_purchase_not_found(client, auth_headers):
resp = await client.get(f"/purchases/{uuid.uuid4()}", headers=auth_headers)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_purchase_stats(client, purchase_data):
resp = await client.get("/purchases/stats", headers=purchase_data["headers"])
assert resp.status_code == 200
data = resp.json()
assert data["total_spent"] == 42.50
assert data["purchase_count"] == 1
assert "Kroger" in data["by_store"]
+77
View File
@@ -0,0 +1,77 @@
"""Integration tests for store endpoints."""
import pytest
from cartsnitch_api.models import Store
@pytest.fixture
async def seeded_store(db_engine):
"""Insert a test store directly into the DB."""
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
store = Store(name="Meijer", slug="meijer", logo_url=None, website_url=None)
session.add(store)
await session.commit()
await session.refresh(store)
return store
@pytest.mark.asyncio
async def test_list_stores(client, seeded_store):
resp = await client.get("/stores")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 1
assert data[0]["slug"] == "meijer"
@pytest.mark.asyncio
async def test_list_user_stores_empty(client, auth_headers):
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.status_code == 200
assert resp.json() == []
@pytest.mark.asyncio
async def test_connect_and_disconnect_store(client, auth_headers, seeded_store):
# Connect
resp = await client.post(
"/me/stores/meijer/connect",
headers=auth_headers,
json={"credentials": None},
)
assert resp.status_code == 201
assert resp.json()["connected"] is True
# List should show connected
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.status_code == 200
assert len(resp.json()) == 1
# Disconnect
resp = await client.delete("/me/stores/meijer", headers=auth_headers)
assert resp.status_code == 204
# List should be empty again
resp = await client.get("/me/stores", headers=auth_headers)
assert resp.json() == []
@pytest.mark.asyncio
async def test_connect_nonexistent_store(client, auth_headers):
resp = await client.post(
"/me/stores/nonexistent/connect",
headers=auth_headers,
json={},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_connect_duplicate_store(client, auth_headers, seeded_store):
await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
resp = await client.post("/me/stores/meijer/connect", headers=auth_headers, json={})
assert resp.status_code == 409
View File