Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1aff898545 | |||
| 24f0dd0e67 |
@@ -0,0 +1,38 @@
|
||||
"""Add GIN index on upc_variants and alter column to JSONB.
|
||||
|
||||
Revision ID: 009_add_gin_index_upc_variants
|
||||
Revises: 008_create_domain_tables
|
||||
Create Date: 2026-04-14
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "009_add_gin_index_upc_variants"
|
||||
down_revision = "008_create_domain_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"normalized_products",
|
||||
"upc_variants",
|
||||
type_=sa.dialects.postgresql.JSONB(),
|
||||
postgresql_using="upc_variants::jsonb",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_normalized_products_upc_variants_gin",
|
||||
"normalized_products",
|
||||
["upc_variants"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_normalized_products_upc_variants_gin", table_name="normalized_products")
|
||||
op.alter_column(
|
||||
"normalized_products",
|
||||
"upc_variants",
|
||||
type_=sa.JSON(),
|
||||
)
|
||||
@@ -1,51 +1,26 @@
|
||||
"""Redis/DragonflyDB caching helpers."""
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class CacheClient:
|
||||
"""Redis/DragonflyDB caching with connection pooling.
|
||||
"""Stub for Redis/DragonflyDB caching.
|
||||
|
||||
Will be used for expensive queries: price trends, product comparisons.
|
||||
Cache invalidation via Redis pub/sub events from other services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool: redis.ConnectionPool | None = None
|
||||
self._client: redis.Redis | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the Redis connection pool."""
|
||||
self._pool = redis.ConnectionPool.from_url(
|
||||
settings.redis_url,
|
||||
max_connections=20,
|
||||
decode_responses=True,
|
||||
)
|
||||
self._client = redis.Redis(connection_pool=self._pool)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the Redis connection pool."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
if self._pool:
|
||||
await self._pool.aclose()
|
||||
self.url = settings.redis_url
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
if not self._client:
|
||||
return None
|
||||
return await self._client.get(key)
|
||||
# TODO: implement with redis-py async
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None:
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.set(key, value, ex=ttl_seconds)
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.delete(key)
|
||||
|
||||
|
||||
cache_client = CacheClient()
|
||||
# TODO: implement with redis-py async
|
||||
pass
|
||||
|
||||
@@ -6,14 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=False,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@@ -21,8 +14,3 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def dispose_engine() -> None:
|
||||
"""Dispose the database engine, closing all pooled connections."""
|
||||
await engine.dispose()
|
||||
|
||||
@@ -5,8 +5,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
from cartsnitch_api.auth.routes import router as auth_router
|
||||
from cartsnitch_api.cache import cache_client
|
||||
from cartsnitch_api.database import dispose_engine
|
||||
from cartsnitch_api.middleware.cors import add_cors_middleware
|
||||
from cartsnitch_api.middleware.error_handler import add_error_handlers, add_error_monitor_middleware
|
||||
from cartsnitch_api.middleware.rate_limit import add_rate_limit_middleware
|
||||
@@ -25,10 +23,9 @@ from cartsnitch_api.routes.user import router as user_router
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await cache_client.initialize()
|
||||
# TODO: initialize DB session pool, Redis connection, service clients
|
||||
yield
|
||||
await cache_client.close()
|
||||
await dispose_engine()
|
||||
# TODO: cleanup connections
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
|
||||
@@ -4,7 +4,6 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
@@ -72,8 +71,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
return f"token:{token_hash}", _auth_limiter
|
||||
# Use last 16 chars of token as key to avoid storing full tokens
|
||||
return f"token:{token[-16:]}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key
|
||||
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
@@ -55,32 +53,3 @@ async def test_health_skips_rate_limit(client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert "x-ratelimit-limit" not in resp.headers
|
||||
|
||||
|
||||
class TestGetRateLimitKey:
|
||||
def _make_request(self, auth_header: str = "") -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.url.path = "/purchases"
|
||||
req.headers = {"authorization": auth_header} if auth_header else {}
|
||||
return req
|
||||
|
||||
def test_distinct_tokens_produce_distinct_keys(self):
|
||||
req1 = self._make_request("Bearer token_alpha_12345")
|
||||
req2 = self._make_request("Bearer token_beta_67890")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_token_produces_same_key(self):
|
||||
req1 = self._make_request("Bearer same_token_value_abc")
|
||||
req2 = self._make_request("Bearer same_token_value_abc")
|
||||
key1, _ = _get_rate_limit_key(req1)
|
||||
key2, _ = _get_rate_limit_key(req2)
|
||||
assert key1 == key2
|
||||
|
||||
def test_key_does_not_contain_raw_token_suffix(self):
|
||||
raw_token = "my_secret_jwt_token_xyz"
|
||||
req = self._make_request(f"Bearer {raw_token}")
|
||||
key, _ = _get_rate_limit_key(req)
|
||||
assert raw_token[-16:] not in key
|
||||
assert raw_token not in key
|
||||
|
||||
+6
-12
@@ -4,23 +4,17 @@ import pg from "pg";
|
||||
|
||||
const { Pool } = pg;
|
||||
|
||||
const pool = new Pool({
|
||||
connectionString:
|
||||
process.env.DATABASE_URL ??
|
||||
"postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||
});
|
||||
|
||||
const secret = process.env.BETTER_AUTH_SECRET;
|
||||
if (!secret) {
|
||||
throw new Error("BETTER_AUTH_SECRET environment variable is required");
|
||||
}
|
||||
|
||||
const databaseUrl = process.env.DATABASE_URL;
|
||||
if (!databaseUrl) {
|
||||
console.warn(
|
||||
"WARNING: DATABASE_URL is not set — using default localhost connection. " +
|
||||
"Set DATABASE_URL for production deployments."
|
||||
);
|
||||
}
|
||||
|
||||
export const pool = new Pool({
|
||||
connectionString: databaseUrl ?? "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||
});
|
||||
|
||||
export const auth = betterAuth({
|
||||
database: pool,
|
||||
basePath: "/auth",
|
||||
|
||||
+3
-17
@@ -1,6 +1,6 @@
|
||||
import { createServer } from "node:http";
|
||||
import { toNodeHandler } from "better-auth/node";
|
||||
import { auth, pool } from "./auth.js";
|
||||
import { auth } from "./auth.js";
|
||||
|
||||
const port = parseInt(process.env.PORT ?? "3001", 10);
|
||||
|
||||
@@ -9,22 +9,8 @@ const handler = toNodeHandler(auth);
|
||||
const server = createServer(async (req, res) => {
|
||||
// Health check
|
||||
if (req.url === "/health" && req.method === "GET") {
|
||||
try {
|
||||
const client = await pool.connect();
|
||||
try {
|
||||
await Promise.race([
|
||||
client.query("SELECT 1"),
|
||||
new Promise((_, reject) => setTimeout(() => reject(new Error("DB timeout")), 2000)),
|
||||
]);
|
||||
} finally {
|
||||
client.release();
|
||||
}
|
||||
res.writeHead(200, { "Content-Type": "application/json" });
|
||||
res.end(JSON.stringify({ status: "ok", db: "connected" }));
|
||||
} catch {
|
||||
res.writeHead(503, { "Content-Type": "application/json" });
|
||||
res.end(JSON.stringify({ status: "error", db: "unreachable" }));
|
||||
}
|
||||
res.writeHead(200, { "Content-Type": "application/json" });
|
||||
res.end(JSON.stringify({ status: "ok" }));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 24 KiB |
-1
Submodule cartsnitch deleted from a53daddb9a
@@ -3,6 +3,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from cartsnitch_common.constants import ProductCategory, SizeUnit
|
||||
@@ -26,7 +27,9 @@ class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
brand: Mapped[str | None] = mapped_column(String(200))
|
||||
size: Mapped[str | None] = mapped_column(String(50))
|
||||
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list)
|
||||
upc_variants: Mapped[list[str] | None] = mapped_column(
|
||||
JSON().with_variant(JSONB(), "postgresql"), default=list
|
||||
)
|
||||
|
||||
# Relationships
|
||||
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""Service-specific configuration for ReceiptWitness."""
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
_PLACEHOLDER_VALUES = {"change-me-in-production"}
|
||||
|
||||
|
||||
class ReceiptWitnessSettings(BaseSettings):
|
||||
model_config = {"env_prefix": "RW_"}
|
||||
|
||||
@@ -34,34 +30,5 @@ class ReceiptWitnessSettings(BaseSettings):
|
||||
# Mailgun inbound email webhook
|
||||
mailgun_webhook_signing_key: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_required_vars(self):
|
||||
errors = []
|
||||
if not self.session_encryption_key or self.session_encryption_key in _PLACEHOLDER_VALUES:
|
||||
errors.append(
|
||||
"RW_SESSION_ENCRYPTION_KEY must be set to a secure value. "
|
||||
'Generate one with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"'
|
||||
)
|
||||
if self.notifications_enabled and not self.resend_api_key:
|
||||
errors.append(
|
||||
"RW_RESEND_API_KEY must be set when RW_NOTIFICATIONS_ENABLED=true. "
|
||||
"Get an API key from https://resend.com/api-keys"
|
||||
)
|
||||
if errors:
|
||||
raise ValueError(
|
||||
"ReceiptWitness startup failed — missing required config:\n"
|
||||
+ "\n".join(f" - {e}" for e in errors)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class _LazySettings:
|
||||
_instance: ReceiptWitnessSettings | None = None
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if _LazySettings._instance is None:
|
||||
_LazySettings._instance = ReceiptWitnessSettings()
|
||||
return getattr(_LazySettings._instance, name)
|
||||
|
||||
|
||||
settings = _LazySettings()
|
||||
settings = ReceiptWitnessSettings()
|
||||
|
||||
@@ -5,12 +5,14 @@ Matches products across retailers by:
|
||||
2. Fuzzy name matching via token-based Jaccard similarity (lower confidence)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from cartsnitch_common.models.product import NormalizedProduct
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import cast, func, select, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -96,17 +98,24 @@ def jaccard_similarity(a: str, b: str) -> float:
|
||||
def match_by_upc(session: Session, upc: str) -> MatchResult | None:
|
||||
"""Find a normalized product by exact UPC match.
|
||||
|
||||
Loads products with upc_variants and checks membership in Python
|
||||
for cross-database compatibility (works on both PostgreSQL and SQLite).
|
||||
Uses PostgreSQL JSONB containment (@>) for production efficiency.
|
||||
Falls back to LIKE on SQLite for test compatibility.
|
||||
"""
|
||||
# TODO: Use PostgreSQL JSON containment query (@>) for production.
|
||||
# Current approach loads all products into memory — acceptable for tests
|
||||
# and small datasets, but will not scale.
|
||||
stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None))
|
||||
products = session.execute(stmt).scalars().all()
|
||||
for product in products:
|
||||
if product.upc_variants and upc in product.upc_variants:
|
||||
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
|
||||
dialect_name = session.bind.dialect.name if session.bind else "default"
|
||||
if dialect_name == "postgresql":
|
||||
stmt = select(NormalizedProduct).where(
|
||||
cast(NormalizedProduct.upc_variants, JSONB).op("@>")(
|
||||
func.cast(json.dumps([upc]), JSONB)
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = select(NormalizedProduct).where(
|
||||
NormalizedProduct.upc_variants.is_not(None),
|
||||
cast(NormalizedProduct.upc_variants, String).contains(upc),
|
||||
)
|
||||
product = session.execute(stmt).scalars().first()
|
||||
if product:
|
||||
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
"""Shared test fixtures."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
os.environ.setdefault("RW_SESSION_ENCRYPTION_KEY", "test-secret-key-for-unit-tests-only-32bytes!")
|
||||
os.environ.setdefault("RW_MAILGUN_WEBHOOK_SIGNING_KEY", "test-mailgun-signing-key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meijer_receipt_data() -> dict:
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import pytest
|
||||
from receiptwitness.config import ReceiptWitnessSettings
|
||||
|
||||
|
||||
def test_valid_config():
|
||||
s = ReceiptWitnessSettings(
|
||||
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8="
|
||||
)
|
||||
assert s.session_encryption_key
|
||||
|
||||
|
||||
def test_missing_session_encryption_key_raises():
|
||||
with pytest.raises(ValueError, match="RW_SESSION_ENCRYPTION_KEY"):
|
||||
ReceiptWitnessSettings(session_encryption_key="")
|
||||
|
||||
|
||||
def test_placeholder_session_encryption_key_raises():
|
||||
with pytest.raises(ValueError, match="RW_SESSION_ENCRYPTION_KEY"):
|
||||
ReceiptWitnessSettings(session_encryption_key="change-me-in-production")
|
||||
|
||||
|
||||
def test_notifications_enabled_without_resend_key_raises():
|
||||
with pytest.raises(ValueError, match="RW_RESEND_API_KEY"):
|
||||
ReceiptWitnessSettings(
|
||||
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
|
||||
notifications_enabled=True,
|
||||
resend_api_key="",
|
||||
)
|
||||
|
||||
|
||||
def test_notifications_disabled_without_resend_key_ok():
|
||||
s = ReceiptWitnessSettings(
|
||||
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
|
||||
notifications_enabled=False,
|
||||
resend_api_key="",
|
||||
)
|
||||
assert s.notifications_enabled is False
|
||||
|
||||
|
||||
def test_notifications_enabled_with_resend_key_ok():
|
||||
s = ReceiptWitnessSettings(
|
||||
session_encryption_key="7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=",
|
||||
notifications_enabled=True,
|
||||
resend_api_key="re_test_1234567890",
|
||||
)
|
||||
assert s.resend_api_key == "re_test_1234567890"
|
||||
Reference in New Issue
Block a user