Compare commits

..

2 Commits

Author SHA1 Message Date
Paperclip 1aff898545 fix: update vite to 6.4.2 to patch audit vulnerabilities
Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-14 14:31:02 +00:00
CartSnitch Engineer Bot 24f0dd0e67 fix: replace N+1 UPC query with SQL containment in normalization
- Add PostgreSQL JSONB containment (@>) query for match_by_upc
- Add SQLite LIKE fallback for test compatibility
- Update upc_variants column to JSONB with variant for cross-db support
- Add GIN index migration for upc_variants

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-14 11:59:28 +00:00
6 changed files with 68 additions and 50 deletions
@@ -0,0 +1,38 @@
"""Add GIN index on upc_variants and alter column to JSONB.
Revision ID: 009_add_gin_index_upc_variants
Revises: 008_create_domain_tables
Create Date: 2026-04-14
"""
import sqlalchemy as sa
from alembic import op
revision = "009_add_gin_index_upc_variants"
down_revision = "008_create_domain_tables"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"normalized_products",
"upc_variants",
type_=sa.dialects.postgresql.JSONB(),
postgresql_using="upc_variants::jsonb",
)
op.create_index(
"ix_normalized_products_upc_variants_gin",
"normalized_products",
["upc_variants"],
postgresql_using="gin",
)
def downgrade() -> None:
op.drop_index("ix_normalized_products_upc_variants_gin", table_name="normalized_products")
op.alter_column(
"normalized_products",
"upc_variants",
type_=sa.JSON(),
)
@@ -4,7 +4,6 @@ Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints. Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
""" """
import hashlib
import time import time
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
@@ -72,8 +71,8 @@ def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:] token = auth_header[7:]
token_hash = hashlib.sha256(token.encode()).hexdigest() # Use last 16 chars of token as key to avoid storing full tokens
return f"token:{token_hash}", _auth_limiter return f"token:{token[-16:]}", _auth_limiter
# Fallback to IP for unauthenticated non-public endpoints # Fallback to IP for unauthenticated non-public endpoints
return f"ip:{_get_client_ip(request)}", _public_limiter return f"ip:{_get_client_ip(request)}", _public_limiter
+1 -32
View File
@@ -1,10 +1,8 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
from unittest.mock import MagicMock
import pytest import pytest
from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter, _get_rate_limit_key from cartsnitch_api.middleware.rate_limit import _SlidingWindowCounter
class TestSlidingWindowCounter: class TestSlidingWindowCounter:
@@ -55,32 +53,3 @@ async def test_health_skips_rate_limit(client):
resp = await client.get("/health") resp = await client.get("/health")
assert resp.status_code == 200 assert resp.status_code == 200
assert "x-ratelimit-limit" not in resp.headers assert "x-ratelimit-limit" not in resp.headers
class TestGetRateLimitKey:
def _make_request(self, auth_header: str = "") -> MagicMock:
req = MagicMock()
req.url.path = "/purchases"
req.headers = {"authorization": auth_header} if auth_header else {}
return req
def test_distinct_tokens_produce_distinct_keys(self):
req1 = self._make_request("Bearer token_alpha_12345")
req2 = self._make_request("Bearer token_beta_67890")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 != key2
def test_same_token_produces_same_key(self):
req1 = self._make_request("Bearer same_token_value_abc")
req2 = self._make_request("Bearer same_token_value_abc")
key1, _ = _get_rate_limit_key(req1)
key2, _ = _get_rate_limit_key(req2)
assert key1 == key2
def test_key_does_not_contain_raw_token_suffix(self):
raw_token = "my_secret_jwt_token_xyz"
req = self._make_request(f"Bearer {raw_token}")
key, _ = _get_rate_limit_key(req)
assert raw_token[-16:] not in key
assert raw_token not in key
@@ -3,6 +3,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import JSON, String from sqlalchemy import JSON, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from cartsnitch_common.constants import ProductCategory, SizeUnit from cartsnitch_common.constants import ProductCategory, SizeUnit
@@ -26,7 +27,9 @@ class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base):
brand: Mapped[str | None] = mapped_column(String(200)) brand: Mapped[str | None] = mapped_column(String(200))
size: Mapped[str | None] = mapped_column(String(50)) size: Mapped[str | None] = mapped_column(String(50))
size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) size_unit: Mapped[SizeUnit | None] = mapped_column(String(10))
upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) upc_variants: Mapped[list[str] | None] = mapped_column(
JSON().with_variant(JSONB(), "postgresql"), default=list
)
# Relationships # Relationships
purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product")
+3 -3
View File
@@ -9805,9 +9805,9 @@
} }
}, },
"node_modules/vite": { "node_modules/vite": {
"version": "6.4.1", "version": "6.4.2",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz", "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.2.tgz",
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==", "integrity": "sha512-2N/55r4JDJ4gdrCvGgINMy+HH3iRpNIz8K6SFwVsA+JbQScLiC+clmAxBgwiSPgcG9U15QmvqCGWzMbqda5zGQ==",
"devOptional": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
@@ -5,12 +5,14 @@ Matches products across retailers by:
2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) 2. Fuzzy name matching via token-based Jaccard similarity (lower confidence)
""" """
import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from cartsnitch_common.models.product import NormalizedProduct from cartsnitch_common.models.product import NormalizedProduct
from sqlalchemy import select from sqlalchemy import cast, func, select, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -96,17 +98,24 @@ def jaccard_similarity(a: str, b: str) -> float:
def match_by_upc(session: Session, upc: str) -> MatchResult | None: def match_by_upc(session: Session, upc: str) -> MatchResult | None:
"""Find a normalized product by exact UPC match. """Find a normalized product by exact UPC match.
Loads products with upc_variants and checks membership in Python Uses PostgreSQL JSONB containment (@>) for production efficiency.
for cross-database compatibility (works on both PostgreSQL and SQLite). Falls back to LIKE on SQLite for test compatibility.
""" """
# TODO: Use PostgreSQL JSON containment query (@>) for production. dialect_name = session.bind.dialect.name if session.bind else "default"
# Current approach loads all products into memory — acceptable for tests if dialect_name == "postgresql":
# and small datasets, but will not scale. stmt = select(NormalizedProduct).where(
stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None)) cast(NormalizedProduct.upc_variants, JSONB).op("@>")(
products = session.execute(stmt).scalars().all() func.cast(json.dumps([upc]), JSONB)
for product in products: )
if product.upc_variants and upc in product.upc_variants: )
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) else:
stmt = select(NormalizedProduct).where(
NormalizedProduct.upc_variants.is_not(None),
cast(NormalizedProduct.upc_variants, String).contains(upc),
)
product = session.execute(stmt).scalars().first()
if product:
return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC)
return None return None