forked from cartsnitch/cartsnitch
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1267caf43c | |||
| 015401861a | |||
| 9891e1aefb | |||
| 69ad161e36 | |||
| 485f890df3 | |||
| bf3ed0ede3 | |||
| 3f41eb7346 | |||
| 6cbd1ef298 | |||
| 94214f762e | |||
| 562c6ef6f6 | |||
| ccc8189d88 | |||
| 86594e4a8e | |||
| c2f1a83c1d | |||
| 6f8e5a9577 | |||
| bbfa816e57 | |||
| 5904eb03a2 | |||
| 87b6433ff7 | |||
| d7c9938f7e | |||
| 02434060ee |
@@ -13,13 +13,14 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
redis_url: str = "redis://localhost:6379/0"
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
jwt_secret_key: str
|
jwt_secret_key: str = "change-me-in-production"
|
||||||
jwt_algorithm: str = "HS256"
|
jwt_algorithm: str = "HS256"
|
||||||
jwt_access_token_expire_minutes: int = 15
|
jwt_access_token_expire_minutes: int = 15
|
||||||
jwt_refresh_token_expire_days: int = 7
|
jwt_refresh_token_expire_days: int = 7
|
||||||
|
|
||||||
service_key: str
|
service_key: str = "change-me-in-production"
|
||||||
fernet_key: str
|
# Valid Fernet key for local dev — MUST be overridden in production
|
||||||
|
fernet_key: str = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||||
|
|
||||||
auth_service_url: str = "http://auth:3001"
|
auth_service_url: str = "http://auth:3001"
|
||||||
|
|
||||||
@@ -34,26 +35,9 @@ class Settings(BaseSettings):
|
|||||||
rate_limit_window_seconds: int = 60
|
rate_limit_window_seconds: int = 60
|
||||||
rate_limit_enabled: bool = True
|
rate_limit_enabled: bool = True
|
||||||
|
|
||||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_secrets(self):
|
def validate_fernet_key(self):
|
||||||
if not self.jwt_secret_key or self.jwt_secret_key in self._PLACEHOLDER_VALUES:
|
"""Validate fernet_key is a valid 32-byte url-safe base64 key at startup."""
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_JWT_SECRET_KEY must be set to a secure value. "
|
|
||||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
if not self.service_key or self.service_key in self._PLACEHOLDER_VALUES:
|
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_SERVICE_KEY must be set to a secure value. "
|
|
||||||
'Generate one with: python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
if not self.fernet_key or self.fernet_key in self._PLACEHOLDER_VALUES:
|
|
||||||
raise ValueError(
|
|
||||||
"CARTSNITCH_FERNET_KEY must be set to a valid Fernet key. "
|
|
||||||
"Generate one with: python -c "
|
|
||||||
"'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
decoded = base64.urlsafe_b64decode(self.fernet_key.encode())
|
||||||
if len(decoded) != 32:
|
if len(decoded) != 32:
|
||||||
|
|||||||
@@ -18,14 +18,10 @@ router = APIRouter(prefix="/public", tags=["public"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
@router.get("/trends/{product_id}", response_model=PublicTrendResponse)
|
||||||
async def public_price_trend(
|
async def public_price_trend(product_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||||
product_id: UUID,
|
|
||||||
days: int = Query(90, ge=1, le=365),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
svc = PublicService(db)
|
svc = PublicService(db)
|
||||||
try:
|
try:
|
||||||
return await svc.get_trend(product_id, days=days)
|
return await svc.get_trend(product_id)
|
||||||
except LookupError:
|
except LookupError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Product not found"
|
||||||
@@ -35,7 +31,6 @@ async def public_price_trend(
|
|||||||
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
@router.get("/store-comparison", response_model=PublicStoreComparisonResponse)
|
||||||
async def public_store_comparison(
|
async def public_store_comparison(
|
||||||
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
product_ids: Annotated[list[UUID], Query(max_length=20)],
|
||||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
if not product_ids:
|
if not product_ids:
|
||||||
@@ -44,14 +39,10 @@ async def public_store_comparison(
|
|||||||
detail="At least one product_id is required",
|
detail="At least one product_id is required",
|
||||||
)
|
)
|
||||||
svc = PublicService(db)
|
svc = PublicService(db)
|
||||||
return await svc.get_store_comparison(product_ids, category=category)
|
return await svc.get_store_comparison(product_ids)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/inflation", response_model=PublicInflationResponse)
|
@router.get("/inflation", response_model=PublicInflationResponse)
|
||||||
async def public_inflation(
|
async def public_inflation(db: AsyncSession = Depends(get_db)):
|
||||||
category: str | None = Query(None, max_length=100, pattern=r"^[a-zA-Z0-9 _-]+$"),
|
|
||||||
period: str = Query("all-time", pattern=r"^(all-time|1y|6m|3m|1m)$"),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
svc = PublicService(db)
|
svc = PublicService(db)
|
||||||
return await svc.get_inflation(category=category, period=period)
|
return await svc.get_inflation()
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Public service — unauthenticated price transparency endpoints."""
|
"""Public service — unauthenticated price transparency endpoints."""
|
||||||
|
|
||||||
from datetime import date, timedelta
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import and_, func, select
|
from sqlalchemy import and_, func, select
|
||||||
@@ -14,7 +13,7 @@ class PublicService:
|
|||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def get_trend(self, product_id: UUID, days: int = 90) -> dict:
|
async def get_trend(self, product_id: UUID) -> dict:
|
||||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||||
|
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
@@ -24,13 +23,9 @@ class PublicService:
|
|||||||
if not product:
|
if not product:
|
||||||
raise LookupError("Product not found")
|
raise LookupError("Product not found")
|
||||||
|
|
||||||
date_threshold = date.today() - timedelta(days=days)
|
|
||||||
prices_result = await self.db.execute(
|
prices_result = await self.db.execute(
|
||||||
select(PriceHistory)
|
select(PriceHistory)
|
||||||
.where(
|
.where(PriceHistory.normalized_product_id == product_id)
|
||||||
PriceHistory.normalized_product_id == product_id,
|
|
||||||
PriceHistory.observed_date >= date_threshold,
|
|
||||||
)
|
|
||||||
.options(selectinload(PriceHistory.store))
|
.options(selectinload(PriceHistory.store))
|
||||||
.order_by(PriceHistory.observed_date)
|
.order_by(PriceHistory.observed_date)
|
||||||
)
|
)
|
||||||
@@ -50,25 +45,20 @@ class PublicService:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_store_comparison(
|
async def get_store_comparison(self, product_ids: list[UUID]) -> dict:
|
||||||
self, product_ids: list[UUID], category: str | None = None
|
|
||||||
) -> dict:
|
|
||||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||||
|
|
||||||
if not product_ids:
|
if not product_ids:
|
||||||
return {"products": []}
|
return {"products": []}
|
||||||
|
|
||||||
product_query = select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
# Fetch all products in one query
|
||||||
if category:
|
prod_result = await self.db.execute(
|
||||||
product_query = product_query.where(NormalizedProduct.category == category)
|
select(NormalizedProduct).where(NormalizedProduct.id.in_(product_ids))
|
||||||
prod_result = await self.db.execute(product_query)
|
)
|
||||||
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
products_by_id = {p.id: p for p in prod_result.scalars().all()}
|
||||||
|
|
||||||
if not products_by_id:
|
# Latest prices for all requested products in one query
|
||||||
return {"products": []}
|
subq = latest_price_per_store(product_ids)
|
||||||
|
|
||||||
filtered_product_ids = list(products_by_id.keys())
|
|
||||||
subq = latest_price_per_store(filtered_product_ids)
|
|
||||||
prices_result = await self.db.execute(
|
prices_result = await self.db.execute(
|
||||||
select(PriceHistory)
|
select(PriceHistory)
|
||||||
.join(
|
.join(
|
||||||
@@ -79,17 +69,18 @@ class PublicService:
|
|||||||
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
PriceHistory.normalized_product_id == subq.c.normalized_product_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.where(PriceHistory.normalized_product_id.in_(filtered_product_ids))
|
.where(PriceHistory.normalized_product_id.in_(product_ids))
|
||||||
.options(selectinload(PriceHistory.store))
|
.options(selectinload(PriceHistory.store))
|
||||||
)
|
)
|
||||||
all_prices = prices_result.scalars().all()
|
all_prices = prices_result.scalars().all()
|
||||||
|
|
||||||
|
# Group by product
|
||||||
prices_by_product: dict[UUID, list] = {}
|
prices_by_product: dict[UUID, list] = {}
|
||||||
for ph in all_prices:
|
for ph in all_prices:
|
||||||
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
prices_by_product.setdefault(ph.normalized_product_id, []).append(ph)
|
||||||
|
|
||||||
products = []
|
products = []
|
||||||
for pid in filtered_product_ids:
|
for pid in product_ids:
|
||||||
product = products_by_id.get(pid)
|
product = products_by_id.get(pid)
|
||||||
if not product:
|
if not product:
|
||||||
continue
|
continue
|
||||||
@@ -111,29 +102,19 @@ class PublicService:
|
|||||||
|
|
||||||
return {"products": products}
|
return {"products": products}
|
||||||
|
|
||||||
async def get_inflation(self, category: str | None = None, period: str = "all-time") -> dict:
|
async def get_inflation(self) -> dict:
|
||||||
"""Aggregate price change stats. Compares average prices across periods."""
|
"""Aggregate price change stats. Compares average prices across periods."""
|
||||||
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
from cartsnitch_api.models import NormalizedProduct, PriceHistory
|
||||||
|
|
||||||
date_threshold = None
|
# Get average prices grouped by category for recent vs older data
|
||||||
if period != "all-time":
|
result = await self.db.execute(
|
||||||
days_map = {"1y": 365, "6m": 180, "3m": 90, "1m": 30}
|
select(
|
||||||
days = days_map.get(period, 365)
|
NormalizedProduct.category,
|
||||||
date_threshold = date.today() - timedelta(days=days)
|
func.avg(PriceHistory.regular_price),
|
||||||
|
)
|
||||||
query = select(
|
.join(NormalizedProduct)
|
||||||
NormalizedProduct.category,
|
.group_by(NormalizedProduct.category)
|
||||||
func.avg(PriceHistory.regular_price),
|
)
|
||||||
).join(NormalizedProduct)
|
|
||||||
|
|
||||||
if category:
|
|
||||||
query = query.where(NormalizedProduct.category == category)
|
|
||||||
if date_threshold:
|
|
||||||
query = query.where(PriceHistory.observed_date >= date_threshold)
|
|
||||||
|
|
||||||
query = query.group_by(NormalizedProduct.category)
|
|
||||||
|
|
||||||
result = await self.db.execute(query)
|
|
||||||
categories = {}
|
categories = {}
|
||||||
for row in result.all():
|
for row in result.all():
|
||||||
cat, avg_price = row
|
cat, avg_price = row
|
||||||
@@ -141,7 +122,7 @@ class PublicService:
|
|||||||
categories[cat] = float(avg_price) if avg_price else 0.0
|
categories[cat] = float(avg_price) if avg_price else 0.0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"period": period,
|
"period": "all-time",
|
||||||
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
"cartsnitch_index": sum(categories.values()) / max(len(categories), 1),
|
||||||
"cpi_baseline": 100.0,
|
"cpi_baseline": 100.0,
|
||||||
"categories": categories,
|
"categories": categories,
|
||||||
|
|||||||
+7
-34
@@ -19,25 +19,6 @@ from cartsnitch_api.database import get_db
|
|||||||
from cartsnitch_api.main import create_app
|
from cartsnitch_api.main import create_app
|
||||||
from cartsnitch_api.models import Base
|
from cartsnitch_api.models import Base
|
||||||
|
|
||||||
TEST_JWT_SECRET = secrets.token_urlsafe(32)
|
|
||||||
TEST_SERVICE_KEY = secrets.token_urlsafe(32)
|
|
||||||
TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_test_settings():
|
|
||||||
original_jwt = cartsnitch_settings.jwt_secret_key
|
|
||||||
original_service = cartsnitch_settings.service_key
|
|
||||||
original_fernet = cartsnitch_settings.fernet_key
|
|
||||||
cartsnitch_settings.jwt_secret_key = TEST_JWT_SECRET
|
|
||||||
cartsnitch_settings.service_key = TEST_SERVICE_KEY
|
|
||||||
cartsnitch_settings.fernet_key = TEST_FERNET_KEY
|
|
||||||
yield
|
|
||||||
cartsnitch_settings.jwt_secret_key = original_jwt
|
|
||||||
cartsnitch_settings.service_key = original_service
|
|
||||||
cartsnitch_settings.fernet_key = original_fernet
|
|
||||||
|
|
||||||
|
|
||||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
@@ -79,8 +60,7 @@ async def db_engine():
|
|||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
# Create Better-Auth tables (not managed by SQLAlchemy models)
|
||||||
await conn.execute(
|
await conn.execute(text("""
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sessions (
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
token TEXT NOT NULL UNIQUE,
|
token TEXT NOT NULL UNIQUE,
|
||||||
@@ -91,10 +71,8 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
await conn.execute(text("""
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
@@ -110,10 +88,8 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
await conn.execute(text("""
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS verifications (
|
CREATE TABLE IF NOT EXISTS verifications (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
identifier TEXT NOT NULL,
|
identifier TEXT NOT NULL,
|
||||||
@@ -122,8 +98,7 @@ async def db_engine():
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""))
|
||||||
)
|
|
||||||
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
@@ -158,9 +133,7 @@ async def client(db_engine):
|
|||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
async def _create_test_user_and_session(
|
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||||
client: AsyncClient, db_engine, **user_overrides
|
|
||||||
) -> tuple[dict, str]:
|
|
||||||
"""Create a test user and a valid session directly in the DB.
|
"""Create a test user and a valid session directly in the DB.
|
||||||
|
|
||||||
Returns (user_dict, session_token). Better-Auth stores the raw token
|
Returns (user_dict, session_token). Better-Auth stores the raw token
|
||||||
|
|||||||
@@ -71,97 +71,3 @@ async def test_public_inflation(client, public_data):
|
|||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert "categories" in data
|
assert "categories" in data
|
||||||
assert "cartsnitch_index" in data
|
assert "cartsnitch_index" in data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trend_invalid_uuid(client):
|
|
||||||
resp = await client.get("/public/trends/not-a-uuid")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trend_days_zero(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(f"/public/trends/{pid}?days=0")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trend_days_negative(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(f"/public/trends/{pid}?days=-1")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trend_days_over_max(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(f"/public/trends/{pid}?days=999")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trend_days_valid(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(f"/public/trends/{pid}?days=30")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "product_name" in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_comparison_empty_list(client):
|
|
||||||
resp = await client.get("/public/store-comparison")
|
|
||||||
assert resp.status_code == 400
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_comparison_category_xss(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(
|
|
||||||
f"/public/store-comparison?product_ids={pid}&category=<script>alert(1)</script>"
|
|
||||||
)
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_comparison_category_sql_injection(client, public_data):
|
|
||||||
pid = str(public_data["product"].id)
|
|
||||||
resp = await client.get(f"/public/store-comparison?product_ids={pid}&category='; DROP TABLE--")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inflation_invalid_period(client, public_data):
|
|
||||||
resp = await client.get("/public/inflation?period=10years")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inflation_valid_periods(client, public_data):
|
|
||||||
for period in ["all-time", "1y", "6m", "3m", "1m"]:
|
|
||||||
resp = await client.get(f"/public/inflation?period={period}")
|
|
||||||
assert resp.status_code == 200, f"period={period} failed"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inflation_category_too_long(client, public_data):
|
|
||||||
long_category = "x" * 200
|
|
||||||
resp = await client.get(f"/public/inflation?category={long_category}")
|
|
||||||
assert resp.status_code == 422
|
|
||||||
assert "detail" in resp.json()
|
|
||||||
assert "stack" not in resp.json()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user