Files
trebuchet/apps/worker/src/ai/claude-executor.ts
T
Chris Farhood 085624b287
CI / Build & push API image (pull_request) Has been skipped
CI / Type-check & lint (pull_request) Successful in 18s
CI / Build & push worker image (pull_request) Has been skipped
feat: backport Opus 4.7 + adaptive thinking, remove scan tools, add --help to scripts
Backport upstream Shannon PRs #325, #327, #328:

- Update large model default to claude-opus-4-7, add adaptive thinking
  configuration (auto-enabled on Opus 4.6/4.7, opt-out via
  CLAUDE_ADAPTIVE_THINKING=false), filter thinking blocks from message
  content, bump claude-agent-sdk to ^0.2.114
- Remove unused scan tools (nmap, subfinder, whatweb, schemathesis) from
  Dockerfile, prompts, and docs; remove dead 'tool' error type from
  PentestErrorType; redact URLs in preflight info logs
- Add --help flag to save-deliverable and generate-totp CLI scripts

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-05-20 00:26:25 +00:00

403 lines
14 KiB
TypeScript

// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Production Claude agent execution with retry, git checkpoints, and audit logging
import { type JsonSchemaOutputFormat, query } from '@anthropic-ai/claude-agent-sdk';
import { fs, path } from 'zx';
import type { AuditSession } from '../audit/index.js';
import { deliverablesDir } from '../paths.js';
import { isRetryableError, PentestError } from '../services/error-handling.js';
import { AGENT_VALIDATORS } from '../session-manager.js';
import type { ActivityLogger } from '../types/activity-logger.js';
import { isSpendingCapBehavior } from '../utils/billing-detection.js';
import { formatTimestamp } from '../utils/formatting.js';
import { Timer } from '../utils/metrics.js';
import { createAuditLogger } from './audit-logger.js';
import { dispatchMessage } from './message-handlers.js';
import { type ModelTier, resolveModel, supportsAdaptiveThinking } from './models.js';
import { detectExecutionContext, formatCompletionMessage, formatErrorOutput } from './output-formatters.js';
import { createProgressManager } from './progress-manager.js';
declare global {
var SHANNON_DISABLE_LOADER: boolean | undefined;
}
export interface ClaudePromptResult {
result?: string | null | undefined;
success: boolean;
duration: number;
turns?: number | undefined;
cost: number;
model?: string | undefined;
partialCost?: number | undefined;
apiErrorDetected?: boolean | undefined;
error?: string | undefined;
errorType?: string | undefined;
prompt?: string | undefined;
retryable?: boolean | undefined;
structuredOutput?: unknown;
}
function outputLines(lines: string[]): void {
for (const line of lines) {
console.log(line);
}
}
async function writeErrorLog(
err: Error & { code?: string; status?: number },
sourceDir: string,
fullPrompt: string,
duration: number,
): Promise<void> {
try {
const errorLog = {
timestamp: formatTimestamp(),
agent: 'claude-executor',
error: {
name: err.constructor.name,
message: err.message,
code: err.code,
status: err.status,
stack: err.stack,
},
context: {
sourceDir,
prompt: `${fullPrompt.slice(0, 200)}...`,
retryable: isRetryableError(err),
},
duration,
};
const logPath = path.join(deliverablesDir(sourceDir), 'error.log');
await fs.appendFile(logPath, `${JSON.stringify(errorLog)}\n`);
} catch {
// Best-effort error log writing - don't propagate failures
}
}
export async function validateAgentOutput(
result: ClaudePromptResult,
agentName: string | null,
sourceDir: string,
logger: ActivityLogger,
): Promise<boolean> {
logger.info(`Validating ${agentName} agent output`);
try {
// Check if agent completed successfully (text result OR structured output)
if (!result.success || (!result.result && result.structuredOutput === undefined)) {
logger.error('Validation failed: Agent execution was unsuccessful');
return false;
}
// Get validator function for this agent
const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined;
if (!validator) {
logger.warn(`No validator found for agent "${agentName}" - assuming success`);
logger.info('Validation passed: Unknown agent with successful result');
return true;
}
logger.info(`Using validator for agent: ${agentName}`, { sourceDir });
// Apply validation function
const validationResult = await validator(sourceDir, logger);
if (validationResult) {
logger.info('Validation passed: Required files/structure present');
} else {
logger.error('Validation failed: Missing required deliverable files');
}
return validationResult;
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
logger.error(`Validation failed with error: ${errMsg}`);
return false;
}
}
// Low-level SDK execution. Handles message streaming, progress, and audit logging.
// Exported for Temporal activities to call single-attempt execution.
export async function runClaudePrompt(
prompt: string,
sourceDir: string,
context: string = '',
description: string = 'Claude analysis',
_agentName: string | null = null,
auditSession: AuditSession | null = null,
logger: ActivityLogger,
modelTier: ModelTier = 'medium',
outputFormat?: JsonSchemaOutputFormat,
apiKey?: string,
deliverablesSubdir?: string,
providerConfig?: import('../types/config.js').ProviderConfig,
): Promise<ClaudePromptResult> {
// 1. Initialize timing and prompt
const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`);
const fullPrompt = context ? `${context}\n\n${prompt}` : prompt;
// 2. Set up progress and audit infrastructure
const execContext = detectExecutionContext(description);
const progress = createProgressManager(
{ description, useCleanOutput: execContext.useCleanOutput },
global.SHANNON_DISABLE_LOADER ?? false,
);
const auditLogger = createAuditLogger(auditSession);
logger.info(`Running Claude Code: ${description}...`);
// 3. Build env vars to pass to SDK subprocesses
const sdkEnv: Record<string, string> = {
CLAUDE_CODE_MAX_OUTPUT_TOKENS: process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS || '64000',
PLAYWRIGHT_MCP_OUTPUT_DIR: deliverablesSubdir
? path.join(sourceDir, path.dirname(deliverablesSubdir), '.playwright-cli')
: path.join(sourceDir, '.shannon', '.playwright-cli'),
// apiKey from ContainerConfig takes precedence over process.env
...(apiKey && { ANTHROPIC_API_KEY: apiKey }),
// Deliverables subdir for save-deliverable CLI tool
...(deliverablesSubdir && { SHANNON_DELIVERABLES_SUBDIR: deliverablesSubdir }),
};
// 3a. Apply structured provider config directly to sdkEnv (no process.env mutation)
if (providerConfig) {
switch (providerConfig.providerType) {
case 'bedrock':
sdkEnv.CLAUDE_CODE_USE_BEDROCK = '1';
if (providerConfig.awsRegion) sdkEnv.AWS_REGION = providerConfig.awsRegion;
if (providerConfig.awsAccessKeyId) sdkEnv.AWS_ACCESS_KEY_ID = providerConfig.awsAccessKeyId;
if (providerConfig.awsSecretAccessKey) sdkEnv.AWS_SECRET_ACCESS_KEY = providerConfig.awsSecretAccessKey;
break;
case 'vertex':
sdkEnv.CLAUDE_CODE_USE_VERTEX = '1';
if (providerConfig.gcpRegion) sdkEnv.CLOUD_ML_REGION = providerConfig.gcpRegion;
if (providerConfig.gcpProjectId) sdkEnv.ANTHROPIC_VERTEX_PROJECT_ID = providerConfig.gcpProjectId;
if (providerConfig.gcpCredentialsPath)
sdkEnv.GOOGLE_APPLICATION_CREDENTIALS = providerConfig.gcpCredentialsPath;
break;
case 'litellm_router':
if (providerConfig.baseUrl) sdkEnv.ANTHROPIC_BASE_URL = providerConfig.baseUrl;
if (providerConfig.authToken) sdkEnv.ANTHROPIC_AUTH_TOKEN = providerConfig.authToken;
break;
default:
// 'anthropic_api' or unset — apiKey already handled above
if (providerConfig.apiKey && !apiKey) sdkEnv.ANTHROPIC_API_KEY = providerConfig.apiKey;
break;
}
}
// 3b. Passthrough env vars not already set by providerConfig or apiKey
const passthroughVars = [
...(!sdkEnv.ANTHROPIC_API_KEY ? ['ANTHROPIC_API_KEY'] : []),
'CLAUDE_CODE_OAUTH_TOKEN',
...(!sdkEnv.ANTHROPIC_BASE_URL ? ['ANTHROPIC_BASE_URL'] : []),
...(!sdkEnv.ANTHROPIC_AUTH_TOKEN ? ['ANTHROPIC_AUTH_TOKEN'] : []),
...(!sdkEnv.CLAUDE_CODE_USE_BEDROCK ? ['CLAUDE_CODE_USE_BEDROCK'] : []),
...(!sdkEnv.AWS_REGION ? ['AWS_REGION'] : []),
'AWS_BEARER_TOKEN_BEDROCK',
...(!sdkEnv.CLAUDE_CODE_USE_VERTEX ? ['CLAUDE_CODE_USE_VERTEX'] : []),
...(!sdkEnv.CLOUD_ML_REGION ? ['CLOUD_ML_REGION'] : []),
...(!sdkEnv.ANTHROPIC_VERTEX_PROJECT_ID ? ['ANTHROPIC_VERTEX_PROJECT_ID'] : []),
...(!sdkEnv.GOOGLE_APPLICATION_CREDENTIALS ? ['GOOGLE_APPLICATION_CREDENTIALS'] : []),
'HOME',
'PATH',
'PLAYWRIGHT_MCP_EXECUTABLE_PATH',
];
for (const name of passthroughVars) {
const val = process.env[name];
if (val) {
sdkEnv[name] = val;
}
}
// 4. Configure SDK options
// Model override from providerConfig takes precedence over env-based resolveModel
const model = providerConfig?.modelOverrides?.[modelTier] ?? resolveModel(modelTier);
const adaptiveThinking = supportsAdaptiveThinking(model) && process.env.CLAUDE_ADAPTIVE_THINKING !== 'false';
const options = {
model,
maxTurns: 10_000,
cwd: sourceDir,
permissionMode: 'bypassPermissions' as const,
allowDangerouslySkipPermissions: true,
settingSources: ['user'] as ('user' | 'project' | 'local')[],
env: sdkEnv,
...(adaptiveThinking && { thinking: { type: 'adaptive' as const } }),
...(outputFormat && { outputFormat }),
};
if (!execContext.useCleanOutput) {
logger.info(`SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`);
}
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let totalCost = 0;
progress.start();
try {
// 6. Process the message stream
const messageLoopResult = await processMessageStream(
fullPrompt,
options,
{ execContext, description, progress, auditLogger, logger },
timer,
);
turnCount = messageLoopResult.turnCount;
result = messageLoopResult.result;
apiErrorDetected = messageLoopResult.apiErrorDetected;
totalCost = messageLoopResult.cost;
const model = messageLoopResult.model;
// === SPENDING CAP SAFEGUARD ===
// 7. Defense-in-depth: Detect spending cap that slipped through detectApiError().
// Uses consolidated billing detection from utils/billing-detection.ts
if (isSpendingCapBehavior(turnCount, totalCost, result || '')) {
throw new PentestError(
`Spending cap likely reached (turns=${turnCount}, cost=$0): ${result?.slice(0, 100)}`,
'billing',
true, // Retryable - Temporal will use 5-30 min backoff
);
}
// 8. Finalize successful result
const duration = timer.stop();
if (apiErrorDetected) {
logger.warn(`API Error detected in ${description} - will validate deliverables before failing`);
}
progress.finish(formatCompletionMessage(execContext, description, turnCount, duration));
return {
result,
success: true,
duration,
turns: turnCount,
cost: totalCost,
model,
partialCost: totalCost,
apiErrorDetected,
...(messageLoopResult.structuredOutput !== undefined && {
structuredOutput: messageLoopResult.structuredOutput,
}),
};
} catch (error) {
// 9. Handle errors — log, write error file, return failure
const duration = timer.stop();
const err = error as Error & { code?: string; status?: number };
await auditLogger.logError(err, duration, turnCount);
progress.stop();
outputLines(formatErrorOutput(err, execContext, description, duration, sourceDir, isRetryableError(err)));
await writeErrorLog(err, sourceDir, fullPrompt, duration);
return {
error: err.message,
errorType: err.constructor.name,
prompt: `${fullPrompt.slice(0, 100)}...`,
success: false,
duration,
cost: totalCost,
retryable: isRetryableError(err),
};
}
}
interface MessageLoopResult {
turnCount: number;
result: string | null;
apiErrorDetected: boolean;
cost: number;
model?: string | undefined;
structuredOutput?: unknown;
}
interface MessageLoopDeps {
execContext: ReturnType<typeof detectExecutionContext>;
description: string;
progress: ReturnType<typeof createProgressManager>;
auditLogger: ReturnType<typeof createAuditLogger>;
logger: ActivityLogger;
}
async function processMessageStream(
fullPrompt: string,
options: NonNullable<Parameters<typeof query>[0]['options']>,
deps: MessageLoopDeps,
timer: Timer,
): Promise<MessageLoopResult> {
const { execContext, description, progress, auditLogger, logger } = deps;
const HEARTBEAT_INTERVAL = 30000;
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let cost = 0;
let model: string | undefined;
let structuredOutput: unknown | undefined;
let lastHeartbeat = Date.now();
for await (const message of query({ prompt: fullPrompt, options })) {
// Heartbeat logging when loader is disabled
const now = Date.now();
if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) {
logger.info(`[${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`);
lastHeartbeat = now;
}
// Increment turn count for assistant messages
if (message.type === 'assistant') {
turnCount++;
}
const dispatchResult = await dispatchMessage(message as { type: string; subtype?: string }, turnCount, {
execContext,
description,
progress,
auditLogger,
logger,
});
if (dispatchResult.type === 'throw') {
throw dispatchResult.error;
}
if (dispatchResult.type === 'complete') {
result = dispatchResult.result;
cost = dispatchResult.cost;
if (dispatchResult.structuredOutput !== undefined) {
structuredOutput = dispatchResult.structuredOutput;
}
break;
}
if (dispatchResult.type === 'continue') {
if (dispatchResult.apiErrorDetected) {
apiErrorDetected = true;
}
if (dispatchResult.model) {
model = dispatchResult.model;
}
}
}
return {
turnCount,
result,
apiErrorDetected,
cost,
model,
...(structuredOutput !== undefined && { structuredOutput }),
};
}