diff --git a/src/cartsnitch_api/auth/passwords.py b/src/cartsnitch_api/auth/passwords.py index 180f994..1205107 100644 --- a/src/cartsnitch_api/auth/passwords.py +++ b/src/cartsnitch_api/auth/passwords.py @@ -4,8 +4,8 @@ import bcrypt def hash_password(password: str) -> str: - return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + return str(bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()) def verify_password(plain_password: str, hashed_password: str) -> bool: - return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) + return bool(bcrypt.checkpw(plain_password.encode(), hashed_password.encode())) diff --git a/src/cartsnitch_api/cache.py b/src/cartsnitch_api/cache.py index 319cb8d..6766a8c 100644 --- a/src/cartsnitch_api/cache.py +++ b/src/cartsnitch_api/cache.py @@ -35,7 +35,12 @@ class CacheClient: async def get(self, key: str) -> str | None: if not self._client: return None - return await self._client.get(key) + value = await self._client.get(key) + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return value async def set(self, key: str, value: str, ttl_seconds: int = 300) -> None: if not self._client: diff --git a/src/cartsnitch_api/config.py b/src/cartsnitch_api/config.py index c71d753..b82aa37 100644 --- a/src/cartsnitch_api/config.py +++ b/src/cartsnitch_api/config.py @@ -86,4 +86,4 @@ class Settings(BaseSettings): return self -settings = Settings() +settings = Settings() # type: ignore[call-arg] diff --git a/src/cartsnitch_api/database.py b/src/cartsnitch_api/database.py index 3c6043c..5334b84 100644 --- a/src/cartsnitch_api/database.py +++ b/src/cartsnitch_api/database.py @@ -6,14 +6,22 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn from cartsnitch_api.config import settings -engine = create_async_engine( - settings.database_url, - echo=False, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=3600, -) + +def _build_engine_kwargs() -> dict: + url = settings.database_url + kwargs: dict = {"echo": False} + if not url.startswith("sqlite"): + kwargs.update( + pool_size=10, + max_overflow=20, + pool_timeout=30, + pool_pre_ping=True, + pool_recycle=3600, + ) + return kwargs + + +engine = create_async_engine(settings.database_url, **_build_engine_kwargs()) async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) diff --git a/src/cartsnitch_api/middleware/rate_limit.py b/src/cartsnitch_api/middleware/rate_limit.py index af3dd4b..b32f760 100644 --- a/src/cartsnitch_api/middleware/rate_limit.py +++ b/src/cartsnitch_api/middleware/rate_limit.py @@ -25,6 +25,9 @@ logger = logging.getLogger(__name__) class RateLimitBackend(Protocol): """Protocol for rate limit backends.""" + max_requests: int + window_seconds: int + async def is_allowed(self, key: str) -> tuple[bool, int, int]: """Check if request is allowed. Returns (allowed, remaining, retry_after).""" @@ -82,7 +85,8 @@ class RedisSlidingWindow: if current_count >= self.max_requests: oldest = await self.redis.zrange(key, 0, 0, withscores=True) if oldest: - retry_after = int((oldest[0][1] - cutoff) / 1000) + 1 + oldest_score = float(oldest[0][1]) + retry_after = int((oldest_score - cutoff) / 1000) + 1 else: retry_after = self.window_seconds return False, 0, retry_after @@ -114,6 +118,10 @@ if settings.rate_limit_redis_enabled: logger.warning("Failed to connect to Redis for rate limiting, using in-memory: %s", e) _use_redis = False +_public_limiter: RateLimitBackend +_auth_limiter: RateLimitBackend +_auth_strict_limiter: RateLimitBackend + if _use_redis and _redis_client: _public_limiter = RedisSlidingWindow( _redis_client, settings.rate_limit_requests, settings.rate_limit_window_seconds diff --git a/src/cartsnitch_api/routes/health.py b/src/cartsnitch_api/routes/health.py index 0574b10..dce47f2 100644 --- a/src/cartsnitch_api/routes/health.py +++ b/src/cartsnitch_api/routes/health.py @@ -1,16 +1,40 @@ """Health check and error metrics endpoints.""" -from fastapi import APIRouter, Depends +import logging + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession from cartsnitch_api.auth.dependencies import verify_service_key +from cartsnitch_api.database import get_db from cartsnitch_api.middleware.error_handler import get_error_monitor +logger = logging.getLogger(__name__) + router = APIRouter(tags=["health"]) @router.get("/health") -async def health(): - return {"status": "ok"} +async def health(db: AsyncSession = Depends(get_db)): + """Liveness + DB connectivity probe. + + Returns HTTP 200 when the API process is responsive *and* the database + is reachable, so Kubernetes readiness probes can correctly route traffic + away from pods that have lost their database connection. + + Returns HTTP 503 when the database is unreachable so K8s marks the pod + unhealthy and stops sending traffic to it. + """ + try: + await db.execute(text("SELECT 1")) + except Exception as exc: + logger.exception("Health check failed: database unreachable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={"status": "unavailable", "database": "disconnected"}, + ) from exc + return {"status": "ok", "database": "connected"} @router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)]) diff --git a/tests/conftest.py b/tests/conftest.py index 6439552..c9dc552 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,15 @@ from cartsnitch_api.database import get_db from cartsnitch_api.main import create_app from cartsnitch_api.models import Base + +def _set_timestamp_defaults(mapper, connection, target): + """Populate created_at/updated_at before insert for SQLite compatibility.""" + now = datetime.now(UTC) + for col in [c for c in mapper.columns if c.key in ("created_at", "updated_at")]: + if getattr(target, col.key, None) is None: + setattr(target, col.key, now) + + TEST_JWT_SECRET = secrets.token_urlsafe(32) TEST_SERVICE_KEY = secrets.token_urlsafe(32) TEST_FERNET_KEY = "7reF42nmTwbdN21PBoubGp7h_FU8qSimstmlaMLoRK8=" @@ -53,18 +62,27 @@ def disable_rate_limiting(): def engine(): """Sync in-memory SQLite engine for model unit tests. - Strips PostgreSQL-specific server_default expressions so SQLite can - handle all column inserts without missing-function errors. + Strips PostgreSQL-specific server_default expressions and provides + Python-side defaults for SQLite compatibility. """ eng = create_engine("sqlite:///:memory:") - for table in Base.metadata.tables.values(): - for col in table.columns.values(): + for tbl in Base.metadata.tables.values(): + for col in tbl.columns.values(): sd = col.server_default if sd is not None: - expr_str = str(sd.expression).lower() - if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: + if not hasattr(sd, "expression"): col.server_default = None + continue + expr_str = str(sd.expression).lower() + # Strip PostgreSQL-specific defaults + if any(x in expr_str for x in ["gen_random_uuid", "gen_random_bytes", "now()"]): + col.server_default = None + + # Register event listener to populate timestamps on insert + for cls in Base.registry._class_registry.values(): + if hasattr(cls, "__mapper__"): + event.listen(cls, "before_insert", _set_timestamp_defaults) Base.metadata.create_all(eng) yield eng @@ -89,13 +107,22 @@ async def db_engine(): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() - for table in Base.metadata.tables.values(): - for col in table.columns.values(): + for tbl in Base.metadata.tables.values(): + for col in tbl.columns.values(): sd = col.server_default if sd is not None: - expr_str = str(sd.expression).lower() - if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: + if not hasattr(sd, "expression"): col.server_default = None + continue + expr_str = str(sd.expression).lower() + # Strip PostgreSQL-specific defaults + if any(x in expr_str for x in ["gen_random_uuid", "gen_random_bytes", "now()"]): + col.server_default = None + + # Register event listener to populate timestamps on insert + for cls in Base.registry._class_registry.values(): + if hasattr(cls, "__mapper__"): + event.listen(cls, "before_insert", _set_timestamp_defaults) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/tests/test_encrypted_json.py b/tests/test_encrypted_json.py index 07cf44c..08b16d7 100644 --- a/tests/test_encrypted_json.py +++ b/tests/test_encrypted_json.py @@ -18,10 +18,13 @@ from cartsnitch_api.models.user import User, UserStoreAccount def engine(): eng = create_engine("sqlite:///:memory:") - for table in Base.metadata.tables.values(): - for col in table.columns.values(): + for tbl in Base.metadata.tables.values(): + for col in tbl.columns.values(): sd = col.server_default if sd is not None: + if not hasattr(sd, "expression"): + col.server_default = None + continue expr_str = str(sd.expression).lower() if "gen_random_uuid" in expr_str or "gen_random_bytes" in expr_str: col.server_default = None