diff --git a/api/alembic/versions/009_add_gin_index_upc_variants.py b/api/alembic/versions/009_add_gin_index_upc_variants.py new file mode 100644 index 0000000..82f1e97 --- /dev/null +++ b/api/alembic/versions/009_add_gin_index_upc_variants.py @@ -0,0 +1,38 @@ +"""Add GIN index on upc_variants and alter column to JSONB. + +Revision ID: 009_add_gin_index_upc_variants +Revises: 008_create_domain_tables +Create Date: 2026-04-14 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "009_add_gin_index_upc_variants" +down_revision = "008_create_domain_tables" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "normalized_products", + "upc_variants", + type_=sa.dialects.postgresql.JSONB(), + postgresql_using="upc_variants::jsonb", + ) + op.create_index( + "ix_normalized_products_upc_variants_gin", + "normalized_products", + ["upc_variants"], + postgresql_using="gin", + ) + + +def downgrade() -> None: + op.drop_index("ix_normalized_products_upc_variants_gin", table_name="normalized_products") + op.alter_column( + "normalized_products", + "upc_variants", + type_=sa.JSON(), + ) diff --git a/common/src/cartsnitch_common/models/product.py b/common/src/cartsnitch_common/models/product.py index 215e57e..61e0134 100644 --- a/common/src/cartsnitch_common/models/product.py +++ b/common/src/cartsnitch_common/models/product.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from sqlalchemy import JSON, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship from cartsnitch_common.constants import ProductCategory, SizeUnit @@ -26,7 +27,9 @@ class NormalizedProduct(UUIDPrimaryKeyMixin, TimestampMixin, Base): brand: Mapped[str | None] = mapped_column(String(200)) size: Mapped[str | None] = mapped_column(String(50)) size_unit: Mapped[SizeUnit | None] = mapped_column(String(10)) - upc_variants: Mapped[list[str] | None] = mapped_column(JSON, default=list) + upc_variants: Mapped[list[str] | None] = mapped_column( + JSON().with_variant(JSONB(), "postgresql"), default=list + ) # Relationships purchase_items: Mapped[list["PurchaseItem"]] = relationship(back_populates="normalized_product") diff --git a/receiptwitness/src/receiptwitness/pipeline/normalization.py b/receiptwitness/src/receiptwitness/pipeline/normalization.py index c1fade9..a714020 100644 --- a/receiptwitness/src/receiptwitness/pipeline/normalization.py +++ b/receiptwitness/src/receiptwitness/pipeline/normalization.py @@ -5,12 +5,14 @@ Matches products across retailers by: 2. Fuzzy name matching via token-based Jaccard similarity (lower confidence) """ +import json import re from dataclasses import dataclass from enum import StrEnum from cartsnitch_common.models.product import NormalizedProduct -from sqlalchemy import select +from sqlalchemy import cast, func, select, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Session @@ -96,17 +98,24 @@ def jaccard_similarity(a: str, b: str) -> float: def match_by_upc(session: Session, upc: str) -> MatchResult | None: """Find a normalized product by exact UPC match. - Loads products with upc_variants and checks membership in Python - for cross-database compatibility (works on both PostgreSQL and SQLite). + Uses PostgreSQL JSONB containment (@>) for production efficiency. + Falls back to LIKE on SQLite for test compatibility. """ - # TODO: Use PostgreSQL JSON containment query (@>) for production. - # Current approach loads all products into memory — acceptable for tests - # and small datasets, but will not scale. - stmt = select(NormalizedProduct).where(NormalizedProduct.upc_variants.is_not(None)) - products = session.execute(stmt).scalars().all() - for product in products: - if product.upc_variants and upc in product.upc_variants: - return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) + dialect_name = session.bind.dialect.name if session.bind else "default" + if dialect_name == "postgresql": + stmt = select(NormalizedProduct).where( + cast(NormalizedProduct.upc_variants, JSONB).op("@>")( + func.cast(json.dumps([upc]), JSONB) + ) + ) + else: + stmt = select(NormalizedProduct).where( + NormalizedProduct.upc_variants.is_not(None), + cast(NormalizedProduct.upc_variants, String).contains(upc), + ) + product = session.execute(stmt).scalars().first() + if product: + return MatchResult(product=product, confidence=1.0, method=MatchMethod.UPC) return None