Files
api/tests/test_encrypted_json.py
T
Barcode Betty 81c4e76acd
CI / lint (pull_request) Failing after 5s
CI / typecheck (pull_request) Failing after 32s
CI / test (pull_request) Failing after 52s
CI / build-and-push (pull_request) Has been skipped
CI / deploy-dev (pull_request) Has been skipped
CI / deploy-uat (pull_request) Has been skipped
Fix SQLite server_default AttributeError and pool_size errors
- Add hasattr(sd, 'expression') guard in engine fixtures to prevent
  AttributeError when iterating over server_default columns that use
  DefaultClause (which lacks .expression)
- Add _build_engine_kwargs() in database.py to conditionally apply
  pool_size/max_overflow only for non-SQLite database URLs
- Fixes test failures in conftest.py, test_encrypted_json.py

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-05-24 18:35:03 +00:00

143 lines
4.7 KiB
Python

"""Tests for EncryptedJSON TypeDecorator and session_data encryption."""
import json
import pytest
from cryptography.fernet import Fernet
from pydantic import ValidationError
from sqlalchemy import column, create_engine, table, text
from sqlalchemy.orm import sessionmaker
from cartsnitch_api.config import settings
from cartsnitch_api.models import Base
from cartsnitch_api.models.store import Store
from cartsnitch_api.models.user import User, UserStoreAccount
@pytest.fixture
def engine():
eng = create_engine("sqlite:///:memory:")
for table in Base.metadata.tables.values():
for col in table.columns.values():
sd = col.server_default
if sd is not None:
if not hasattr(sd, "expression"):
col.server_default = None
continue
expr_str = str(sd.expression).lower()
if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str:
col.server_default = None
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def session(engine):
factory = sessionmaker(bind=engine)
with factory() as sess:
yield sess
@pytest.fixture
def store(session):
s = Store(name="Test Store", slug="test-store")
session.add(s)
session.commit()
session.refresh(s)
return s
@pytest.fixture
def user(session):
u = User(email="alice@example.com", hashed_password="fakehash")
session.add(u)
session.commit()
session.refresh(u)
return u
class TestEncryptedJSONType:
"""Unit tests for the EncryptedJSON TypeDecorator."""
def test_round_trip(self, session, user, store):
"""Data written via the ORM comes back as the original dict."""
original = {"token": "abc123", "cookies": {"session_id": "xyz"}}
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == original
def test_stored_value_is_encrypted(self, session, user, store):
"""The raw value in the DB should be a Fernet token, not plaintext JSON."""
original = {"secret": "do-not-leak"}
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=original)
session.add(account)
session.commit()
# Use a raw table construct to bypass TypeDecorator on read
raw_table = table("user_store_accounts", column("id"), column("session_data"))
raw = session.execute(raw_table.select().where(raw_table.c.id == str(account.id))).first()
# If UUID matching fails with str, try bytes format
if raw is None:
raw = session.execute(
text("SELECT session_data FROM user_store_accounts LIMIT 1")
).scalar_one()
else:
raw = raw[1]
assert raw != json.dumps(original)
assert raw.startswith("gAAAAA")
# Verify we can decrypt the raw value manually
f = Fernet(settings.fernet_key.encode())
decrypted = json.loads(f.decrypt(raw.encode()))
assert decrypted == original
def test_null_round_trip(self, session, user, store):
"""NULL session_data stays NULL."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data=None)
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data is None
def test_empty_dict_round_trip(self, session, user, store):
"""Empty dict round-trips correctly."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={})
session.add(account)
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == {}
def test_update_session_data(self, session, user, store):
"""Updating session_data re-encrypts the new value."""
account = UserStoreAccount(user_id=user.id, store_id=store.id, session_data={"v": 1})
session.add(account)
session.commit()
account.session_data = {"v": 2, "new_field": True}
session.commit()
loaded = session.get(UserStoreAccount, account.id)
assert loaded.session_data == {"v": 2, "new_field": True}
class TestEncryptionKeyValidation:
"""Test that invalid/missing keys are caught at startup."""
def test_invalid_fernet_key_rejected(self, monkeypatch):
"""Settings validation rejects a bad key."""
monkeypatch.setenv("CARTSNITCH_FERNET_KEY", "not-a-valid-key")
with pytest.raises(ValidationError):
from cartsnitch_api.config import Settings
Settings()