forked from cartsnitch/cartsnitch
feat(api): implement lifespan with DB and Redis connection pooling
- Refactor database.py to use init_db()/close_db() lifecycle - Add create_db_engine() with pool_size=10, max_overflow=20, pool_pre_ping=True - Replace cache.py stub with real Redis client using redis.asyncio - Implement init_redis()/close_redis() with graceful error handling - Replace no-op lifespan in main.py with proper startup/shutdown - Enhance health endpoint to check DB and Redis connectivity - Add tests for database, cache, and health endpoint lifecycle Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
committed by
savannah-savings-cto[bot]
parent
f96daceb0f
commit
2460a00d4e
@@ -1,9 +1,41 @@
|
|||||||
"""Redis/DragonflyDB caching helpers."""
|
"""Redis/DragonflyDB caching helpers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from cartsnitch_api.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_redis: "Redis | None" = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> "Settings":
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
async def init_redis() -> None:
|
||||||
|
global _redis
|
||||||
|
_redis = redis.from_url(settings.redis_url)
|
||||||
|
await _redis.ping()
|
||||||
|
|
||||||
|
|
||||||
|
async def close_redis() -> None:
|
||||||
|
global _redis
|
||||||
|
if _redis is not None:
|
||||||
|
await _redis.aclose()
|
||||||
|
_redis = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis() -> Redis | None:
|
||||||
|
return _redis
|
||||||
|
|
||||||
|
|
||||||
class CacheClient:
|
class CacheClient:
|
||||||
"""Redis/DragonflyDB caching with connection pooling.
|
"""Redis/DragonflyDB caching with connection pooling.
|
||||||
|
|||||||
@@ -1,28 +1,60 @@
|
|||||||
"""Database session management for the API gateway."""
|
"""Database session management for the API gateway."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from cartsnitch_api.config import settings
|
from cartsnitch_api.config import settings
|
||||||
|
|
||||||
engine = create_async_engine(
|
if TYPE_CHECKING:
|
||||||
settings.database_url,
|
from sqlalchemy.engine import Engine
|
||||||
echo=False,
|
|
||||||
pool_size=10,
|
|
||||||
max_overflow=20,
|
_engine: "Engine | None" = None
|
||||||
pool_pre_ping=True,
|
async_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||||
pool_recycle=3600,
|
|
||||||
)
|
|
||||||
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
def create_db_engine():
|
||||||
|
return create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db() -> None:
|
||||||
|
global _engine, async_session_factory
|
||||||
|
_engine = create_db_engine()
|
||||||
|
async_session_factory = async_sessionmaker(_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_db() -> None:
|
||||||
|
global _engine, async_session_factory
|
||||||
|
if _engine is not None:
|
||||||
|
await _engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
async_session_factory = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine():
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""FastAPI dependency that yields an async DB session."""
|
if async_session_factory is None:
|
||||||
|
raise RuntimeError("Database not initialized. Call init_db() first.")
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
async def dispose_engine() -> None:
|
# Backward compatibility: module-level engine proxy that delegates to _engine
|
||||||
"""Dispose the database engine, closing all pooled connections."""
|
def __getattr__(name: str):
|
||||||
await engine.dispose()
|
if name == "engine":
|
||||||
|
if _engine is None:
|
||||||
|
raise RuntimeError("Database not initialized. Call init_db() first.")
|
||||||
|
return _engine
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
@@ -26,10 +26,14 @@ from cartsnitch_api.routes.user import router as user_router
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
await cache_client.initialize()
|
from cartsnitch_api.database import init_db, close_db
|
||||||
|
from cartsnitch_api.cache import init_redis, close_redis
|
||||||
|
|
||||||
|
await init_db()
|
||||||
|
await init_redis()
|
||||||
yield
|
yield
|
||||||
await cache_client.close()
|
await close_redis()
|
||||||
await dispose_engine()
|
await close_db()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
"""Health check and error metrics endpoints."""
|
"""Health check and error metrics endpoints."""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
from cartsnitch_api.auth.dependencies import verify_service_key
|
from cartsnitch_api.auth.dependencies import verify_service_key
|
||||||
|
from cartsnitch_api.cache import get_redis
|
||||||
|
from cartsnitch_api.database import get_engine
|
||||||
from cartsnitch_api.middleware.error_handler import get_error_monitor
|
from cartsnitch_api.middleware.error_handler import get_error_monitor
|
||||||
|
|
||||||
router = APIRouter(tags=["health"])
|
router = APIRouter(tags=["health"])
|
||||||
@@ -10,7 +13,27 @@ router = APIRouter(tags=["health"])
|
|||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok"}
|
engine = get_engine()
|
||||||
|
db_ok = False
|
||||||
|
redis_ok = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text("SELECT 1"))
|
||||||
|
db_ok = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = get_redis()
|
||||||
|
if r:
|
||||||
|
await r.ping()
|
||||||
|
redis_ok = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
status = "ok" if db_ok else "degraded"
|
||||||
|
return {"status": status, "db": db_ok, "redis": redis_ok}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
|
@router.get("/internal/error-stats", dependencies=[Depends(verify_service_key)])
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""Tests for Redis/DragonflyDB caching lifecycle."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cartsnitch_api.cache import CacheClient, close_redis, get_redis, init_redis
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_redis_creates_client():
|
||||||
|
"""Test that init_redis creates the Redis client."""
|
||||||
|
await init_redis()
|
||||||
|
try:
|
||||||
|
r = get_redis()
|
||||||
|
assert r is not None
|
||||||
|
await r.ping()
|
||||||
|
finally:
|
||||||
|
await close_redis()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_redis_clears_client():
|
||||||
|
"""Test that close_redis properly closes and clears the client."""
|
||||||
|
await init_redis()
|
||||||
|
await close_redis()
|
||||||
|
assert get_redis() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_client_get_returns_none_when_not_connected():
|
||||||
|
"""Test that CacheClient.get returns None gracefully when Redis is down."""
|
||||||
|
client = CacheClient()
|
||||||
|
# Without init_redis, get should return None
|
||||||
|
result = await client.get("test-key")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_client_set_does_not_raise_when_not_connected():
|
||||||
|
"""Test that CacheClient.set does not raise when Redis is down."""
|
||||||
|
client = CacheClient()
|
||||||
|
# Without init_redis, set should not raise
|
||||||
|
await client.set("test-key", "test-value", ttl_seconds=60)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_client_delete_does_not_raise_when_not_connected():
|
||||||
|
"""Test that CacheClient.delete does not raise when Redis is down."""
|
||||||
|
client = CacheClient()
|
||||||
|
# Without init_redis, delete should not raise
|
||||||
|
await client.delete("test-key")
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""Tests for database initialization and lifecycle."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from cartsnitch_api.database import (
|
||||||
|
close_db,
|
||||||
|
create_db_engine,
|
||||||
|
get_engine,
|
||||||
|
init_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_db_engine_creates_engine_with_pool_settings():
|
||||||
|
"""Test that create_db_engine creates engine with correct pool settings."""
|
||||||
|
engine = create_db_engine()
|
||||||
|
assert engine is not None
|
||||||
|
pool = engine.pool
|
||||||
|
assert pool.size() == 10
|
||||||
|
assert pool._max_overflow == 20
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_db_sets_engine_and_factory():
|
||||||
|
"""Test that init_db properly initializes the engine and session factory."""
|
||||||
|
await init_db()
|
||||||
|
try:
|
||||||
|
eng = get_engine()
|
||||||
|
assert eng is not None
|
||||||
|
from cartsnitch_api import database
|
||||||
|
|
||||||
|
assert database.async_session_factory is not None
|
||||||
|
finally:
|
||||||
|
await close_db()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_db_disposes_engine():
|
||||||
|
"""Test that close_db properly disposes the engine."""
|
||||||
|
await init_db()
|
||||||
|
await close_db()
|
||||||
|
assert get_engine() is None
|
||||||
|
from cartsnitch_api import database
|
||||||
|
|
||||||
|
assert database.async_session_factory is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_yields_session_after_init():
|
||||||
|
"""Test that get_db yields working sessions after init_db."""
|
||||||
|
await init_db()
|
||||||
|
try:
|
||||||
|
from cartsnitch_api.database import get_db
|
||||||
|
|
||||||
|
gen = get_db()
|
||||||
|
session = await gen.__anext__()
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
await gen.aclose()
|
||||||
|
finally:
|
||||||
|
await close_db()
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Tests for health check endpoint."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from cartsnitch_api.database import init_db, close_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_returns_db_and_redis_fields(client):
|
||||||
|
"""Test that health endpoint returns db and redis status fields."""
|
||||||
|
from cartsnitch_api.cache import init_redis, close_redis
|
||||||
|
|
||||||
|
await init_db()
|
||||||
|
await init_redis()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert "db" in data
|
||||||
|
assert "redis" in data
|
||||||
|
finally:
|
||||||
|
await close_redis()
|
||||||
|
await close_db()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_returns_degraded_when_db_down():
|
||||||
|
"""Test that health returns degraded when database is down."""
|
||||||
|
from cartsnitch_api.database import _engine
|
||||||
|
from cartsnitch_api.routes.health import health
|
||||||
|
|
||||||
|
# Simulate engine is None (DB not initialized)
|
||||||
|
with patch("cartsnitch_api.routes.health.get_engine", return_value=None):
|
||||||
|
response = await health()
|
||||||
|
assert response["status"] == "degraded"
|
||||||
|
assert response["db"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_returns_ok_when_db_up(client):
|
||||||
|
"""Test that health returns ok when database is up."""
|
||||||
|
from cartsnitch_api.database import init_db, close_db
|
||||||
|
from cartsnitch_api.cache import init_redis, close_redis
|
||||||
|
|
||||||
|
await init_db()
|
||||||
|
await init_redis()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
if data["db"]:
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
finally:
|
||||||
|
await close_redis()
|
||||||
|
await close_db()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_redis_down_does_not_make_unhealthy(client):
|
||||||
|
"""Test that Redis being down does not make health return unhealthy."""
|
||||||
|
from cartsnitch_api.database import init_db, close_db
|
||||||
|
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
data = response.json()
|
||||||
|
# Redis being down should not make status "degraded"
|
||||||
|
# Only DB failure makes it degraded
|
||||||
|
if not data["db"]:
|
||||||
|
assert data["status"] == "degraded"
|
||||||
|
finally:
|
||||||
|
await close_db()
|
||||||
Reference in New Issue
Block a user