feat: merge cartsnitch/api into api/ subdirectory
Consolidate API gateway service into monorepo. Squashed from https://github.com/cartsnitch/api main (89bacb1). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
"""CORS middleware configuration."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
def add_cors_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Structured error responses and error monitoring.
|
||||
|
||||
Ensures all errors return a consistent JSON shape and never leak stack traces.
|
||||
Provides hooks for error monitoring/alerting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger("cartsnitch_api.errors")
|
||||
|
||||
|
||||
def _error_response(
|
||||
status_code: int,
|
||||
detail: str,
|
||||
code: str | None = None,
|
||||
errors: list[dict] | None = None,
|
||||
) -> JSONResponse:
|
||||
"""Build a consistent error response."""
|
||||
body: dict = {"detail": detail}
|
||||
if code:
|
||||
body["code"] = code
|
||||
if errors:
|
||||
body["errors"] = errors
|
||||
return JSONResponse(status_code=status_code, content=body)
|
||||
|
||||
|
||||
def add_error_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers for consistent error responses."""
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
"""Return 422 with structured field-level error details."""
|
||||
field_errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", ())
|
||||
field_errors.append(
|
||||
{
|
||||
"field": ".".join(str(p) for p in loc[1:]) if len(loc) > 1 else str(loc),
|
||||
"message": err.get("msg", "Invalid value"),
|
||||
"type": err.get("type", "value_error"),
|
||||
}
|
||||
)
|
||||
return _error_response(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Validation error",
|
||||
code="VALIDATION_ERROR",
|
||||
errors=field_errors,
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
"""Wrap HTTP exceptions (Starlette and FastAPI) in consistent format."""
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
return _error_response(
|
||||
status_code=exc.status_code,
|
||||
detail=detail,
|
||||
code=_status_to_code(exc.status_code),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all: log full traceback, return safe 500 to client."""
|
||||
logger.error(
|
||||
"Unhandled exception on %s %s: %s\n%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
_notify_error_monitor(request, exc)
|
||||
|
||||
return _error_response(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error",
|
||||
code="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def _status_to_code(status_code: int) -> str:
|
||||
"""Map HTTP status code to a machine-readable error code."""
|
||||
mapping = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "RATE_LIMITED",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
}
|
||||
return mapping.get(status_code, f"HTTP_{status_code}")
|
||||
|
||||
|
||||
# ---------- Error Monitoring ----------
|
||||
|
||||
|
||||
class _ErrorMonitor:
|
||||
"""Simple error counter for monitoring and alerting hooks.
|
||||
|
||||
Tracks error counts and rates. In production, this would forward
|
||||
to an external monitoring service (Prometheus, Sentry, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.error_counts: dict[int, int] = {}
|
||||
self.recent_5xx: list[dict] = []
|
||||
self._max_recent = 100
|
||||
|
||||
def record(self, status_code: int, path: str, method: str, error: str | None = None) -> None:
|
||||
self.error_counts[status_code] = self.error_counts.get(status_code, 0) + 1
|
||||
|
||||
if status_code >= 500:
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"status": status_code,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"error": error,
|
||||
}
|
||||
self.recent_5xx.append(entry)
|
||||
if len(self.recent_5xx) > self._max_recent:
|
||||
self.recent_5xx = self.recent_5xx[-self._max_recent :]
|
||||
|
||||
logger.warning(
|
||||
"5xx error recorded: %s %s -> %d (%s)",
|
||||
method,
|
||||
path,
|
||||
status_code,
|
||||
error or "unknown",
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return {
|
||||
"error_counts": dict(self.error_counts),
|
||||
"recent_5xx_count": len(self.recent_5xx),
|
||||
}
|
||||
|
||||
|
||||
_monitor = _ErrorMonitor()
|
||||
|
||||
|
||||
def get_error_monitor() -> _ErrorMonitor:
|
||||
"""Access the global error monitor (for health/metrics endpoints)."""
|
||||
return _monitor
|
||||
|
||||
|
||||
def _notify_error_monitor(request: Request, exc: Exception) -> None:
|
||||
"""Record unhandled exception in the error monitor."""
|
||||
_monitor.record(
|
||||
status_code=500,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc)[:200],
|
||||
)
|
||||
|
||||
|
||||
class ErrorMonitorMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to track all 4xx/5xx responses for monitoring."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable],
|
||||
):
|
||||
response = await call_next(request)
|
||||
|
||||
if response.status_code >= 400:
|
||||
_monitor.record(
|
||||
status_code=response.status_code,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def add_error_monitor_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(ErrorMonitorMiddleware)
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Rate limiting middleware for public and authenticated endpoints.
|
||||
|
||||
Uses in-memory sliding window as fallback, Redis/DragonflyDB when available.
|
||||
Per-IP limiting on public endpoints, per-token limiting on authenticated endpoints.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from cartsnitch_api.config import settings
|
||||
|
||||
|
||||
class _SlidingWindowCounter:
|
||||
"""Thread-safe in-memory sliding window rate limiter."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._hits: dict[str, list[float]] = defaultdict(list)
|
||||
self._lock = Lock()
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, int, int]:
|
||||
"""Check if request is allowed. Returns (allowed, remaining, retry_after)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
# Prune expired entries
|
||||
self._hits[key] = [t for t in self._hits[key] if t > cutoff]
|
||||
|
||||
current_count = len(self._hits[key])
|
||||
if current_count >= self.max_requests:
|
||||
retry_after = int(self._hits[key][0] - cutoff) + 1
|
||||
return False, 0, retry_after
|
||||
|
||||
self._hits[key].append(now)
|
||||
remaining = self.max_requests - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
|
||||
# Module-level counters — one for public (per-IP), one for auth (per-token)
|
||||
_public_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests,
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
_auth_limiter = _SlidingWindowCounter(
|
||||
max_requests=settings.rate_limit_requests * 5, # 300/min for authenticated users
|
||||
window_seconds=settings.rate_limit_window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP, respecting X-Forwarded-For behind a reverse proxy."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _get_rate_limit_key(request: Request) -> tuple[str, _SlidingWindowCounter]:
|
||||
"""Determine rate limit key and which limiter to use."""
|
||||
if request.url.path.startswith("/public"):
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
# For authenticated endpoints, use Bearer token as key if present
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
# Use last 16 chars of token as key to avoid storing full tokens
|
||||
return f"token:{token[-16:]}", _auth_limiter
|
||||
|
||||
# Fallback to IP for unauthenticated non-public endpoints
|
||||
return f"ip:{_get_client_ip(request)}", _public_limiter
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip rate limiting when disabled (e.g. in tests) or for health checks
|
||||
if not settings.rate_limit_enabled or request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
key, limiter = _get_rate_limit_key(request)
|
||||
allowed, remaining, retry_after = limiter.is_allowed(key)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"code": "RATE_LIMITED",
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limiter.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
|
||||
def add_rate_limit_middleware(app: FastAPI) -> None:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
Reference in New Issue
Block a user