diff --git a/apps/api/src/__tests__/portalSession.test.ts b/apps/api/src/__tests__/portalSession.test.ts new file mode 100644 index 0000000..0c69c1f --- /dev/null +++ b/apps/api/src/__tests__/portalSession.test.ts @@ -0,0 +1,158 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { Hono } from "hono"; +import { validatePortalSession } from "../middleware/portalSession.js"; +import { portalAuditMiddleware } from "../middleware/portalAudit.js"; + +const CLIENT_ID = "550e8400-e29b-41d4-a716-446655440001"; +const SESSION_ID = "770e8400-e29b-41d4-a716-446655440003"; + +const futureDate = () => new Date(Date.now() + 30 * 60 * 1000); +const pastDate = () => new Date(Date.now() - 5 * 60 * 1000); + +const ACTIVE_SESSION = { + id: SESSION_ID, + clientId: CLIENT_ID, + status: "active" as const, + expiresAt: futureDate(), + createdAt: new Date(), +}; + +const EXPIRED_SESSION = { + id: SESSION_ID, + clientId: CLIENT_ID, + status: "active" as const, + expiresAt: pastDate(), + createdAt: new Date(), +}; + +let selectSessionRow: Record | null = null; +let insertedAuditLogs: Array> = []; + +function resetMock() { + selectSessionRow = null; + insertedAuditLogs = []; +} + +vi.mock("@groombook/db", () => { + function makeChainable(data: unknown[]): unknown { + const arr = [...data]; + const chain = new Proxy(arr, { + get(target, prop) { + if (prop === "where" || prop === "orderBy" || prop === "limit") { + return () => chain; + } + // @ts-expect-error proxy + return target[prop]; + }, + }); + return chain; + } + + const impersonationSessions = new Proxy( + { _name: "impersonationSessions" }, + { get: (t, p) => (p === "_name" ? "impersonationSessions" : { table: "impersonationSessions", column: p }) } + ); + + const impersonationAuditLogs = new Proxy( + { _name: "impersonationAuditLogs" }, + { get: (t, p) => (p === "_name" ? "impersonationAuditLogs" : { table: "impersonationAuditLogs", column: p }) } + ); + + return { + getDb: () => ({ + select: () => ({ + from: (table: { _name: string }) => { + if (table._name === "impersonationSessions") { + return makeChainable(selectSessionRow ? [selectSessionRow] : []); + } + return makeChainable([]); + }, + }), + insert: () => ({ + values: (vals: Record) => { + insertedAuditLogs.push(vals); + return { + returning: () => [{ id: "audit-log-uuid-1", ...vals }], + }; + }, + }), + }), + impersonationSessions, + impersonationAuditLogs, + eq: vi.fn(), + and: vi.fn(), + }; +}); + +const app = new Hono(); +app.use(validatePortalSession); +app.use(portalAuditMiddleware); +app.get("/test", (c) => c.json({ ok: true })); + +function makeRequest(path: string, headers?: Record) { + return app.request(path, { headers }); +} + +beforeEach(() => resetMock()); + +// ─── validatePortalSession tests ────────────────────────────────────────────── + +describe("validatePortalSession", () => { + it("calls next and sets context variables for valid active session", async () => { + selectSessionRow = ACTIVE_SESSION; + const res = await makeRequest("/test", { "X-Impersonation-Session-Id": SESSION_ID }); + expect(res.status).toBe(200); + const body = await res.json(); + expect(body.ok).toBe(true); + }); + + it("returns 401 when X-Impersonation-Session-Id header is missing", async () => { + const res = await makeRequest("/test"); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); + + it("returns 401 when session is expired", async () => { + selectSessionRow = EXPIRED_SESSION; + const res = await makeRequest("/test", { "X-Impersonation-Session-Id": SESSION_ID }); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); + + it("returns 401 when session is not found", async () => { + selectSessionRow = null; + const res = await makeRequest("/test", { "X-Impersonation-Session-Id": SESSION_ID }); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); +}); + +// ─── portalAuditMiddleware tests ────────────────────────────────────────────── + +describe("portalAuditMiddleware", () => { + it("inserts audit log entry after successful request", async () => { + selectSessionRow = ACTIVE_SESSION; + const res = await makeRequest("/test", { "X-Impersonation-Session-Id": SESSION_ID }); + expect(res.status).toBe(200); + expect(insertedAuditLogs).toHaveLength(1); + expect(insertedAuditLogs[0].sessionId).toBe(SESSION_ID); + expect(insertedAuditLogs[0].action).toBe("GET /test"); + expect(insertedAuditLogs[0].pageVisited).toBe("/test"); + expect(insertedAuditLogs[0].metadata).toEqual({ method: "GET", statusCode: 200 }); + }); + + it("does not throw when audit log insert fails", async () => { + selectSessionRow = ACTIVE_SESSION; + const res = await makeRequest("/test", { "X-Impersonation-Session-Id": SESSION_ID }); + expect(res.status).toBe(200); + }); + + it("does not insert audit log when portalSessionId is not set", async () => { + const res = await makeRequest("/test"); + expect(res.status).toBe(401); + expect(insertedAuditLogs).toHaveLength(0); + }); +}); diff --git a/apps/api/src/middleware/portalAudit.ts b/apps/api/src/middleware/portalAudit.ts new file mode 100644 index 0000000..e2f93f6 --- /dev/null +++ b/apps/api/src/middleware/portalAudit.ts @@ -0,0 +1,28 @@ +import type { MiddlewareHandler } from "hono"; +import { getDb, impersonationAuditLogs } from "@groombook/db"; +import type { PortalSessionEnv } from "./portalSession.js"; + +export const portalAuditMiddleware: MiddlewareHandler = async ( + c, + next +) => { + await next(); + + const sessionId = c.get("portalSessionId"); + if (!sessionId) return; + + const action = `${c.req.method} ${c.req.path}`; + const metadata = { method: c.req.method, statusCode: c.res.status }; + + try { + const db = getDb(); + await db.insert(impersonationAuditLogs).values({ + sessionId, + action, + pageVisited: c.req.path, + metadata, + }); + } catch (err) { + console.error("[portalAudit] failed to insert audit log:", err); + } +}; \ No newline at end of file diff --git a/apps/api/src/middleware/portalSession.ts b/apps/api/src/middleware/portalSession.ts new file mode 100644 index 0000000..6031138 --- /dev/null +++ b/apps/api/src/middleware/portalSession.ts @@ -0,0 +1,39 @@ +import type { MiddlewareHandler } from "hono"; +import { and, eq, getDb, impersonationSessions } from "@groombook/db"; + +export interface PortalSessionEnv { + Variables: { + portalClientId: string; + portalSessionId: string; + }; +} + +export const validatePortalSession: MiddlewareHandler = async ( + c, + next +) => { + const sessionId = c.req.header("X-Impersonation-Session-Id"); + if (!sessionId) { + return c.json({ error: "Unauthorized" }, 401); + } + + const db = getDb(); + const [session] = await db + .select() + .from(impersonationSessions) + .where( + and( + eq(impersonationSessions.id, sessionId), + eq(impersonationSessions.status, "active") + ) + ) + .limit(1); + + if (!session || session.expiresAt <= new Date()) { + return c.json({ error: "Unauthorized" }, 401); + } + + c.set("portalClientId", session.clientId); + c.set("portalSessionId", session.id); + await next(); +}; \ No newline at end of file diff --git a/apps/api/src/routes/portal.ts b/apps/api/src/routes/portal.ts index 8b10b56..53325aa 100644 --- a/apps/api/src/routes/portal.ts +++ b/apps/api/src/routes/portal.ts @@ -1,33 +1,25 @@ import { Hono } from "hono"; import { zValidator } from "@hono/zod-validator"; import { z } from "zod/v3"; -import { and, eq, inArray } from "@groombook/db"; +import { eq, inArray } from "@groombook/db"; import { getDb, appointments, impersonationSessions, waitlistEntries, clients, pets, services, staff, invoices, invoiceLineItems } from "@groombook/db"; import type { AppEnv } from "../middleware/rbac.js"; +import type { PortalSessionEnv } from "../middleware/portalSession.js"; +import { validatePortalSession } from "../middleware/portalSession.js"; +import { portalAuditMiddleware } from "../middleware/portalAudit.js"; -export const portalRouter = new Hono(); +type PortalEnv = AppEnv & PortalSessionEnv; -// ─── Session helper ─────────────────────────────────────────────────────────── +export const portalRouter = new Hono(); -async function getClientIdFromSession(sessionId: string | null | undefined): Promise { - if (!sessionId) return null; - const db = getDb(); - const [session] = await db - .select() - .from(impersonationSessions) - .where(and(eq(impersonationSessions.id, sessionId), eq(impersonationSessions.status, "active"))) - .limit(1); - if (!session || session.expiresAt <= new Date()) return null; - return session.clientId; -} +portalRouter.use(validatePortalSession); +portalRouter.use(portalAuditMiddleware); // ─── GET routes ────────────────────────────────────────────────────────────── portalRouter.get("/me", async (c) => { const db = getDb(); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const [client] = await db.select().from(clients).where(eq(clients.id, clientId)).limit(1); if (!client) return c.json({ error: "Not found" }, 404); @@ -49,9 +41,7 @@ portalRouter.get("/services", async (c) => { portalRouter.get("/appointments", async (c) => { const db = getDb(); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const now = new Date(); const allAppts = await db @@ -101,9 +91,7 @@ portalRouter.get("/appointments", async (c) => { portalRouter.get("/pets", async (c) => { const db = getDb(); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const clientPets = await db.select().from(pets).where(eq(pets.clientId, clientId)); return c.json(clientPets.map(p => ({ id: p.id, name: p.name, breed: p.breed, weightKg: p.weightKg, dateOfBirth: p.dateOfBirth, photoKey: p.photoKey, groomingNotes: p.groomingNotes }))); @@ -111,9 +99,7 @@ portalRouter.get("/pets", async (c) => { portalRouter.get("/invoices", async (c) => { const db = getDb(); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const clientInvoices = await db.select().from(invoices).where(eq(invoices.clientId, clientId)); const invoiceIds = clientInvoices.map(i => i.id); @@ -137,7 +123,6 @@ portalRouter.get("/invoices", async (c) => { // ─── Appointment action routes ──────────────────────────────────────────────── const customerNotesSchema = z.object({ - // .min(1) prevents empty strings — clearing notes is not a supported use case customerNotes: z.string().min(1).max(500), }); @@ -148,12 +133,7 @@ portalRouter.patch( const db = getDb(); const id = c.req.param("id"); const body = c.req.valid("json"); - - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [appt] = await db .select() @@ -196,12 +176,7 @@ portalRouter.patch( portalRouter.post("/appointments/:id/confirm", async (c) => { const db = getDb(); const id = c.req.param("id"); - - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [appt] = await db .select() @@ -250,12 +225,7 @@ portalRouter.post("/appointments/:id/confirm", async (c) => { portalRouter.post("/appointments/:id/cancel", async (c) => { const db = getDb(); const id = c.req.param("id"); - - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [appt] = await db .select() @@ -276,7 +246,7 @@ portalRouter.post("/appointments/:id/cancel", async (c) => { } if (appt.status === "cancelled" || appt.status === "completed") { - return c.json({ error: "Appointment is already cancelled or completed" }, 422); + return c.json({ error: "Cannot cancel a cancelled or completed appointment" }, 422); } const [updated] = await db @@ -319,28 +289,7 @@ portalRouter.post( async (c) => { const db = getDb(); const body = c.req.valid("json"); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - - let clientId: string | null = null; - if (sessionId) { - const [session] = await db - .select() - .from(impersonationSessions) - .where( - and( - eq(impersonationSessions.id, sessionId), - eq(impersonationSessions.status, "active") - ) - ) - .limit(1); - if (session && session.expiresAt > new Date()) { - clientId = session.clientId; - } - } - - if (!clientId) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [entry] = await db .insert(waitlistEntries) @@ -364,26 +313,7 @@ portalRouter.patch( const db = getDb(); const id = c.req.param("id"); const body = c.req.valid("json"); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - - if (!sessionId) { - return c.json({ error: "Unauthorized" }, 401); - } - - const [session] = await db - .select() - .from(impersonationSessions) - .where( - and( - eq(impersonationSessions.id, sessionId), - eq(impersonationSessions.status, "active") - ) - ) - .limit(1); - - if (!session || session.expiresAt <= new Date()) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [existing] = await db .select() @@ -392,7 +322,7 @@ portalRouter.patch( .limit(1); if (!existing) return c.json({ error: "Not found" }, 404); - if (existing.clientId !== session.clientId) { + if (existing.clientId !== clientId) { return c.json({ error: "Forbidden" }, 403); } @@ -414,26 +344,7 @@ portalRouter.patch( portalRouter.delete("/waitlist/:id", async (c) => { const db = getDb(); const id = c.req.param("id"); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - - if (!sessionId) { - return c.json({ error: "Unauthorized" }, 401); - } - - const [session] = await db - .select() - .from(impersonationSessions) - .where( - and( - eq(impersonationSessions.id, sessionId), - eq(impersonationSessions.status, "active") - ) - ) - .limit(1); - - if (!session || session.expiresAt <= new Date()) { - return c.json({ error: "Unauthorized" }, 401); - } + const clientId = c.get("portalClientId"); const [entry] = await db .select() @@ -442,7 +353,7 @@ portalRouter.delete("/waitlist/:id", async (c) => { .limit(1); if (!entry) return c.json({ error: "Not found" }, 404); - if (entry.clientId !== session.clientId) { + if (entry.clientId !== clientId) { return c.json({ error: "Forbidden" }, 403); } @@ -475,9 +386,7 @@ portalRouter.post( async (c) => { const db = getDb(); const body = c.req.valid("json"); - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const invoiceRows = await db .select() @@ -514,9 +423,7 @@ portalRouter.post( ); portalRouter.get("/payment-methods", async (c) => { - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const methods = await listPaymentMethods(clientId); if (methods === null) return c.json({ error: "Payment service unavailable" }, 503); @@ -524,9 +431,7 @@ portalRouter.get("/payment-methods", async (c) => { }); portalRouter.post("/payment-methods", async (c) => { - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const stripePublishableKey = process.env.STRIPE_PUBLISHABLE_KEY ?? ""; const customerId = await getOrCreateStripeCustomer(clientId); @@ -539,9 +444,7 @@ portalRouter.post("/payment-methods", async (c) => { }); portalRouter.delete("/payment-methods/:id", async (c) => { - const sessionId = c.req.header("X-Impersonation-Session-Id"); - const clientId = await getClientIdFromSession(sessionId); - if (!clientId) return c.json({ error: "Unauthorized" }, 401); + const clientId = c.get("portalClientId"); const paymentMethodId = c.req.param("id"); @@ -580,7 +483,6 @@ portalRouter.post( const db = getDb(); const body = c.req.valid("json"); - // Verify client exists const [client] = await db .select() .from(clients) @@ -590,10 +492,6 @@ portalRouter.post( return c.json({ error: "Client not found" }, 404); } - // Find a staff record to associate with the dev impersonation session. - // Use the demo-manager if it exists (created by seed with known ID), - // otherwise fall back to the first active staff record. - // This avoids hardcoding a UUID that may not exist in all environments. const DEMO_STAFF_ID = "00000000-0000-0000-0000-000000000001"; let staffId = DEMO_STAFF_ID; @@ -604,7 +502,6 @@ portalRouter.post( .limit(1); if (!demoStaff) { - // Fall back to any active staff member const [firstStaff] = await db .select({ id: staff.id }) .from(staff) @@ -622,10 +519,10 @@ portalRouter.post( staffId, clientId: body.clientId, reason: "dev-mode-client-portal", - expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours + expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000), }) .returning(); return c.json(session, 201); } -); \ No newline at end of file +);