diff --git a/Dockerfile b/Dockerfile index bb5d3bd..e3b4bbf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,10 +12,14 @@ RUN pip install --no-cache-dir --prefix=/install . FROM python:3.12-slim AS prod +RUN apt-get update && apt-get install -y --no-install-recommends libpq5 && rm -rf /var/lib/apt/lists/* + WORKDIR /app RUN adduser --system --group --uid 1000 app COPY --from=build /install /usr/local COPY src/ ./src/ +COPY alembic.ini ./ +COPY alembic/ ./alembic/ USER 1000 EXPOSE 8000 @@ -23,4 +27,4 @@ EXPOSE 8000 HEALTHCHECK --interval=30s --timeout=3s \ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" -CMD ["uvicorn", "cartsnitch_api.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["sh", "-c", "python -m alembic upgrade head && uvicorn cartsnitch_api.main:app --host 0.0.0.0 --port 8000"] diff --git a/alembic/env.py b/alembic/env.py index 3e563e1..6844fba 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -18,7 +18,7 @@ if not db_url: "CARTSNITCH_DATABASE_URL_SYNC must be set. " "Example: postgresql://user:pass@localhost:5432/cartsnitch" ) -config.set_main_option("sqlalchemy.url", db_url) +config.set_main_option("sqlalchemy.url", db_url.replace("%", "%%")) target_metadata = Base.metadata @@ -31,6 +31,7 @@ def run_migrations_offline() -> None: target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, + version_table_column_width=128, ) with context.begin_transaction(): context.run_migrations() @@ -44,9 +45,20 @@ def run_migrations_online() -> None: poolclass=pool.NullPool, ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure(connection=connection, target_metadata=target_metadata, version_table_column_width=128) with context.begin_transaction(): context.run_migrations() + # Create any tables defined in models but not yet created by migrations. + # This bootstraps fresh databases that have no legacy schema. + # checkfirst=True ensures this is a no-op on existing databases. + try: + Base.metadata.create_all(bind=connection, checkfirst=True) + connection.commit() + except Exception as exc: + import logging + logging.getLogger("alembic.env").warning( + "create_all failed (non-fatal, migrations should handle table creation): %s", exc + ) if context.is_offline_mode(): diff --git a/alembic/versions/001_encrypt_session_data.py b/alembic/versions/001_encrypt_session_data.py index 4932231..20c70ac 100644 --- a/alembic/versions/001_encrypt_session_data.py +++ b/alembic/versions/001_encrypt_session_data.py @@ -33,6 +33,21 @@ def _is_fernet_token(value: str) -> bool: def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Fresh DB — table created by Base.metadata.create_all with correct TEXT type + if not inspector.has_table("user_store_accounts"): + return + + # Already migrated? Skip if session_data is already TEXT (not JSON) + cols = {c["name"]: c for c in inspector.get_columns("user_store_accounts")} + if "session_data" not in cols: + return + col_type = str(cols["session_data"]["type"]).lower() + if "text" in col_type and "json" not in col_type: + return # already TEXT — nothing to do + # Change column type from JSON to TEXT to hold Fernet ciphertext op.alter_column( "user_store_accounts", @@ -43,7 +58,6 @@ def upgrade() -> None: postgresql_using="session_data::text", ) - conn = op.get_bind() rows = conn.execute( text("SELECT id, session_data FROM user_store_accounts WHERE session_data IS NOT NULL") ).fetchall() diff --git a/alembic/versions/002_better_auth_tables.py b/alembic/versions/002_better_auth_tables.py index aa5dd93..efa283f 100644 --- a/alembic/versions/002_better_auth_tables.py +++ b/alembic/versions/002_better_auth_tables.py @@ -21,81 +21,94 @@ depends_on = None def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + # --- Extend users table for Better-Auth compatibility --- - op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false")) - op.add_column("users", sa.Column("image", sa.Text(), nullable=True)) + # Guard: on a fresh DB Base.metadata.create_all (called in env.py after migrations) + # creates the users table with all columns, so migration 002 must not re-run add_column. + if inspector.has_table("users"): + existing_user_cols = [c["name"] for c in inspector.get_columns("users")] + if "email_verified" not in existing_user_cols: + op.add_column("users", sa.Column("email_verified", sa.Boolean(), nullable=False, server_default="false")) + if "image" not in existing_user_cols: + op.add_column("users", sa.Column("image", sa.Text(), nullable=True)) # --- Create sessions table --- - op.create_table( - "sessions", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("token", sa.Text(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("ip_address", sa.Text(), nullable=True), - sa.Column("user_agent", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index("ix_sessions_token", "sessions", ["token"], unique=True) - op.create_index("ix_sessions_user_id", "sessions", ["user_id"]) + if not inspector.has_table("sessions"): + op.create_table( + "sessions", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("token", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("ip_address", sa.Text(), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_sessions_token", "sessions", ["token"], unique=True) + op.create_index("ix_sessions_user_id", "sessions", ["user_id"]) # --- Create accounts table --- - op.create_table( - "accounts", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=False), - sa.Column("account_id", sa.Text(), nullable=False), - sa.Column("provider_id", sa.Text(), nullable=False), - sa.Column("access_token", sa.Text(), nullable=True), - sa.Column("refresh_token", sa.Text(), nullable=True), - sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("scope", sa.Text(), nullable=True), - sa.Column("id_token", sa.Text(), nullable=True), - sa.Column("password", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index("ix_accounts_user_id", "accounts", ["user_id"]) + if not inspector.has_table("accounts"): + op.create_table( + "accounts", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("account_id", sa.Text(), nullable=False), + sa.Column("provider_id", sa.Text(), nullable=False), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("access_token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("refresh_token_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("id_token", sa.Text(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_accounts_user_id", "accounts", ["user_id"]) # --- Create verifications table --- - op.create_table( - "verifications", - sa.Column("id", sa.Text(), nullable=False), - sa.Column("identifier", sa.Text(), nullable=False), - sa.Column("value", sa.Text(), nullable=False), - sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) + if not inspector.has_table("verifications"): + op.create_table( + "verifications", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("identifier", sa.Text(), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) # --- Migrate existing password hashes to accounts table --- - # For each user with a hashed_password, create a 'credential' account row - conn = op.get_bind() - users = conn.execute( - text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL") - ).fetchall() + # Only run on existing (non-fresh) DBs that already have users table with data + if inspector.has_table("users"): + users = conn.execute( + text("SELECT id, hashed_password FROM users WHERE hashed_password IS NOT NULL") + ).fetchall() - for user_id, hashed_password in users: - user_id_str = str(user_id) - conn.execute( - text( - "INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) " - "VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())" - ), - {"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password}, - ) + for user_id, hashed_password in users: + user_id_str = str(user_id) + conn.execute( + text( + "INSERT INTO accounts (id, user_id, account_id, provider_id, password, created_at, updated_at) " + "VALUES (gen_random_uuid()::text, :user_id, :account_id, 'credential', :password, now(), now())" + ), + {"user_id": user_id_str, "account_id": user_id_str, "password": hashed_password}, + ) def downgrade() -> None: - op.drop_table("verifications") - op.drop_table("accounts") - op.drop_index("ix_sessions_user_id", table_name="sessions") - op.drop_index("ix_sessions_token", table_name="sessions") - op.drop_table("sessions") - op.drop_column("users", "image") - op.drop_column("users", "email_verified") + op.execute(text("DROP INDEX IF EXISTS ix_accounts_user_id")) + op.execute(text("DROP TABLE IF EXISTS verifications")) + op.execute(text("DROP TABLE IF EXISTS accounts")) + op.execute(text("DROP INDEX IF EXISTS ix_sessions_user_id")) + op.execute(text("DROP INDEX IF EXISTS ix_sessions_token")) + op.execute(text("DROP TABLE IF EXISTS sessions")) + op.execute(text("ALTER TABLE users DROP COLUMN IF EXISTS image")) + op.execute(text("ALTER TABLE users DROP COLUMN IF EXISTS email_verified")) diff --git a/alembic/versions/003_make_users_hashed_password_nullable.py b/alembic/versions/003_make_users_hashed_password_nullable.py index 8aec2bc..573b0ad 100644 --- a/alembic/versions/003_make_users_hashed_password_nullable.py +++ b/alembic/versions/003_make_users_hashed_password_nullable.py @@ -19,8 +19,25 @@ depends_on = None def upgrade() -> None: - op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=True) + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Fresh DB — nothing to alter + if not inspector.has_table("users"): + return + + cols = {c["name"]: c for c in inspector.get_columns("users")} + if "hashed_password" in cols and not cols["hashed_password"]["nullable"]: + op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=True) def downgrade() -> None: - op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=False) + conn = op.get_bind() + inspector = sa.inspect(conn) + + if not inspector.has_table("users"): + return + + cols = {c["name"]: c for c in inspector.get_columns("users")} + if "hashed_password" in cols and cols["hashed_password"]["nullable"]: + op.alter_column("users", "hashed_password", existing_type=sa.String(255), nullable=False) diff --git a/alembic/versions/004_fix_user_id_text.py b/alembic/versions/004_fix_user_id_text.py index a52bf9d..648333c 100644 --- a/alembic/versions/004_fix_user_id_text.py +++ b/alembic/versions/004_fix_user_id_text.py @@ -25,7 +25,21 @@ depends_on = None def upgrade() -> None: - # Step 1: Drop existing FK constraints + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Fresh DB — no tables yet, nothing to convert + if not inspector.has_table("users"): + return + + # Check if already TEXT (Base.metadata.create_all uses TEXT for fresh DB) + users_cols = {c["name"]: c for c in inspector.get_columns("users")} + if "id" in users_cols: + id_type = str(users_cols["id"]["type"]).lower() + if "text" in id_type and "uuid" not in id_type: + return # already TEXT — nothing to do + + # Step 1: Drop existing FK constraints (ignore if they don't exist) op.execute(text("ALTER TABLE user_store_accounts DROP CONSTRAINT IF EXISTS user_store_accounts_user_id_fkey")) op.execute(text("ALTER TABLE purchases DROP CONSTRAINT IF EXISTS purchases_user_id_fkey")) diff --git a/alembic/versions/005_add_email_inbound_token.py b/alembic/versions/005_add_email_inbound_token.py index 4fb7c2c..c5cc2a9 100644 --- a/alembic/versions/005_add_email_inbound_token.py +++ b/alembic/versions/005_add_email_inbound_token.py @@ -18,6 +18,15 @@ depends_on = None def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + # Guard: on a fresh DB Base.metadata.create_all creates users table with the column already present + if not inspector.has_table("users"): + return + existing_cols = [c["name"] for c in inspector.get_columns("users")] + if "email_inbound_token" in existing_cols: + return + # Add column nullable first so existing rows can be backfilled op.add_column( "users", @@ -25,11 +34,10 @@ def upgrade() -> None: ) # Backfill existing users with unique tokens - connection = op.get_bind() - result = connection.execute(sa.text("SELECT id FROM users WHERE email_inbound_token IS NULL")) + result = conn.execute(sa.text("SELECT id FROM users WHERE email_inbound_token IS NULL")) for (user_id,) in result: token = secrets.token_urlsafe(16) - connection.execute( + conn.execute( sa.text("UPDATE users SET email_inbound_token = :token WHERE id = :id"), {"token": token, "id": user_id}, ) diff --git a/alembic/versions/006_email_inbound_token_server_default.py b/alembic/versions/006_email_inbound_token_server_default.py new file mode 100644 index 0000000..e090016 --- /dev/null +++ b/alembic/versions/006_email_inbound_token_server_default.py @@ -0,0 +1,42 @@ +"""Add server_default to users.email_inbound_token. + +Revision ID: 006_email_inbound_token_server_default +Revises: 005_add_email_inbound_token +Create Date: 2026-04-04 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "006_email_inbound_token_server_default" +down_revision = "005_add_email_inbound_token" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + # Guard: on a fresh DB Base.metadata.create_all already sets the server_default + if not inspector.has_table("users"): + return + cols = {c["name"]: c for c in inspector.get_columns("users")} + if "email_inbound_token" not in cols: + return + if cols["email_inbound_token"].get("default") is not None: + return + op.alter_column( + "users", + "email_inbound_token", + server_default=sa.text( + "replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')" + ), + ) + + +def downgrade() -> None: + op.alter_column( + "users", + "email_inbound_token", + server_default=None, + ) diff --git a/alembic/versions/007_bootstrap_users_table.py b/alembic/versions/007_bootstrap_users_table.py new file mode 100644 index 0000000..e9695c0 --- /dev/null +++ b/alembic/versions/007_bootstrap_users_table.py @@ -0,0 +1,47 @@ +"""Bootstrap users table on fresh databases. + +On fresh databases, migrations 001-006 skip users-table operations because +the table does not exist yet. Base.metadata.create_all() in env.py is meant +to handle this, but if it fails (import errors, etc.) the table is never +created. This migration creates the users table with raw SQL as a safety net. + +Revision ID: 007_bootstrap_users_table +Revises: 006_email_inbound_token_server_default +Create Date: 2026-04-04 +""" + +import sqlalchemy as sa +from sqlalchemy import text + +from alembic import op + +revision = "007_bootstrap_users_table" +down_revision = "006_email_inbound_token_server_default" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + if inspector.has_table("users"): + return # Table already exists (non-fresh DB or create_all already ran) + + conn.execute(text(""" + CREATE TABLE users ( + id TEXT PRIMARY KEY, + email VARCHAR(255) NOT NULL UNIQUE, + hashed_password VARCHAR(255), + display_name VARCHAR(100), + email_verified BOOLEAN NOT NULL DEFAULT false, + image TEXT, + email_inbound_token VARCHAR(22) NOT NULL UNIQUE + DEFAULT replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_'), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ) + """)) + + +def downgrade() -> None: + op.execute(text("DROP TABLE IF EXISTS users")) diff --git a/alembic/versions/008_create_domain_tables.py b/alembic/versions/008_create_domain_tables.py new file mode 100644 index 0000000..021c5bf --- /dev/null +++ b/alembic/versions/008_create_domain_tables.py @@ -0,0 +1,210 @@ +"""Create domain tables (stores, purchases, coupons, etc.). + +Revision ID: 008_create_domain_tables +Revises: 007_bootstrap_users_table +Create Date: 2026-04-04 +""" + +import sqlalchemy as sa +from sqlalchemy import text + +from alembic import op + +revision = "008_create_domain_tables" +down_revision = "007_bootstrap_users_table" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # 1. stores + if not inspector.has_table("stores"): + op.create_table( + "stores", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("slug", sa.String(20), nullable=False, unique=True), + sa.Column("logo_url", sa.String(500), nullable=True), + sa.Column("website_url", sa.String(500), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 2. store_locations + if not inspector.has_table("store_locations"): + op.create_table( + "store_locations", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False), + sa.Column("address", sa.String(300), nullable=False), + sa.Column("city", sa.String(100), nullable=False), + sa.Column("state", sa.String(2), nullable=False), + sa.Column("zip", sa.String(10), nullable=False), + sa.Column("lat", sa.Float(), nullable=True), + sa.Column("lng", sa.Float(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 3. normalized_products + if not inspector.has_table("normalized_products"): + op.create_table( + "normalized_products", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("canonical_name", sa.String(300), nullable=False), + sa.Column("category", sa.String(50), nullable=True), + sa.Column("subcategory", sa.String(100), nullable=True), + sa.Column("brand", sa.String(200), nullable=True), + sa.Column("size", sa.String(50), nullable=True), + sa.Column("size_unit", sa.String(10), nullable=True), + sa.Column("upc_variants", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 4. purchases + if not inspector.has_table("purchases"): + op.create_table( + "purchases", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False), + sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False), + sa.Column("store_location_id", sa.Uuid(), sa.ForeignKey("store_locations.id"), nullable=True), + sa.Column("receipt_id", sa.String(200), nullable=False), + sa.Column("purchase_date", sa.Date(), nullable=False), + sa.Column("total", sa.Numeric(10, 2), nullable=False), + sa.Column("subtotal", sa.Numeric(10, 2), nullable=True), + sa.Column("tax", sa.Numeric(10, 2), nullable=True), + sa.Column("savings_total", sa.Numeric(10, 2), nullable=True), + sa.Column("source_url", sa.String(500), nullable=True), + sa.Column("raw_data", sa.JSON(), nullable=True), + sa.Column("ingested_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.UniqueConstraint("user_id", "store_id", "receipt_id", name="uq_purchase_receipt"), + sa.Index("ix_purchases_user_store", "user_id", "store_id"), + ) + + # 5. purchase_items + if not inspector.has_table("purchase_items"): + op.create_table( + "purchase_items", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("purchase_id", sa.Uuid(), sa.ForeignKey("purchases.id"), nullable=False), + sa.Column("product_name_raw", sa.String(300), nullable=False), + sa.Column("upc", sa.String(20), nullable=True), + sa.Column("quantity", sa.Numeric(10, 3), nullable=False), + sa.Column("unit_price", sa.Numeric(10, 2), nullable=False), + sa.Column("extended_price", sa.Numeric(10, 2), nullable=False), + sa.Column("regular_price", sa.Numeric(10, 2), nullable=True), + sa.Column("sale_price", sa.Numeric(10, 2), nullable=True), + sa.Column("coupon_discount", sa.Numeric(10, 2), nullable=True), + sa.Column("loyalty_discount", sa.Numeric(10, 2), nullable=True), + sa.Column("category_raw", sa.String(100), nullable=True), + sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 6. coupons + if not inspector.has_table("coupons"): + op.create_table( + "coupons", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False), + sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=True), + sa.Column("title", sa.String(300), nullable=False), + sa.Column("description", sa.String(1000), nullable=True), + sa.Column("discount_type", sa.String(20), nullable=False), + sa.Column("discount_value", sa.Numeric(10, 2), nullable=True), + sa.Column("min_purchase", sa.Numeric(10, 2), nullable=True), + sa.Column("valid_from", sa.Date(), nullable=True), + sa.Column("valid_to", sa.Date(), nullable=True), + sa.Column("requires_clip", sa.Boolean(), server_default=text("false"), nullable=False), + sa.Column("coupon_code", sa.String(100), nullable=True), + sa.Column("source_url", sa.String(500), nullable=True), + sa.Column("scraped_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 7. price_history + if not inspector.has_table("price_history"): + op.create_table( + "price_history", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False), + sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False), + sa.Column("observed_date", sa.Date(), nullable=False), + sa.Column("regular_price", sa.Numeric(10, 2), nullable=False), + sa.Column("sale_price", sa.Numeric(10, 2), nullable=True), + sa.Column("loyalty_price", sa.Numeric(10, 2), nullable=True), + sa.Column("coupon_price", sa.Numeric(10, 2), nullable=True), + sa.Column("source", sa.String(20), nullable=False), + sa.Column("purchase_item_id", sa.Uuid(), sa.ForeignKey("purchase_items.id"), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Index("ix_price_history_product_store_date", "normalized_product_id", "store_id", "observed_date"), + ) + + # 8. shrinkflation_events + if not inspector.has_table("shrinkflation_events"): + op.create_table( + "shrinkflation_events", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("normalized_product_id", sa.Uuid(), sa.ForeignKey("normalized_products.id"), nullable=False), + sa.Column("detected_date", sa.Date(), nullable=False), + sa.Column("old_size", sa.String(50), nullable=False), + sa.Column("new_size", sa.String(50), nullable=False), + sa.Column("old_unit", sa.String(10), nullable=True), + sa.Column("new_unit", sa.String(10), nullable=True), + sa.Column("price_at_old_size", sa.Numeric(10, 2), nullable=True), + sa.Column("price_at_new_size", sa.Numeric(10, 2), nullable=True), + sa.Column("confidence", sa.Numeric(3, 2), server_default=text("1.00"), nullable=False), + sa.Column("notes", sa.String(1000), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + # 9. user_store_accounts + if not inspector.has_table("user_store_accounts"): + op.create_table( + "user_store_accounts", + sa.Column("id", sa.Uuid(), server_default=text("gen_random_uuid()"), primary_key=True), + sa.Column("user_id", sa.Text(), sa.ForeignKey("users.id"), nullable=False), + sa.Column("store_id", sa.Uuid(), sa.ForeignKey("stores.id"), nullable=False), + sa.Column("session_data", sa.JSON(), nullable=True), + sa.Column("session_expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("status", sa.String(20), server_default=text("'active'"), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.UniqueConstraint("user_id", "store_id", name="uq_user_store_account"), + ) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + if inspector.has_table("user_store_accounts"): + op.drop_table("user_store_accounts") + if inspector.has_table("shrinkflation_events"): + op.drop_table("shrinkflation_events") + if inspector.has_table("price_history"): + op.drop_table("price_history") + if inspector.has_table("coupons"): + op.drop_table("coupons") + if inspector.has_table("purchase_items"): + op.drop_table("purchase_items") + if inspector.has_table("purchases"): + op.drop_table("purchases") + if inspector.has_table("normalized_products"): + op.drop_table("normalized_products") + if inspector.has_table("store_locations"): + op.drop_table("store_locations") + if inspector.has_table("stores"): + op.drop_table("stores") diff --git a/src/cartsnitch_api/auth/dependencies.py b/src/cartsnitch_api/auth/dependencies.py index 6fe1db4..5040741 100644 --- a/src/cartsnitch_api/auth/dependencies.py +++ b/src/cartsnitch_api/auth/dependencies.py @@ -19,12 +19,15 @@ bearer_scheme = HTTPBearer(auto_error=False) # Better-Auth session cookie name SESSION_COOKIE_NAME = "better-auth.session_token" +# Secure prefix used by better-auth on HTTPS deployments +SECURE_SESSION_COOKIE_NAME = "__Secure-better-auth.session_token" async def _validate_session_token(token: str, db: AsyncSession) -> str: """Validate a Better-Auth session token against the sessions table. - Returns the user_id (as str) if the session is valid and not expired. + Better-Auth stores the raw token in the DB. The cookie/Bearer header + carries the same raw token, so we compare directly. """ result = await db.execute( text("SELECT user_id, expires_at FROM sessions WHERE token = :token"), @@ -65,14 +68,17 @@ async def get_current_user( """ token: str | None = None - # 1. Check session cookie - cookie_token = request.cookies.get(SESSION_COOKIE_NAME) + # 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev) + cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME) if cookie_token: - token = cookie_token + # Better-Auth cookie format is "token.sessionId" — extract just the token part + token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token # 2. Fall back to Bearer header if not token and credentials: - token = credentials.credentials + # Callers might pass the compound value here too + raw = credentials.credentials + token = raw.split(".")[0] if "." in raw else raw if not token: raise HTTPException( diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index 5111997..7642deb 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -1,13 +1,16 @@ import base64 -from pydantic import model_validator +from pydantic import AliasChoices, Field, model_validator from pydantic_settings import BaseSettings class Settings(BaseSettings): model_config = {"env_prefix": "CARTSNITCH_"} - database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + database_url: str = Field( + default="postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch", + validation_alias=AliasChoices("CARTSNITCH_DATABASE_URL", "DATABASE_URL"), + ) redis_url: str = "redis://localhost:6379/0" jwt_secret_key: str = "change-me-in-production" @@ -49,5 +52,12 @@ class Settings(BaseSettings): ) from None return self + @model_validator(mode="after") + def normalize_database_url(self): + """Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver.""" + if self.database_url.startswith("postgresql://"): + self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1) + return self + settings = Settings() diff --git a/src/cartsnitch_api/models/user.py b/src/cartsnitch_api/models/user.py index 89390a3..9cbd4e8 100644 --- a/src/cartsnitch_api/models/user.py +++ b/src/cartsnitch_api/models/user.py @@ -4,7 +4,8 @@ import secrets from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint +import sqlalchemy as sa +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from cartsnitch_api.constants import AccountStatus @@ -23,13 +24,20 @@ class User(TimestampMixin, Base): id: Mapped[str] = mapped_column(Text, primary_key=True) email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) - hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True) display_name: Mapped[str | None] = mapped_column(String(100)) + email_verified: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default="false" + ) + image: Mapped[str | None] = mapped_column(Text, nullable=True) email_inbound_token: Mapped[str] = mapped_column( String(22), nullable=False, unique=True, default=lambda: secrets.token_urlsafe(16), + server_default=sa.text( + "replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_')" + ), ) # Relationships diff --git a/tests/conftest.py b/tests/conftest.py index 61810e1..bb84c20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,7 +136,8 @@ async def client(db_engine): async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]: """Create a test user and a valid session directly in the DB. - Returns (user_dict, session_token). + Returns (user_dict, session_token). Better-Auth stores the raw token + in the DB, so we insert it as-is. """ user_id = str(uuid.uuid4()) email = user_overrides.get("email", "test@example.com") diff --git a/tests/test_auth/test_auth_endpoints.py b/tests/test_auth/test_auth_endpoints.py index 7b096ae..9b55a4c 100644 --- a/tests/test_auth/test_auth_endpoints.py +++ b/tests/test_auth/test_auth_endpoints.py @@ -71,6 +71,56 @@ async def test_delete_me(client, auth_headers): assert resp.status_code == 404 +@pytest.mark.asyncio +async def test_get_me_compound_cookie(client, db_engine): + """Compound cookie value (token.sessionId) must be parsed to extract the token part.""" + from tests.conftest import _create_test_user_and_session + + _, session_token = await _create_test_user_and_session( + client, db_engine, email="compound@example.com", display_name="Compound User" + ) + compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7" + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={compound}"}, + ) + assert resp.status_code == 200 + assert resp.json()["email"] == "compound@example.com" + + +@pytest.mark.asyncio +async def test_get_me_raw_token_cookie(client, db_engine): + """Raw token (no dot) in cookie must still work — regression guard.""" + from tests.conftest import _create_test_user_and_session + + _, session_token = await _create_test_user_and_session( + client, db_engine, email="rawcookie@example.com", display_name="Raw Cookie User" + ) + resp = await client.get( + "/auth/me", + headers={"Cookie": f"better-auth.session_token={session_token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["email"] == "rawcookie@example.com" + + +@pytest.mark.asyncio +async def test_get_me_compound_bearer(client, db_engine): + """Compound Bearer token (token.sessionId) must be parsed to extract the token part.""" + from tests.conftest import _create_test_user_and_session + + _, session_token = await _create_test_user_and_session( + client, db_engine, email="compoundbearer@example.com", display_name="Compound Bearer User" + ) + compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7" + resp = await client.get( + "/auth/me", + headers={"Authorization": f"Bearer {compound}"}, + ) + assert resp.status_code == 200 + assert resp.json()["email"] == "compoundbearer@example.com" + + @pytest.mark.asyncio async def test_expired_session_rejected(client, db_engine): """Expired sessions must be rejected.""" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..f594bc2 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,48 @@ +"""Tests for Settings config, specifically the database_url env var fallback.""" + +import os + +from cartsnitch_api.config import Settings + + +def test_database_url_prefers_cartsnitch_prefix(): + """CARTSNITCH_DATABASE_URL takes precedence over DATABASE_URL.""" + env = { + "CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://user1:pass1@host1:5432/db1", + "DATABASE_URL": "postgresql://user2:pass2@host2:5432/db2", + } + settings = Settings(**env) + assert settings.database_url == "postgresql+asyncpg://user1:pass1@host1:5432/db1" + + +def test_database_url_falls_back_to_database_url(): + """When CARTSNITCH_DATABASE_URL is absent, DATABASE_URL is accepted.""" + env = { + "DATABASE_URL": "postgresql://user:pass@dbhost:5432/mydb", + } + settings = Settings(**env) + assert settings.database_url == "postgresql+asyncpg://user:pass@dbhost:5432/mydb" + + +def test_database_url_normalizes_plain_postgresql_prefix(): + """DATABASE_URL with plain postgresql:// is normalized to postgresql+asyncpg://.""" + env = { + "DATABASE_URL": "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch", + } + settings = Settings(**env) + assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + + +def test_database_url_preserves_asyncpg_prefix(): + """CARTSNITCH_DATABASE_URL with postgresql+asyncpg:// is left unchanged.""" + env = { + "CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch", + } + settings = Settings(**env) + assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch" + + +def test_database_url_default(): + """When neither env var is set, the hardcoded default is used.""" + settings = Settings() + assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"