Files
trebuchet/src/ai/claude-executor.ts
T
ajmallesh 023cc953db security: tighten Docker isolation and subprocess env
- Pin @playwright/mcp to 0.0.68 instead of @latest to prevent supply chain risk
- Restrict MCP subprocess env to allowlist (PATH, HOME, NODE_PATH, DISPLAY, XDG_*) instead of spreading process.env
- Add path traversal guard to @include() directive in prompt templates
- Bind all Docker ports to 127.0.0.1 to prevent network exposure
- Remove ipc: host — shm_size: 2gb already covers Chromium shared memory needs
- Add prompt injection disclaimer for untrusted repositories to README
2026-03-06 17:20:39 -08:00

434 lines
13 KiB
TypeScript

// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Production Claude agent execution with retry, git checkpoints, and audit logging
import { fs, path } from 'zx';
import { query } from '@anthropic-ai/claude-agent-sdk';
import { isRetryableError, PentestError } from '../services/error-handling.js';
import { isSpendingCapBehavior } from '../utils/billing-detection.js';
import { Timer } from '../utils/metrics.js';
import { formatTimestamp } from '../utils/formatting.js';
import { AGENT_VALIDATORS, MCP_AGENT_MAPPING } from '../session-manager.js';
import { AuditSession } from '../audit/index.js';
import { createShannonHelperServer } from '../../mcp-server/dist/index.js';
import { AGENTS } from '../session-manager.js';
import type { AgentName } from '../types/index.js';
import { dispatchMessage } from './message-handlers.js';
import { detectExecutionContext, formatErrorOutput, formatCompletionMessage } from './output-formatters.js';
import { createProgressManager } from './progress-manager.js';
import { createAuditLogger } from './audit-logger.js';
import { getActualModelName } from './router-utils.js';
import { resolveModel, type ModelTier } from './models.js';
import type { ActivityLogger } from '../types/activity-logger.js';
declare global {
var SHANNON_DISABLE_LOADER: boolean | undefined;
}
export interface ClaudePromptResult {
result?: string | null | undefined;
success: boolean;
duration: number;
turns?: number | undefined;
cost: number;
model?: string | undefined;
partialCost?: number | undefined;
apiErrorDetected?: boolean | undefined;
error?: string | undefined;
errorType?: string | undefined;
prompt?: string | undefined;
retryable?: boolean | undefined;
}
interface StdioMcpServer {
type: 'stdio';
command: string;
args: string[];
env: Record<string, string>;
}
type McpServer = ReturnType<typeof createShannonHelperServer> | StdioMcpServer;
// Configures MCP servers for agent execution, with Docker-specific Chromium handling
function buildMcpServers(
sourceDir: string,
agentName: string | null,
logger: ActivityLogger
): Record<string, McpServer> {
// 1. Create the shannon-helper server (always present)
const shannonHelperServer = createShannonHelperServer(sourceDir);
const mcpServers: Record<string, McpServer> = {
'shannon-helper': shannonHelperServer,
};
// 2. Look up the agent's Playwright MCP mapping
if (agentName) {
const promptTemplate = AGENTS[agentName as AgentName].promptTemplate;
const playwrightMcpName = MCP_AGENT_MAPPING[promptTemplate as keyof typeof MCP_AGENT_MAPPING] || null;
if (playwrightMcpName) {
logger.info(`Assigned ${agentName} -> ${playwrightMcpName}`);
const userDataDir = `/tmp/${playwrightMcpName}`;
// 3. Configure Playwright MCP args with Docker/local browser handling
const isDocker = process.env.SHANNON_DOCKER === 'true';
const mcpArgs: string[] = [
'@playwright/mcp@0.0.68',
'--isolated',
'--user-data-dir', userDataDir,
];
if (isDocker) {
mcpArgs.push('--executable-path', '/usr/bin/chromium-browser');
mcpArgs.push('--browser', 'chromium');
}
// NOTE: Explicit allowlist — the Playwright MCP subprocess must not inherit
// secrets (API keys, AWS tokens) from the parent process.
const MCP_ENV_ALLOWLIST = [
'PATH', 'HOME', 'NODE_PATH', 'DISPLAY',
'PLAYWRIGHT_CHROMIUM_EXECUTABLE_PATH',
] as const;
const envVars: Record<string, string> = {
PLAYWRIGHT_HEADLESS: 'true',
...(isDocker && { PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: '1' }),
};
for (const key of MCP_ENV_ALLOWLIST) {
if (process.env[key]) {
envVars[key] = process.env[key]!;
}
}
for (const [key, value] of Object.entries(process.env)) {
if (key.startsWith('XDG_') && value !== undefined) {
envVars[key] = value;
}
}
mcpServers[playwrightMcpName] = {
type: 'stdio' as const,
command: 'npx',
args: mcpArgs,
env: envVars,
};
}
}
// 4. Return configured servers
return mcpServers;
}
function outputLines(lines: string[]): void {
for (const line of lines) {
console.log(line);
}
}
async function writeErrorLog(
err: Error & { code?: string; status?: number },
sourceDir: string,
fullPrompt: string,
duration: number
): Promise<void> {
try {
const errorLog = {
timestamp: formatTimestamp(),
agent: 'claude-executor',
error: {
name: err.constructor.name,
message: err.message,
code: err.code,
status: err.status,
stack: err.stack
},
context: {
sourceDir,
prompt: fullPrompt.slice(0, 200) + '...',
retryable: isRetryableError(err)
},
duration
};
const logPath = path.join(sourceDir, 'error.log');
await fs.appendFile(logPath, JSON.stringify(errorLog) + '\n');
} catch {
// Best-effort error log writing - don't propagate failures
}
}
export async function validateAgentOutput(
result: ClaudePromptResult,
agentName: string | null,
sourceDir: string,
logger: ActivityLogger
): Promise<boolean> {
logger.info(`Validating ${agentName} agent output`);
try {
// Check if agent completed successfully
if (!result.success || !result.result) {
logger.error('Validation failed: Agent execution was unsuccessful');
return false;
}
// Get validator function for this agent
const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined;
if (!validator) {
logger.warn(`No validator found for agent "${agentName}" - assuming success`);
logger.info('Validation passed: Unknown agent with successful result');
return true;
}
logger.info(`Using validator for agent: ${agentName}`, { sourceDir });
// Apply validation function
const validationResult = await validator(sourceDir, logger);
if (validationResult) {
logger.info('Validation passed: Required files/structure present');
} else {
logger.error('Validation failed: Missing required deliverable files');
}
return validationResult;
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
logger.error(`Validation failed with error: ${errMsg}`);
return false;
}
}
// Low-level SDK execution. Handles message streaming, progress, and audit logging.
// Exported for Temporal activities to call single-attempt execution.
export async function runClaudePrompt(
prompt: string,
sourceDir: string,
context: string = '',
description: string = 'Claude analysis',
agentName: string | null = null,
auditSession: AuditSession | null = null,
logger: ActivityLogger,
modelTier: ModelTier = 'medium'
): Promise<ClaudePromptResult> {
// 1. Initialize timing and prompt
const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`);
const fullPrompt = context ? `${context}\n\n${prompt}` : prompt;
// 2. Set up progress and audit infrastructure
const execContext = detectExecutionContext(description);
const progress = createProgressManager(
{ description, useCleanOutput: execContext.useCleanOutput },
global.SHANNON_DISABLE_LOADER ?? false
);
const auditLogger = createAuditLogger(auditSession);
logger.info(`Running Claude Code: ${description}...`);
// 3. Configure MCP servers
const mcpServers = buildMcpServers(sourceDir, agentName, logger);
// 4. Build env vars to pass to SDK subprocesses
const sdkEnv: Record<string, string> = {
CLAUDE_CODE_MAX_OUTPUT_TOKENS: process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS || '64000',
};
const passthroughVars = [
'ANTHROPIC_API_KEY',
'CLAUDE_CODE_OAUTH_TOKEN',
'ANTHROPIC_BASE_URL',
'ANTHROPIC_AUTH_TOKEN',
'CLAUDE_CODE_USE_BEDROCK',
'AWS_REGION',
'AWS_BEARER_TOKEN_BEDROCK',
'CLAUDE_CODE_USE_VERTEX',
'CLOUD_ML_REGION',
'ANTHROPIC_VERTEX_PROJECT_ID',
'GOOGLE_APPLICATION_CREDENTIALS',
'ANTHROPIC_SMALL_MODEL',
'ANTHROPIC_MEDIUM_MODEL',
'ANTHROPIC_LARGE_MODEL',
];
for (const name of passthroughVars) {
if (process.env[name]) {
sdkEnv[name] = process.env[name]!;
}
}
// 5. Configure SDK options
const options = {
model: resolveModel(modelTier),
maxTurns: 10_000,
cwd: sourceDir,
permissionMode: 'bypassPermissions' as const,
allowDangerouslySkipPermissions: true,
mcpServers,
env: sdkEnv,
};
if (!execContext.useCleanOutput) {
logger.info(`SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`);
}
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let totalCost = 0;
progress.start();
try {
// 6. Process the message stream
const messageLoopResult = await processMessageStream(
fullPrompt,
options,
{ execContext, description, progress, auditLogger, logger },
timer
);
turnCount = messageLoopResult.turnCount;
result = messageLoopResult.result;
apiErrorDetected = messageLoopResult.apiErrorDetected;
totalCost = messageLoopResult.cost;
const model = messageLoopResult.model;
// === SPENDING CAP SAFEGUARD ===
// 7. Defense-in-depth: Detect spending cap that slipped through detectApiError().
// Uses consolidated billing detection from utils/billing-detection.ts
if (isSpendingCapBehavior(turnCount, totalCost, result || '')) {
throw new PentestError(
`Spending cap likely reached (turns=${turnCount}, cost=$0): ${result?.slice(0, 100)}`,
'billing',
true // Retryable - Temporal will use 5-30 min backoff
);
}
// 8. Finalize successful result
const duration = timer.stop();
if (apiErrorDetected) {
logger.warn(`API Error detected in ${description} - will validate deliverables before failing`);
}
progress.finish(formatCompletionMessage(execContext, description, turnCount, duration));
return {
result,
success: true,
duration,
turns: turnCount,
cost: totalCost,
model,
partialCost: totalCost,
apiErrorDetected
};
} catch (error) {
// 9. Handle errors — log, write error file, return failure
const duration = timer.stop();
const err = error as Error & { code?: string; status?: number };
await auditLogger.logError(err, duration, turnCount);
progress.stop();
outputLines(formatErrorOutput(err, execContext, description, duration, sourceDir, isRetryableError(err)));
await writeErrorLog(err, sourceDir, fullPrompt, duration);
return {
error: err.message,
errorType: err.constructor.name,
prompt: fullPrompt.slice(0, 100) + '...',
success: false,
duration,
cost: totalCost,
retryable: isRetryableError(err)
};
}
}
interface MessageLoopResult {
turnCount: number;
result: string | null;
apiErrorDetected: boolean;
cost: number;
model?: string | undefined;
}
interface MessageLoopDeps {
execContext: ReturnType<typeof detectExecutionContext>;
description: string;
progress: ReturnType<typeof createProgressManager>;
auditLogger: ReturnType<typeof createAuditLogger>;
logger: ActivityLogger;
}
async function processMessageStream(
fullPrompt: string,
options: NonNullable<Parameters<typeof query>[0]['options']>,
deps: MessageLoopDeps,
timer: Timer
): Promise<MessageLoopResult> {
const { execContext, description, progress, auditLogger, logger } = deps;
const HEARTBEAT_INTERVAL = 30000;
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let cost = 0;
let model: string | undefined;
let lastHeartbeat = Date.now();
for await (const message of query({ prompt: fullPrompt, options })) {
// Heartbeat logging when loader is disabled
const now = Date.now();
if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) {
logger.info(`[${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`);
lastHeartbeat = now;
}
// Increment turn count for assistant messages
if (message.type === 'assistant') {
turnCount++;
}
const dispatchResult = await dispatchMessage(
message as { type: string; subtype?: string },
turnCount,
{ execContext, description, progress, auditLogger, logger }
);
if (dispatchResult.type === 'throw') {
throw dispatchResult.error;
}
if (dispatchResult.type === 'complete') {
result = dispatchResult.result;
cost = dispatchResult.cost;
break;
}
if (dispatchResult.type === 'continue') {
if (dispatchResult.apiErrorDetected) {
apiErrorDetected = true;
}
// Capture model from SystemInitMessage, but override with router model if applicable
if (dispatchResult.model) {
model = getActualModelName(dispatchResult.model);
}
}
}
return { turnCount, result, apiErrorDetected, cost, model };
}