backport: provider extensions and drop claude-code-router mode
Cherry-pick of KeygraphHQ/shannon#295 (581c208). Upstream changes: removes router mode from CLI/worker, adds provider extensions, new report-output-provider and checkpoint-provider interfaces, refactored workflow orchestration. Conflicts resolved: kept our README.md, CLAUDE.md, and deleted compose files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,6 @@ import { dispatchMessage } from './message-handlers.js';
|
||||
import { type ModelTier, resolveModel } from './models.js';
|
||||
import { detectExecutionContext, formatCompletionMessage, formatErrorOutput } from './output-formatters.js';
|
||||
import { createProgressManager } from './progress-manager.js';
|
||||
import { getActualModelName } from './router-utils.js';
|
||||
|
||||
declare global {
|
||||
var SHANNON_DISABLE_LOADER: boolean | undefined;
|
||||
@@ -184,7 +183,6 @@ export async function runClaudePrompt(
|
||||
case 'litellm_router':
|
||||
if (providerConfig.baseUrl) sdkEnv.ANTHROPIC_BASE_URL = providerConfig.baseUrl;
|
||||
if (providerConfig.authToken) sdkEnv.ANTHROPIC_AUTH_TOKEN = providerConfig.authToken;
|
||||
if (providerConfig.routerDefault) sdkEnv.ROUTER_DEFAULT = providerConfig.routerDefault;
|
||||
break;
|
||||
default:
|
||||
// 'anthropic_api' or unset — apiKey already handled above
|
||||
@@ -385,9 +383,8 @@ async function processMessageStream(
|
||||
if (dispatchResult.apiErrorDetected) {
|
||||
apiErrorDetected = true;
|
||||
}
|
||||
// Capture model from SystemInitMessage, but override with router model if applicable
|
||||
if (dispatchResult.model) {
|
||||
model = getActualModelName(dispatchResult.model);
|
||||
model = dispatchResult.model;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
formatToolUseOutput,
|
||||
} from './output-formatters.js';
|
||||
import type { ProgressManager } from './progress-manager.js';
|
||||
import { getActualModelName } from './router-utils.js';
|
||||
import type {
|
||||
ApiErrorDetection,
|
||||
AssistantMessage,
|
||||
@@ -309,12 +308,10 @@ export async function dispatchMessage(
|
||||
case 'system': {
|
||||
if (message.subtype === 'init') {
|
||||
const initMsg = message as SystemInitMessage;
|
||||
const actualModel = getActualModelName(initMsg.model);
|
||||
if (!execContext.useCleanOutput) {
|
||||
logger.info(`Model: ${actualModel}, Permission: ${initMsg.permissionMode}`);
|
||||
logger.info(`Model: ${initMsg.model}, Permission: ${initMsg.permissionMode}`);
|
||||
}
|
||||
// Return actual model for tracking in audit logs
|
||||
return { type: 'continue', model: actualModel };
|
||||
return { type: 'continue', model: initMsg.model };
|
||||
}
|
||||
return { type: 'continue' };
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (C) 2025 Keygraph, Inc.
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License version 3
|
||||
// as published by the Free Software Foundation.
|
||||
|
||||
/**
|
||||
* Get the actual model name being used.
|
||||
* When using claude-code-router, the SDK reports its configured model (claude-sonnet)
|
||||
* but the actual model is determined by ROUTER_DEFAULT env var.
|
||||
*/
|
||||
export function getActualModelName(sdkReportedModel?: string): string | undefined {
|
||||
const routerBaseUrl = process.env.ANTHROPIC_BASE_URL;
|
||||
const routerDefault = process.env.ROUTER_DEFAULT;
|
||||
|
||||
// If router mode is active and ROUTER_DEFAULT is set, use that
|
||||
if (routerBaseUrl && routerDefault) {
|
||||
// ROUTER_DEFAULT format: "provider,model" (e.g., "gemini,gemini-2.5-pro")
|
||||
const parts = routerDefault.split(',');
|
||||
if (parts.length >= 2) {
|
||||
return parts.slice(1).join(','); // Handle model names with commas
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SDK-reported model
|
||||
return sdkReportedModel;
|
||||
}
|
||||
@@ -1,21 +1,59 @@
|
||||
/**
|
||||
* CheckpointProvider — injectable interface for external state persistence.
|
||||
*
|
||||
* Called after each agent completes to allow external progress tracking.
|
||||
* During the concurrent vulnerability-exploitation phase, 5 pipelines run
|
||||
* in parallel — onAgentComplete fires per-agent for granular progress.
|
||||
* Called before and after each agent to support skip-guard (resume) and
|
||||
* post-agent artifact persistence. During the concurrent vulnerability-exploitation
|
||||
* phase, 5 pipelines run in parallel — methods fire per-agent for granular control.
|
||||
*
|
||||
* Default: no-op.
|
||||
* Default: no-op (skip nothing, persist nothing).
|
||||
*/
|
||||
|
||||
import type { PipelineState } from '../temporal/shared.js';
|
||||
import type { AgentMetrics, PipelineState } from '../temporal/shared.js';
|
||||
|
||||
/** Result of a pre-agent skip check. */
|
||||
export interface SkipDecision {
|
||||
readonly skip: boolean;
|
||||
readonly metrics?: AgentMetrics; // Required when skip=true
|
||||
}
|
||||
|
||||
/** File-system context passed after agent completion for artifact persistence. */
|
||||
export interface CheckpointContext {
|
||||
readonly repoPath: string;
|
||||
readonly sessionId: string;
|
||||
readonly deliverablesSubdir: string;
|
||||
readonly outputPath?: string;
|
||||
}
|
||||
|
||||
export interface CheckpointProvider {
|
||||
onAgentComplete(agentName: string, phase: string, state: PipelineState): Promise<void>;
|
||||
/**
|
||||
* Called before an agent activity executes.
|
||||
* Return { skip: true, metrics } to skip the agent (e.g., output files already exist).
|
||||
* Return { skip: false } to run normally.
|
||||
*/
|
||||
shouldSkipAgent(
|
||||
agentName: string,
|
||||
repoPath: string,
|
||||
deliverablesSubdir: string,
|
||||
): Promise<SkipDecision>;
|
||||
|
||||
/**
|
||||
* Called after an agent activity succeeds.
|
||||
* Receives pipeline state and optional file context for artifact persistence.
|
||||
*/
|
||||
onAgentComplete(
|
||||
agentName: string,
|
||||
phase: string,
|
||||
state: PipelineState,
|
||||
context?: CheckpointContext,
|
||||
): Promise<void>;
|
||||
}
|
||||
|
||||
/** Default no-op implementation — no external checkpointing. */
|
||||
export class NoOpCheckpointProvider implements CheckpointProvider {
|
||||
async shouldSkipAgent(): Promise<SkipDecision> {
|
||||
return { skip: false };
|
||||
}
|
||||
|
||||
async onAgentComplete(): Promise<void> {
|
||||
// No-op
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/**
|
||||
* FindingsProvider — injectable interface for external findings integration.
|
||||
*
|
||||
* Allows external security data (SAST, SCA, secrets, etc.) to be merged
|
||||
* Allows external security data from consumer-supplied sources to be merged
|
||||
* into the exploitation pipeline between vulnerability analysis and exploitation.
|
||||
*
|
||||
* Default: no-op returning { mergedCount: 0 }.
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
* Consumers can provide alternate implementations via the DI container.
|
||||
*/
|
||||
|
||||
export type { CheckpointProvider } from './checkpoint-provider.js';
|
||||
export type { CheckpointProvider, CheckpointContext, SkipDecision } from './checkpoint-provider.js';
|
||||
export { NoOpCheckpointProvider } from './checkpoint-provider.js';
|
||||
export type { FindingsProvider } from './findings-provider.js';
|
||||
export { NoOpFindingsProvider } from './findings-provider.js';
|
||||
export type { ReportOutputProvider } from './report-output-provider.js';
|
||||
export { NoOpReportOutputProvider } from './report-output-provider.js';
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* ReportOutputProvider — injectable interface for emitting an optional
|
||||
* additional artifact alongside the assembled markdown report.
|
||||
*
|
||||
* Runs after the report agent has finalized
|
||||
* `comprehensive_security_assessment_report.md`. Consumers can override to
|
||||
* produce derived outputs; the default no-op produces nothing.
|
||||
*/
|
||||
|
||||
import type { ActivityInput } from '../temporal/activities.js';
|
||||
import type { ActivityLogger } from '../types/activity-logger.js';
|
||||
|
||||
export interface ReportOutputProvider {
|
||||
generate(input: ActivityInput, logger: ActivityLogger): Promise<{ outputPath?: string }>;
|
||||
}
|
||||
|
||||
/** Default no-op implementation — no additional output produced. */
|
||||
export class NoOpReportOutputProvider implements ReportOutputProvider {
|
||||
async generate(): Promise<{ outputPath?: string }> {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,8 @@ import type { CheckpointProvider } from '../interfaces/checkpoint-provider.js';
|
||||
import { NoOpCheckpointProvider } from '../interfaces/checkpoint-provider.js';
|
||||
import type { FindingsProvider } from '../interfaces/findings-provider.js';
|
||||
import { NoOpFindingsProvider } from '../interfaces/findings-provider.js';
|
||||
import type { ReportOutputProvider } from '../interfaces/report-output-provider.js';
|
||||
import { NoOpReportOutputProvider } from '../interfaces/report-output-provider.js';
|
||||
import type { ContainerConfig } from '../types/config.js';
|
||||
import { AgentExecutionService } from './agent-execution.js';
|
||||
import { ConfigLoaderService } from './config-loader.js';
|
||||
@@ -40,6 +42,7 @@ export interface ContainerDependencies {
|
||||
readonly config: ContainerConfig;
|
||||
readonly findingsProvider?: FindingsProvider;
|
||||
readonly checkpointProvider?: CheckpointProvider;
|
||||
readonly reportOutputProvider?: ReportOutputProvider;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -59,6 +62,7 @@ export class Container {
|
||||
readonly exploitationChecker: ExploitationCheckerService;
|
||||
readonly findingsProvider: FindingsProvider;
|
||||
readonly checkpointProvider: CheckpointProvider;
|
||||
readonly reportOutputProvider: ReportOutputProvider;
|
||||
|
||||
constructor(deps: ContainerDependencies) {
|
||||
this.sessionMetadata = deps.sessionMetadata;
|
||||
@@ -72,6 +76,7 @@ export class Container {
|
||||
// Wire providers with default no-ops when not provided
|
||||
this.findingsProvider = deps.findingsProvider ?? new NoOpFindingsProvider();
|
||||
this.checkpointProvider = deps.checkpointProvider ?? new NoOpCheckpointProvider();
|
||||
this.reportOutputProvider = deps.reportOutputProvider ?? new NoOpReportOutputProvider();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +92,32 @@ const DEFAULT_CONFIG: ContainerConfig = {
|
||||
auditDir: './workspaces',
|
||||
};
|
||||
|
||||
/**
|
||||
* Factory function for creating containers.
|
||||
*
|
||||
* Default: creates a plain Container with NoOp providers. Consumers can call
|
||||
* setContainerFactory() at worker startup to inject custom provider
|
||||
* implementations into every container.
|
||||
*/
|
||||
type ContainerFactory = (
|
||||
workflowId: string,
|
||||
sessionMetadata: SessionMetadata,
|
||||
config: ContainerConfig,
|
||||
) => Container;
|
||||
|
||||
let containerFactory: ContainerFactory = (_workflowId, sessionMetadata, config) =>
|
||||
new Container({ sessionMetadata, config });
|
||||
|
||||
/**
|
||||
* Override the default container factory.
|
||||
*
|
||||
* Call once at worker startup to inject providers into all containers
|
||||
* created during the worker's lifetime.
|
||||
*/
|
||||
export function setContainerFactory(factory: ContainerFactory): void {
|
||||
containerFactory = factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create a Container for a workflow.
|
||||
*
|
||||
@@ -106,7 +137,7 @@ export function getOrCreateContainer(
|
||||
let container = containers.get(workflowId);
|
||||
|
||||
if (!container) {
|
||||
container = new Container({ sessionMetadata, config });
|
||||
container = containerFactory(workflowId, sessionMetadata, config);
|
||||
containers.set(workflowId, container);
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ export { AgentExecutionService } from './agent-execution.js';
|
||||
|
||||
export { ConfigLoaderService } from './config-loader.js';
|
||||
export type { ContainerDependencies } from './container.js';
|
||||
export { Container, getContainer, getOrCreateContainer, removeContainer } from './container.js';
|
||||
export { Container, getContainer, getOrCreateContainer, removeContainer, setContainerFactory } from './container.js';
|
||||
export { ExploitationCheckerService } from './exploitation-checker.js';
|
||||
export { loadPrompt } from './prompt-manager.js';
|
||||
export { assembleFinalReport, injectModelIntoReport } from './reporting.js';
|
||||
export type { ClaudePromptResult } from '../ai/claude-executor.js';
|
||||
export { runClaudePrompt } from '../ai/claude-executor.js';
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
* Checks run sequentially, cheapest first:
|
||||
* 1. Repository path exists and contains .git
|
||||
* 2. Config file parses and validates (if provided)
|
||||
* 3. Credentials validate via Claude Agent SDK query (API key, OAuth, Bedrock, Vertex AI, or router mode)
|
||||
* 3. Credentials validate via Claude Agent SDK query (API key, OAuth, Bedrock, or Vertex AI)
|
||||
* 4. Target URL is reachable from the container (DNS + HTTP)
|
||||
*/
|
||||
|
||||
@@ -473,7 +473,7 @@ async function validateTargetUrl(targetUrl: string, logger: ActivityLogger): Pro
|
||||
*
|
||||
* 1. Repository path exists and contains .git
|
||||
* 2. Config file parses and validates (if configPath provided)
|
||||
* 3. Credentials validate (API key, OAuth, or router mode)
|
||||
* 3. Credentials validate (API key, OAuth, Bedrock, or Vertex AI)
|
||||
* 4. Target URL is reachable from the container
|
||||
*
|
||||
* Returns on first failure.
|
||||
|
||||
@@ -17,7 +17,11 @@ interface DeliverableFile {
|
||||
}
|
||||
|
||||
// Pure function: Assemble final report from specialist deliverables
|
||||
export async function assembleFinalReport(sourceDir: string, logger: ActivityLogger): Promise<string> {
|
||||
export async function assembleFinalReport(
|
||||
sourceDir: string,
|
||||
deliverablesSubdir: string | undefined,
|
||||
logger: ActivityLogger,
|
||||
): Promise<string> {
|
||||
const deliverableFiles: DeliverableFile[] = [
|
||||
{ name: 'Injection', path: 'injection_exploitation_evidence.md', required: false },
|
||||
{ name: 'XSS', path: 'xss_exploitation_evidence.md', required: false },
|
||||
@@ -29,7 +33,7 @@ export async function assembleFinalReport(sourceDir: string, logger: ActivityLog
|
||||
const sections: string[] = [];
|
||||
|
||||
for (const file of deliverableFiles) {
|
||||
const filePath = path.join(deliverablesDir(sourceDir), file.path);
|
||||
const filePath = path.join(deliverablesDir(sourceDir, deliverablesSubdir), file.path);
|
||||
try {
|
||||
if (await fs.pathExists(filePath)) {
|
||||
const content = await fs.readFile(filePath, 'utf8');
|
||||
@@ -56,7 +60,7 @@ export async function assembleFinalReport(sourceDir: string, logger: ActivityLog
|
||||
}
|
||||
|
||||
const finalContent = sections.join('\n\n');
|
||||
const outputDir = deliverablesDir(sourceDir);
|
||||
const outputDir = deliverablesDir(sourceDir, deliverablesSubdir);
|
||||
const finalReportPath = path.join(outputDir, 'comprehensive_security_assessment_report.md');
|
||||
|
||||
try {
|
||||
@@ -82,6 +86,7 @@ export async function assembleFinalReport(sourceDir: string, logger: ActivityLog
|
||||
*/
|
||||
export async function injectModelIntoReport(
|
||||
repoPath: string,
|
||||
deliverablesSubdir: string | undefined,
|
||||
outputPath: string,
|
||||
logger: ActivityLogger,
|
||||
): Promise<void> {
|
||||
@@ -118,7 +123,7 @@ export async function injectModelIntoReport(
|
||||
logger.info(`Injecting model info into report: ${modelStr}`);
|
||||
|
||||
// 3. Read the final report
|
||||
const reportPath = path.join(deliverablesDir(repoPath), 'comprehensive_security_assessment_report.md');
|
||||
const reportPath = path.join(deliverablesDir(repoPath, deliverablesSubdir), 'comprehensive_security_assessment_report.md');
|
||||
|
||||
if (!(await fs.pathExists(reportPath))) {
|
||||
logger.warn('Final report not found, skipping model injection');
|
||||
|
||||
@@ -103,7 +103,6 @@ export const AGENTS: Readonly<Record<AgentName, AgentDefinition>> = Object.freez
|
||||
prerequisites: ['injection-exploit', 'xss-exploit', 'auth-exploit', 'ssrf-exploit', 'authz-exploit'],
|
||||
promptTemplate: 'report-executive',
|
||||
deliverableFilename: 'comprehensive_security_assessment_report.md',
|
||||
modelTier: 'small',
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ import { AuditSession } from '../audit/index.js';
|
||||
import type { ResumeAttempt } from '../audit/metrics-tracker.js';
|
||||
import type { SessionMetadata } from '../audit/utils.js';
|
||||
import type { WorkflowSummary } from '../audit/workflow-logger.js';
|
||||
import { DEFAULT_DELIVERABLES_SUBDIR, deliverablesDir } from '../paths.js';
|
||||
import type { ContainerConfig, ProviderConfig } from '../types/config.js';
|
||||
import type { CheckpointContext } from '../interfaces/checkpoint-provider.js';
|
||||
import { getContainer, getOrCreateContainer, removeContainer } from '../services/container.js';
|
||||
import { classifyErrorForTemporal, PentestError } from '../services/error-handling.js';
|
||||
import { ExploitationCheckerService } from '../services/exploitation-checker.js';
|
||||
@@ -33,9 +34,9 @@ import { assembleFinalReport, injectModelIntoReport } from '../services/reportin
|
||||
import { AGENTS } from '../session-manager.js';
|
||||
import type { AgentName } from '../types/agents.js';
|
||||
import { ALL_AGENTS } from '../types/agents.js';
|
||||
import type { ContainerConfig, ProviderConfig } from '../types/config.js';
|
||||
import { ErrorCode } from '../types/errors.js';
|
||||
import { isErr } from '../types/result.js';
|
||||
import { DEFAULT_DELIVERABLES_SUBDIR, deliverablesDir } from '../paths.js';
|
||||
import { fileExists, readJson } from '../utils/file-io.js';
|
||||
import { createActivityLogger } from './activity-logger.js';
|
||||
import type { AgentMetrics, PipelineState, ResumeState } from './shared.js';
|
||||
@@ -131,6 +132,20 @@ function buildContainerConfig(input: ActivityInput): ContainerConfig {
|
||||
*/
|
||||
async function runAgentActivity(agentName: AgentName, input: ActivityInput): Promise<AgentMetrics> {
|
||||
const { repoPath, configPath, pipelineTestingMode = false, workflowId, webUrl } = input;
|
||||
|
||||
// Skip guard: the checkpoint provider decides whether to run the agent.
|
||||
// The default NoOp provider always returns { skip: false }.
|
||||
const skipContainer = getContainer(workflowId) ??
|
||||
getOrCreateContainer(workflowId, buildSessionMetadata(input), buildContainerConfig(input));
|
||||
const decision = await skipContainer.checkpointProvider.shouldSkipAgent(
|
||||
agentName,
|
||||
repoPath,
|
||||
input.deliverablesSubdir ?? DEFAULT_DELIVERABLES_SUBDIR,
|
||||
);
|
||||
if (decision.skip && decision.metrics) {
|
||||
return decision.metrics;
|
||||
}
|
||||
|
||||
const startTime = Date.now();
|
||||
const attemptNumber = Context.current().info.attempt;
|
||||
|
||||
@@ -288,7 +303,7 @@ export async function runReportAgent(input: ActivityInput): Promise<AgentMetrics
|
||||
* Runs cheap checks before any agent execution:
|
||||
* 1. Repository path exists with .git
|
||||
* 2. Config file validates (if provided)
|
||||
* 3. Credential validation (API key, OAuth, or router mode)
|
||||
* 3. Credential validation (API key, OAuth, Bedrock, or Vertex AI)
|
||||
* 4. Target URL reachable from the container
|
||||
*
|
||||
* NOT using runAgentActivity — preflight doesn't run an agent via the SDK.
|
||||
@@ -306,15 +321,7 @@ export async function runPreflightValidation(input: ActivityInput): Promise<void
|
||||
const logger = createActivityLogger();
|
||||
logger.info('Running preflight validation...', { attempt: attemptNumber });
|
||||
|
||||
const result = await runPreflightChecks(
|
||||
input.webUrl,
|
||||
input.repoPath,
|
||||
input.configPath,
|
||||
logger,
|
||||
input.skipGitCheck,
|
||||
input.apiKey,
|
||||
input.providerConfig,
|
||||
);
|
||||
const result = await runPreflightChecks(input.webUrl, input.repoPath, input.configPath, logger, input.skipGitCheck, input.apiKey, input.providerConfig);
|
||||
|
||||
if (isErr(result)) {
|
||||
const classified = classifyErrorForTemporal(result.error);
|
||||
@@ -386,11 +393,11 @@ export async function initDeliverableGit(input: ActivityInput): Promise<void> {
|
||||
* Assemble the final report by concatenating exploitation evidence files.
|
||||
*/
|
||||
export async function assembleReportActivity(input: ActivityInput): Promise<void> {
|
||||
const { repoPath } = input;
|
||||
const { repoPath, deliverablesSubdir } = input;
|
||||
const logger = createActivityLogger();
|
||||
logger.info('Assembling deliverables from specialist agents...');
|
||||
try {
|
||||
await assembleFinalReport(repoPath, logger);
|
||||
await assembleFinalReport(repoPath, deliverablesSubdir, logger);
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
logger.warn(`Error assembling final report: ${err.message}`);
|
||||
@@ -401,11 +408,11 @@ export async function assembleReportActivity(input: ActivityInput): Promise<void
|
||||
* Inject model metadata into the final report.
|
||||
*/
|
||||
export async function injectReportMetadataActivity(input: ActivityInput): Promise<void> {
|
||||
const { repoPath, sessionId, outputPath } = input;
|
||||
const { repoPath, sessionId, outputPath, deliverablesSubdir } = input;
|
||||
const logger = createActivityLogger();
|
||||
const effectiveOutputPath = outputPath ? path.join(outputPath, sessionId) : path.join('./workspaces', sessionId);
|
||||
try {
|
||||
await injectModelIntoReport(repoPath, effectiveOutputPath, logger);
|
||||
await injectModelIntoReport(repoPath, deliverablesSubdir, effectiveOutputPath, logger);
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
logger.warn(`Error injecting model into report: ${err.message}`);
|
||||
@@ -593,6 +600,18 @@ export async function restoreGitCheckpoint(
|
||||
const logger = createActivityLogger();
|
||||
logger.info(`Restoring deliverables to ${checkpointHash}...`);
|
||||
|
||||
// Validate hash exists in this clone before attempting reset
|
||||
try {
|
||||
await executeGitCommandWithRetry(
|
||||
['git', 'rev-parse', '--verify', checkpointHash],
|
||||
repoPath,
|
||||
'verify checkpoint hash exists'
|
||||
);
|
||||
} catch {
|
||||
logger.info(`Checkpoint hash not found in clone, skipping git reset: ${checkpointHash}`);
|
||||
return;
|
||||
}
|
||||
|
||||
await executeGitCommandWithRetry(
|
||||
['git', 'reset', '--hard', checkpointHash],
|
||||
deliverablesPath,
|
||||
@@ -744,5 +763,42 @@ export async function saveCheckpoint(
|
||||
): Promise<void> {
|
||||
const container = getContainer(input.workflowId);
|
||||
if (!container?.checkpointProvider) return;
|
||||
return container.checkpointProvider.onAgentComplete(agentName, phase, state);
|
||||
|
||||
const context: CheckpointContext = {
|
||||
repoPath: input.repoPath,
|
||||
sessionId: input.sessionId,
|
||||
deliverablesSubdir: input.deliverablesSubdir ?? DEFAULT_DELIVERABLES_SUBDIR,
|
||||
...(input.outputPath !== undefined && { outputPath: input.outputPath }),
|
||||
};
|
||||
|
||||
return container.checkpointProvider.onAgentComplete(agentName, phase, state, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an optional additional output alongside the assembled markdown report.
|
||||
*
|
||||
* Delegates to the ReportOutputProvider registered in the DI container.
|
||||
* Default: no-op. Consumers can override this activity at the worker level
|
||||
* to emit derived outputs from the final report.
|
||||
*/
|
||||
export async function generateReportOutputActivity(input: ActivityInput): Promise<void> {
|
||||
const container = getContainer(input.workflowId);
|
||||
if (!container?.reportOutputProvider) return;
|
||||
|
||||
const logger = createActivityLogger();
|
||||
|
||||
// Resolve promptDir against the worker root so providers are cwd-independent.
|
||||
const resolvedInput: ActivityInput = {
|
||||
...input,
|
||||
...(input.promptDir !== undefined && {
|
||||
promptDir: path.isAbsolute(input.promptDir)
|
||||
? input.promptDir
|
||||
: path.resolve(process.env.SHANNON_WORKER_ROOT ?? process.cwd(), input.promptDir),
|
||||
}),
|
||||
};
|
||||
|
||||
const result = await container.reportOutputProvider.generate(resolvedInput, logger);
|
||||
if (result.outputPath) {
|
||||
logger.info(`Report output written to ${result.outputPath}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,10 +25,10 @@ export interface PipelineInput {
|
||||
deliverablesSubdir?: string; // Override deliverables path (default: '.shannon/deliverables')
|
||||
auditDir?: string; // Override audit log directory (default: './workspaces')
|
||||
promptDir?: string; // Override prompt template directory
|
||||
sastSarifPath?: string; // Path to SARIF file (gates SAST-enhanced mode)
|
||||
sastSarifPath?: string; // Optional path for consumer-supplied findings input
|
||||
checkpointsEnabled?: boolean; // Enable checkpoint activities (default: false)
|
||||
skipGitCheck?: boolean; // Skip .git directory validation in preflight (e.g. when .git is removed after clone)
|
||||
providerConfig?: ProviderConfig; // LLM provider configuration (Bedrock, Vertex, LiteLLM, etc.)
|
||||
providerConfig?: ProviderConfig; // LLM provider configuration (Bedrock, Vertex, etc.)
|
||||
}
|
||||
|
||||
export interface ResumeState {
|
||||
|
||||
@@ -332,30 +332,14 @@ export async function pentestPipeline(input: PipelineInput): Promise<PipelineSta
|
||||
];
|
||||
}
|
||||
|
||||
// Aggregate results from settled pipeline promises into workflow state
|
||||
// Aggregate errors from settled pipeline promises.
|
||||
// Metrics and completedAgents are updated incrementally inside runVulnExploitPipeline
|
||||
// so that getProgress queries reflect real-time status during execution.
|
||||
function aggregatePipelineResults(results: PromiseSettledResult<VulnExploitPipelineResult>[]): void {
|
||||
const failedPipelines: string[] = [];
|
||||
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled') {
|
||||
const { vulnType, vulnMetrics, exploitMetrics } = result.value;
|
||||
|
||||
const vulnAgentName = `${vulnType}-vuln`;
|
||||
if (vulnMetrics) {
|
||||
state.agentMetrics[vulnAgentName] = vulnMetrics;
|
||||
state.completedAgents.push(vulnAgentName);
|
||||
} else if (shouldSkip(vulnAgentName)) {
|
||||
state.completedAgents.push(vulnAgentName);
|
||||
}
|
||||
|
||||
const exploitAgentName = `${vulnType}-exploit`;
|
||||
if (exploitMetrics) {
|
||||
state.agentMetrics[exploitAgentName] = exploitMetrics;
|
||||
state.completedAgents.push(exploitAgentName);
|
||||
} else if (shouldSkip(exploitAgentName)) {
|
||||
state.completedAgents.push(exploitAgentName);
|
||||
}
|
||||
} else {
|
||||
if (result.status === 'rejected') {
|
||||
const errorMsg = result.reason instanceof Error ? result.reason.message : String(result.reason);
|
||||
failedPipelines.push(errorMsg);
|
||||
}
|
||||
@@ -442,14 +426,17 @@ export async function pentestPipeline(input: PipelineInput): Promise<PipelineSta
|
||||
let vulnMetrics: AgentMetrics | null = null;
|
||||
if (!shouldSkip(vulnAgentName)) {
|
||||
vulnMetrics = await runVulnAgent();
|
||||
state.agentMetrics[vulnAgentName] = vulnMetrics;
|
||||
state.completedAgents.push(vulnAgentName);
|
||||
if (input.checkpointsEnabled) {
|
||||
await a.saveCheckpoint(activityInput, vulnAgentName, 'vulnerability-analysis', state);
|
||||
}
|
||||
} else {
|
||||
log.info(`Skipping ${vulnAgentName} (already complete)`);
|
||||
state.completedAgents.push(vulnAgentName);
|
||||
}
|
||||
|
||||
// 1.5. Merge external findings (SAST, SCA, etc.) into exploitation queue
|
||||
// 1.5. Merge external findings from consumer provider into exploitation queue
|
||||
await a.mergeFindingsIntoQueue(activityInput, vulnType);
|
||||
|
||||
// 2. Check exploitation queue for actionable findings
|
||||
@@ -460,11 +447,14 @@ export async function pentestPipeline(input: PipelineInput): Promise<PipelineSta
|
||||
if (decision.shouldExploit) {
|
||||
if (!shouldSkip(exploitAgentName)) {
|
||||
exploitMetrics = await runExploitAgent();
|
||||
state.agentMetrics[exploitAgentName] = exploitMetrics;
|
||||
state.completedAgents.push(exploitAgentName);
|
||||
if (input.checkpointsEnabled) {
|
||||
await a.saveCheckpoint(activityInput, exploitAgentName, 'exploitation', state);
|
||||
}
|
||||
} else {
|
||||
log.info(`Skipping ${exploitAgentName} (already complete)`);
|
||||
state.completedAgents.push(exploitAgentName);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,6 +516,13 @@ export async function pentestPipeline(input: PipelineInput): Promise<PipelineSta
|
||||
state.completedAgents.push('report');
|
||||
}
|
||||
|
||||
// Runs after the skip gate so consumer providers still execute on resume.
|
||||
await a.generateReportOutputActivity(activityInput);
|
||||
|
||||
if (input.checkpointsEnabled) {
|
||||
await a.saveCheckpoint(activityInput, 'report-output', 'reporting', state);
|
||||
}
|
||||
|
||||
state.status = 'completed';
|
||||
state.currentPhase = null;
|
||||
state.currentAgent = null;
|
||||
|
||||
@@ -80,7 +80,6 @@ export interface ProviderConfig {
|
||||
readonly gcpCredentialsPath?: string;
|
||||
readonly baseUrl?: string;
|
||||
readonly authToken?: string;
|
||||
readonly routerDefault?: string;
|
||||
readonly modelOverrides?: Record<string, string>;
|
||||
readonly supportsStructuredOutput?: boolean;
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ export const BILLING_TEXT_PATTERNS = [
|
||||
'cap reached',
|
||||
'budget exceeded',
|
||||
'usage limit',
|
||||
'resets',
|
||||
] as const;
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user