diff --git a/packages/agent-runtime/src/__tests__/main-prompt.test.ts b/packages/agent-runtime/src/__tests__/main-prompt.test.ts index 17b4f99e1..f68e13147 100644 --- a/packages/agent-runtime/src/__tests__/main-prompt.test.ts +++ b/packages/agent-runtime/src/__tests__/main-prompt.test.ts @@ -375,6 +375,7 @@ describe('mainPrompt', () => { it('should update consecutiveAssistantMessages when new prompt is received', async () => { const sessionState = getInitialSessionState(mockFileContext) sessionState.mainAgentState.stepsRemaining = 12 + const initialStepsRemaining = sessionState.mainAgentState.stepsRemaining const action = { type: 'prompt' as const, @@ -394,7 +395,7 @@ describe('mainPrompt', () => { // When there's a new prompt, consecutiveAssistantMessages should be set to 1 expect(newSessionState.mainAgentState.stepsRemaining).toBe( - sessionState.mainAgentState.stepsRemaining - 1, + initialStepsRemaining - 1, ) }) diff --git a/packages/agent-runtime/src/run-agent-step.ts b/packages/agent-runtime/src/run-agent-step.ts index 704cedf3a..4b8267033 100644 --- a/packages/agent-runtime/src/run-agent-step.ts +++ b/packages/agent-runtime/src/run-agent-step.ts @@ -536,6 +536,17 @@ export const runAgentStep = async ( } } +/** + * Runs the agent loop. + * + * IMPORTANT: This function mutates `params.agentState` in place throughout the + * run (not just at return time). Fields like `messageHistory`, `systemPrompt`, + * `toolDefinitions`, `creditsUsed`, and `output` are updated as work progresses + * so that callers holding a reference to the same object (e.g. the SDK's + * `sessionState.mainAgentState`) see in-progress work immediately — which + * matters when an error is thrown mid-run and the normal return path is + * skipped. + */ export async function loopAgentSteps( params: { addAgentStep: AddAgentStepFn @@ -800,12 +811,13 @@ export async function loopAgentSteps( return cachedAdditionalToolDefinitions } - let currentAgentState: AgentState = { - ...initialAgentState, - messageHistory: initialMessages, - systemPrompt: system, - toolDefinitions, - } + // Mutate initialAgentState so that in-progress work propagates back to the + // caller's shared reference (e.g. SDK's sessionState.mainAgentState) even if + // an error is thrown before we return. + initialAgentState.messageHistory = initialMessages + initialAgentState.systemPrompt = system + initialAgentState.toolDefinitions = toolDefinitions + let currentAgentState: AgentState = initialAgentState // Convert tool definitions to Anthropic format for accurate token counting // Tool definitions are stored as { [name]: { description, inputSchema } } @@ -908,7 +920,8 @@ export async function loopAgentSteps( } = programmaticResult n = generateN - currentAgentState = programmaticAgentState + Object.assign(initialAgentState, programmaticAgentState) + currentAgentState = initialAgentState totalSteps = stepNumber shouldEndTurn = endTurn @@ -989,7 +1002,8 @@ export async function loopAgentSteps( logger.error('No runId found for agent state after finishing agent run') } - currentAgentState = newAgentState + Object.assign(initialAgentState, newAgentState) + currentAgentState = initialAgentState shouldEndTurn = llmShouldEndTurn nResponses = generatedResponses diff --git a/sdk/src/__tests__/run-error-preserves-history.test.ts b/sdk/src/__tests__/run-error-preserves-history.test.ts new file mode 100644 index 000000000..95b72ead2 --- /dev/null +++ b/sdk/src/__tests__/run-error-preserves-history.test.ts @@ -0,0 +1,315 @@ +import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt' +import { getInitialSessionState } from '@codebuff/common/types/session-state' +import { getStubProjectFileContext } from '@codebuff/common/util/file' +import { assistantMessage, userMessage } from '@codebuff/common/util/messages' +import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' + +import { CodebuffClient } from '../client' +import * as databaseModule from '../impl/database' + +interface ToolCallContentBlock { + type: 'tool-call' + toolCallId: string + toolName: string + input: Record +} + +const setupDatabaseMocks = () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + referral_code: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') +} + +describe('Error preserves in-progress message history', () => { + afterEach(() => { + mock.restore() + }) + + it('preserves in-progress assistant work on error (simulated via shared state mutation)', async () => { + setupDatabaseMocks() + + // Simulate the agent runtime: + // 1. Mutates the shared session state with the user message and partial work + // 2. Then throws due to a downstream timeout/service error + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const mainAgentState = params.action.sessionState.mainAgentState + + // Match the real runtime's behavior: replace messageHistory with a new + // array that includes the user prompt as its first entry. The SDK + // detects runtime progress via reference inequality, so we must + // reassign the array rather than pushing into it. + mainAgentState.messageHistory = [ + ...mainAgentState.messageHistory, + { + role: 'user', + content: [{ type: 'text', text: 'Fix the bug in auth.ts' }], + tags: ['USER_PROMPT'], + }, + { + role: 'assistant', + content: [ + { type: 'text', text: 'Let me read the auth file first.' }, + { + type: 'tool-call', + toolCallId: 'read-1', + toolName: 'read_files', + input: { paths: ['auth.ts'] }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'read-1', + toolName: 'read_files', + content: [ + { + type: 'json', + value: [{ path: 'auth.ts', content: 'const auth = ...' }], + }, + ], + }, + { + role: 'assistant', + content: [ + { type: 'text', text: 'Found the issue, writing the fix now.' }, + { + type: 'tool-call', + toolCallId: 'write-1', + toolName: 'write_file', + input: { path: 'auth.ts', content: 'const auth = fixed' }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'write-1', + toolName: 'write_file', + content: [{ type: 'json', value: { file: 'auth.ts', message: 'File written' } }], + }, + ] + + // Now simulate a server timeout on the next LLM call + const timeoutError = new Error('Service Unavailable') as Error & { + statusCode: number + responseBody: string + } + timeoutError.statusCode = 503 + timeoutError.responseBody = JSON.stringify({ + message: 'Request timeout after 30s', + }) + throw timeoutError + }, + ) + + const client = new CodebuffClient({ apiKey: 'test-key' }) + const result = await client.run({ + agent: 'base2', + prompt: 'Fix the bug in auth.ts', + }) + + // Error output with correct status code + expect(result.output.type).toBe('error') + const errorOutput = result.output as { + type: 'error' + message: string + statusCode?: number + } + expect(errorOutput.statusCode).toBe(503) + + const history = result.sessionState!.mainAgentState.messageHistory + + // The user's prompt should appear exactly once + const userPromptMessages = history.filter( + (m) => + m.role === 'user' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('Fix the bug'), + ), + ) + expect(userPromptMessages.length).toBe(1) + + // Assistant text messages from both steps should be preserved + const firstAssistantText = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('read the auth file'), + ), + ) + expect(firstAssistantText).toBeDefined() + + const secondAssistantText = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('writing the fix'), + ), + ) + expect(secondAssistantText).toBeDefined() + + // Both tool calls and both tool results should be preserved + const readToolCall = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-1', + ), + ) + expect(readToolCall).toBeDefined() + + const writeToolCall = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'write-1', + ), + ) + expect(writeToolCall).toBeDefined() + + const readToolResult = history.find( + (m) => m.role === 'tool' && m.toolCallId === 'read-1', + ) + expect(readToolResult).toBeDefined() + + const writeToolResult = history.find( + (m) => m.role === 'tool' && m.toolCallId === 'write-1', + ) + expect(writeToolResult).toBeDefined() + }) + + it('a subsequent run after error includes the preserved in-progress history', async () => { + setupDatabaseMocks() + + // Run 1: agent does some work then hits an error + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const mainAgentState = params.action.sessionState.mainAgentState + + mainAgentState.messageHistory = [ + ...mainAgentState.messageHistory, + { + role: 'user', + content: [{ type: 'text', text: 'Investigate the login bug' }], + tags: ['USER_PROMPT'], + }, + assistantMessage('I found the problem in auth.ts on line 42.'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'read-login', + toolName: 'read_files', + input: { paths: ['login.ts'] }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'read-login', + toolName: 'read_files', + content: [{ type: 'json', value: [{ path: 'login.ts', content: 'login code' }] }], + }, + ] + + const error = new Error('Service Unavailable') as Error & { + statusCode: number + } + error.statusCode = 503 + throw error + }, + ) + + const client = new CodebuffClient({ apiKey: 'test-key' }) + const firstResult = await client.run({ + agent: 'base2', + prompt: 'Investigate the login bug', + }) + + expect(firstResult.output.type).toBe('error') + + // Run 2: use the failed run as previousRun + mock.restore() + setupDatabaseMocks() + + let historyReceivedByRuntime: unknown[] | undefined + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId } = params + historyReceivedByRuntime = [ + ...params.action.sessionState.mainAgentState.messageHistory, + ] + + const responseSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) + responseSessionState.mainAgentState.messageHistory = [ + ...params.action.sessionState.mainAgentState.messageHistory, + userMessage('Now try again'), + assistantMessage('Continuing with the fix.'), + ] + + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState: responseSessionState, + output: { type: 'lastMessage', value: [] }, + }, + }) + + return { + sessionState: responseSessionState, + output: { type: 'lastMessage' as const, value: [] }, + } + }, + ) + + const secondResult = await client.run({ + agent: 'base2', + prompt: 'Now try again', + previousRun: firstResult, + }) + + // The runtime should have received history containing the work from the first run + expect(historyReceivedByRuntime).toBeDefined() + const receivedReadCall = historyReceivedByRuntime!.find( + (m) => + (m as { role: string }).role === 'assistant' && + ((m as { content: Array<{ type: string; toolCallId?: string }> }) + .content ?? []).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-login', + ), + ) + expect(receivedReadCall).toBeDefined() + + const receivedToolResult = historyReceivedByRuntime!.find( + (m) => + (m as { role: string }).role === 'tool' && + (m as { toolCallId: string }).toolCallId === 'read-login', + ) + expect(receivedToolResult).toBeDefined() + + // Final result should preserve history + const finalHistory = secondResult.sessionState!.mainAgentState.messageHistory + const finalReadCall = finalHistory.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-login', + ), + ) + expect(finalReadCall).toBeDefined() + }) +}) diff --git a/sdk/src/run.ts b/sdk/src/run.ts index 5a18f7025..2dfcef553 100644 --- a/sdk/src/run.ts +++ b/sdk/src/run.ts @@ -282,16 +282,27 @@ async function runOnce({ } } + // The agent runtime mutates sessionState.mainAgentState as it progresses, + // replacing messageHistory with a new array once it adds the user prompt. + // Comparing array identity detects progress more robustly than length: + // context pruning could shrink history below its starting length without + // meaning the runtime never ran. + const initialMessageHistory = sessionState.mainAgentState.messageHistory + /** Calculates the current session state if cancelled. * - * This is used when callMainPrompt throws an error (the server never processed the request). - * We need to add the user's message here since the server didn't get a chance to add it. + * This is used when callMainPrompt throws an error. If the agent runtime made + * any progress (replaced the shared messageHistory), those messages are + * preserved. Otherwise the user's message is added so it isn't lost. */ function getCancelledSessionState(message: string): SessionState { + const runtimeMadeProgress = + sessionState.mainAgentState.messageHistory !== initialMessageHistory + const state = cloneDeep(sessionState) - // Add the user's message since the server never processed it - if (prompt || preparedContent) { + // Only add the user's message if the runtime didn't get a chance to add it. + if (!runtimeMadeProgress && (prompt || preparedContent)) { state.mainAgentState.messageHistory.push({ role: 'user' as const, content: buildUserMessageContent(prompt, params, preparedContent),