forked from cartsnitch/cartsnitch
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 138033be9b | |||
| 8ddefe82e4 | |||
| def921f115 | |||
| f03d7a33c8 | |||
| 7bf0165fe4 | |||
| ef63c47b7c | |||
| 67e60c9ae1 | |||
| a25b673dd6 | |||
| 4e003ba3d0 | |||
| 4996ff7432 |
+8
-2
@@ -18,7 +18,7 @@ if not db_url:
|
||||
"CARTSNITCH_DATABASE_URL_SYNC must be set. "
|
||||
"Example: postgresql://user:pass@localhost:5432/cartsnitch"
|
||||
)
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
config.set_main_option("sqlalchemy.url", db_url.replace("%", "%%"))
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
@@ -50,7 +50,13 @@ def run_migrations_online() -> None:
|
||||
# Create any tables defined in models but not yet created by migrations.
|
||||
# This bootstraps fresh databases that have no legacy schema.
|
||||
# checkfirst=True ensures this is a no-op on existing databases.
|
||||
Base.metadata.create_all(bind=connection, checkfirst=True)
|
||||
try:
|
||||
Base.metadata.create_all(bind=connection, checkfirst=True)
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger("alembic.env").warning(
|
||||
"create_all failed (non-fatal, migrations should handle table creation): %s", exc
|
||||
)
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Bootstrap users table on fresh databases.
|
||||
|
||||
On fresh databases, migrations 001-006 skip users-table operations because
|
||||
the table does not exist yet. Base.metadata.create_all() in env.py is meant
|
||||
to handle this, but if it fails (import errors, etc.) the table is never
|
||||
created. This migration creates the users table with raw SQL as a safety net.
|
||||
|
||||
Revision ID: 007_bootstrap_users_table
|
||||
Revises: 006_email_inbound_token_server_default
|
||||
Create Date: 2026-04-04
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "007_bootstrap_users_table"
|
||||
down_revision = "006_email_inbound_token_server_default"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
if inspector.has_table("users"):
|
||||
return # Table already exists (non-fresh DB or create_all already ran)
|
||||
|
||||
conn.execute(text("""
|
||||
CREATE TABLE users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
hashed_password VARCHAR(255),
|
||||
display_name VARCHAR(100),
|
||||
email_verified BOOLEAN NOT NULL DEFAULT false,
|
||||
image TEXT,
|
||||
email_inbound_token VARCHAR(22) NOT NULL UNIQUE
|
||||
DEFAULT replace(replace(trim(trailing '=' from encode(gen_random_bytes(16), 'base64')), '+', '-'), '/', '_'),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(text("DROP TABLE IF EXISTS users"))
|
||||
@@ -4,6 +4,7 @@ Validates Better-Auth session tokens from cookies or Bearer header.
|
||||
Sessions are verified by querying the shared sessions table directly.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from datetime import UTC, datetime
|
||||
from fastapi import Cookie, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -19,16 +20,21 @@ bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Better-Auth session cookie name
|
||||
SESSION_COOKIE_NAME = "better-auth.session_token"
|
||||
# Secure prefix used by better-auth on HTTPS deployments
|
||||
SECURE_SESSION_COOKIE_NAME = "__Secure-better-auth.session_token"
|
||||
|
||||
|
||||
async def _validate_session_token(token: str, db: AsyncSession) -> str:
|
||||
"""Validate a Better-Auth session token against the sessions table.
|
||||
|
||||
Returns the user_id (as str) if the session is valid and not expired.
|
||||
Better-Auth v1.2+ stores SHA-256(raw_token) in the DB.
|
||||
The cookie/Bearer header carries the raw token, so we hash before lookup.
|
||||
"""
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
result = await db.execute(
|
||||
text("SELECT user_id, expires_at FROM sessions WHERE token = :token"),
|
||||
{"token": token},
|
||||
{"token": token_hash},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
@@ -65,8 +71,8 @@ async def get_current_user(
|
||||
"""
|
||||
token: str | None = None
|
||||
|
||||
# 1. Check session cookie
|
||||
cookie_token = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
# 1. Check session cookie — prefer __Secure- variant (HTTPS) over plain (HTTP dev)
|
||||
cookie_token = request.cookies.get(SECURE_SESSION_COOKIE_NAME) or request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if cookie_token:
|
||||
token = cookie_token
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Session-based auth: tests create users and sessions directly in the DB,
|
||||
matching the Better-Auth session validation flow.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@@ -136,12 +137,14 @@ async def client(db_engine):
|
||||
async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_overrides) -> tuple[dict, str]:
|
||||
"""Create a test user and a valid session directly in the DB.
|
||||
|
||||
Returns (user_dict, session_token).
|
||||
Returns (user_dict, session_token). Better-Auth v1.2+ stores SHA-256
|
||||
hashed tokens in the DB, so the token is hashed before insertion.
|
||||
"""
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_overrides.get("email", "test@example.com")
|
||||
display_name = user_overrides.get("display_name", "Test User")
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
token_hash = hashlib.sha256(session_token.encode()).hexdigest()
|
||||
session_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expires = (datetime.now(UTC) + timedelta(days=7)).isoformat()
|
||||
@@ -169,7 +172,7 @@ async def _create_test_user_and_session(client: AsyncClient, db_engine, **user_o
|
||||
),
|
||||
{
|
||||
"id": session_id,
|
||||
"token": session_token,
|
||||
"token": token_hash,
|
||||
"user_id": user_id,
|
||||
"expires_at": expires,
|
||||
"created_at": now,
|
||||
|
||||
@@ -74,6 +74,7 @@ async def test_delete_me(client, auth_headers):
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session_rejected(client, db_engine):
|
||||
"""Expired sessions must be rejected."""
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@@ -82,6 +83,7 @@ async def test_expired_session_rejected(client, db_engine):
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
token_hash = hashlib.sha256(session_token.encode()).hexdigest()
|
||||
now = datetime.now(UTC).isoformat()
|
||||
expired = (datetime.now(UTC) - timedelta(hours=1)).isoformat()
|
||||
|
||||
@@ -108,7 +110,7 @@ async def test_expired_session_rejected(client, db_engine):
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": session_token,
|
||||
"token": token_hash,
|
||||
"uid": user_id,
|
||||
"ea": expired,
|
||||
"ca": now,
|
||||
|
||||
@@ -14,7 +14,7 @@ if config.config_file_name is not None:
|
||||
|
||||
db_url = os.environ.get("CARTSNITCH_DATABASE_URL_SYNC")
|
||||
if db_url:
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
config.set_main_option("sqlalchemy.url", db_url.replace("%", "%%"))
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
Reference in New Issue
Block a user