forked from cartsnitch/cartsnitch
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0be7ccd4b4 | |||
| 6d37cecdba | |||
| 3745f5be69 | |||
| abec954320 | |||
| ec9deb515b | |||
| cfed9b0482 | |||
| 25edd8d5e3 | |||
| bd3cb3b9ab | |||
| 3bedc651c6 | |||
| 138033be9b | |||
| 8ddefe82e4 | |||
| def921f115 |
+1
-1
@@ -6,7 +6,7 @@ from logging.config import fileConfig
|
|||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from cartsnitch_api.models.base import Base # noqa: F401 — imports all models for autogenerate
|
from cartsnitch_api.models import Base # noqa: F401 — imports all models for autogenerate
|
||||||
|
|
||||||
config = context.config
|
config = context.config
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
|
|||||||
@@ -19,12 +19,15 @@ bearer_scheme = HTTPBearer(auto_error=False)
|
|||||||
|
|
||||||
# Better-Auth session cookie name
|
# Better-Auth session cookie name
|
||||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||||
|
# Secure prefix used by better-auth on HTTPS deployments
|
||||||
|
SECURE_SESSION_COOKIE_NAME = "__Secure-better-auth.session_token"
|
||||||
|
|
||||||
|
|
||||||
async def _validate_session_token(token: str, db: AsyncSession) -> str:
|
async def _validate_session_token(token: str, db: AsyncSession) -> str:
|
||||||
"""Validate a Better-Auth session token against the sessions table.
|
"""Validate a Better-Auth session token against the sessions table.
|
||||||
|
|
||||||
Returns the user_id (as str) if the session is valid and not expired.
|
Better-Auth stores the raw token in the DB. The cookie/Bearer header
|
||||||
|
carries the same raw token, so we compare directly.
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||||
@@ -65,14 +68,17 @@ async def get_current_user(
|
|||||||
"""
|
"""
|
||||||
token: str | None = None
|
token: str | None = None
|
||||||
|
|
||||||
# 1. Check session cookie
|
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
|
||||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME)
|
||||||
if cookie_token:
|
if cookie_token:
|
||||||
token = cookie_token
|
# Better-Auth cookie format is "token.sessionId" — extract just the token part
|
||||||
|
token = cookie_token.split(".")[0] if "." in cookie_token else cookie_token
|
||||||
|
|
||||||
# 2. Fall back to Bearer header
|
# 2. Fall back to Bearer header
|
||||||
if not token and credentials:
|
if not token and credentials:
|
||||||
token = credentials.credentials
|
# Callers might pass the compound value here too
|
||||||
|
raw = credentials.credentials
|
||||||
|
token = raw.split(".")[0] if "." in raw else raw
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import AliasChoices, Field, model_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = {"env_prefix": "CARTSNITCH_"}
|
model_config = {"env_prefix": "CARTSNITCH_"}
|
||||||
|
|
||||||
database_url: str = "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
database_url: str = Field(
|
||||||
|
default="postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||||
|
validation_alias=AliasChoices("CARTSNITCH_DATABASE_URL", "DATABASE_URL"),
|
||||||
|
)
|
||||||
redis_url: str = "redis://localhost:6379/0"
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
jwt_secret_key: str = "change-me-in-production"
|
jwt_secret_key: str = "change-me-in-production"
|
||||||
@@ -49,5 +52,12 @@ class Settings(BaseSettings):
|
|||||||
) from None
|
) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def normalize_database_url(self):
|
||||||
|
"""Normalize postgresql:// → postgresql+asyncpg:// for the asyncpg driver."""
|
||||||
|
if self.database_url.startswith("postgresql://"):
|
||||||
|
self.database_url = self.database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ async def client(db_engine):
|
|||||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||||
"""Create a test user and a valid session directly in the DB.
|
"""Create a test user and a valid session directly in the DB.
|
||||||
|
|
||||||
Returns (user_dict, session_token).
|
Returns (user_dict, session_token). Better-Auth stores the raw token
|
||||||
|
in the DB, so we insert it as-is.
|
||||||
"""
|
"""
|
||||||
user_id = str(uuid.uuid4())
|
user_id = str(uuid.uuid4())
|
||||||
email = user_overrides.get("email", "test@example.com")
|
email = user_overrides.get("email", "test@example.com")
|
||||||
|
|||||||
@@ -71,6 +71,56 @@ async def test_delete_me(client, auth_headers):
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_compound_cookie(client, db_engine):
|
||||||
|
"""Compound cookie value (token.sessionId) must be parsed to extract the token part."""
|
||||||
|
from tests.conftest import _create_test_user_and_session
|
||||||
|
|
||||||
|
_, session_token = await _create_test_user_and_session(
|
||||||
|
client, db_engine, email="compound@example.com", display_name="Compound User"
|
||||||
|
)
|
||||||
|
compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7"
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Cookie": f"better-auth.session_token={compound}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["email"] == "compound@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_raw_token_cookie(client, db_engine):
|
||||||
|
"""Raw token (no dot) in cookie must still work — regression guard."""
|
||||||
|
from tests.conftest import _create_test_user_and_session
|
||||||
|
|
||||||
|
_, session_token = await _create_test_user_and_session(
|
||||||
|
client, db_engine, email="rawcookie@example.com", display_name="Raw Cookie User"
|
||||||
|
)
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Cookie": f"better-auth.session_token={session_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["email"] == "rawcookie@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_compound_bearer(client, db_engine):
|
||||||
|
"""Compound Bearer token (token.sessionId) must be parsed to extract the token part."""
|
||||||
|
from tests.conftest import _create_test_user_and_session
|
||||||
|
|
||||||
|
_, session_token = await _create_test_user_and_session(
|
||||||
|
client, db_engine, email="compoundbearer@example.com", display_name="Compound Bearer User"
|
||||||
|
)
|
||||||
|
compound = f"{session_token}.B0atkJCFxK1rZlwWPMK97nVO2LnyDun7"
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {compound}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["email"] == "compoundbearer@example.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_expired_session_rejected(client, db_engine):
|
async def test_expired_session_rejected(client, db_engine):
|
||||||
"""Expired sessions must be rejected."""
|
"""Expired sessions must be rejected."""
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for Settings config, specifically the database_url env var fallback."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from cartsnitch_api.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_url_prefers_cartsnitch_prefix():
|
||||||
|
"""CARTSNITCH_DATABASE_URL takes precedence over DATABASE_URL."""
|
||||||
|
env = {
|
||||||
|
"CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://user1:pass1@host1:5432/db1",
|
||||||
|
"DATABASE_URL": "postgresql://user2:pass2@host2:5432/db2",
|
||||||
|
}
|
||||||
|
settings = Settings(**env)
|
||||||
|
assert settings.database_url == "postgresql+asyncpg://user1:pass1@host1:5432/db1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_url_falls_back_to_database_url():
|
||||||
|
"""When CARTSNITCH_DATABASE_URL is absent, DATABASE_URL is accepted."""
|
||||||
|
env = {
|
||||||
|
"DATABASE_URL": "postgresql://user:pass@dbhost:5432/mydb",
|
||||||
|
}
|
||||||
|
settings = Settings(**env)
|
||||||
|
assert settings.database_url == "postgresql+asyncpg://user:pass@dbhost:5432/mydb"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_url_normalizes_plain_postgresql_prefix():
|
||||||
|
"""DATABASE_URL with plain postgresql:// is normalized to postgresql+asyncpg://."""
|
||||||
|
env = {
|
||||||
|
"DATABASE_URL": "postgresql://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||||
|
}
|
||||||
|
settings = Settings(**env)
|
||||||
|
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_url_preserves_asyncpg_prefix():
|
||||||
|
"""CARTSNITCH_DATABASE_URL with postgresql+asyncpg:// is left unchanged."""
|
||||||
|
env = {
|
||||||
|
"CARTSNITCH_DATABASE_URL": "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch",
|
||||||
|
}
|
||||||
|
settings = Settings(**env)
|
||||||
|
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_url_default():
|
||||||
|
"""When neither env var is set, the hardcoded default is used."""
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.database_url == "postgresql+asyncpg://cartsnitch:cartsnitch@localhost:5432/cartsnitch"
|
||||||
Reference in New Issue
Block a user