From 05deaf59638c8a9f0e9c234afedb381917200b3c Mon Sep 17 00:00:00 2001 From: Emanuel Ehmki Date: Mon, 30 Mar 2026 18:15:07 +0200 Subject: [PATCH] fix(session): abort-safe stream processing and resilient cleanup - Replace for-await with Promise.race abort-race pattern to prevent hangs when abort fires during outstanding tool results - Wrap post-loop Snapshot.patch in try/catch to prevent zombie sessions when snapshot cleanup fails - Make doom loop permission prompt abort-aware - Add sweep() for recursive child session abort on cancellation --- packages/opencode/src/session/processor.ts | 672 ++++++++++-------- .../test/session/iterator-cleanup.test.ts | 165 +++++ 2 files changed, 534 insertions(+), 303 deletions(-) create mode 100644 packages/opencode/test/session/iterator-cleanup.test.ts diff --git a/packages/opencode/src/session/processor.ts b/packages/opencode/src/session/processor.ts index 8200dea7564d..c315195043e3 100644 --- a/packages/opencode/src/session/processor.ts +++ b/packages/opencode/src/session/processor.ts @@ -5,6 +5,7 @@ import { Agent } from "@/agent/agent" import { Snapshot } from "@/snapshot" import { SessionSummary } from "./summary" import { Bus } from "@/bus" +import { BusEvent } from "@/bus/bus-event" import { SessionRetry } from "./retry" import { SessionStatus } from "./status" import { Plugin } from "@/plugin" @@ -14,13 +15,47 @@ import { Config } from "@/config/config" import { SessionCompaction } from "./compaction" import { PermissionNext } from "@/permission" import { Question } from "@/question" -import { PartID } from "./schema" -import type { SessionID, MessageID } from "./schema" +import { PartID, SessionID } from "./schema" +import type { MessageID } from "./schema" +import z from "zod" export namespace SessionProcessor { + export const Event = { + CancelRequested: BusEvent.define("session.prompt.cancel", z.object({ sessionID: SessionID.zod })), + } const DOOM_LOOP_THRESHOLD = 3 const log = Log.create({ service: "session.processor" }) + /** Recursively mark running tool parts as "error" for a child session and its descendants. */ + async function abortChildren(sessionID: SessionID, visited = new Set()) { + if (visited.has(sessionID)) return + visited.add(sessionID) + Bus.publish(Event.CancelRequested, { sessionID }) + const msgs = await Session.messages({ sessionID }) + for (const msg of msgs) { + for (const part of msg.parts) { + if (part.type !== "tool") continue + if (part.state.status === "completed" || part.state.status === "error") continue + // If this is a task tool with a child session, recurse first + if (part.tool === "task" && part.state.status === "running" && part.state.metadata?.sessionId) { + await abortChildren(part.state.metadata.sessionId, visited) + } + await Session.updatePart({ + ...part, + state: { + ...part.state, + status: "error", + error: "Tool execution aborted", + time: { + start: part.state.status === "running" ? part.state.time.start : Date.now(), + end: Date.now(), + }, + }, + }) + } + } + } + export type Info = Awaited> export type Result = Awaited> @@ -36,6 +71,30 @@ export namespace SessionProcessor { let attempt = 0 let needsCompaction = false + /** Mark any non-terminal tool parts as "error" and abort child sessions. */ + async function sweep() { + const parts = await MessageV2.parts(input.assistantMessage.id) + for (const part of parts) { + if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") { + if (part.tool === "task" && part.state.status === "running" && part.state.metadata?.sessionId) { + await abortChildren(part.state.metadata.sessionId) + } + await Session.updatePart({ + ...part, + state: { + ...part.state, + status: "error", + error: "Tool execution aborted", + time: { + start: part.state.status === "running" ? part.state.time.start : Date.now(), + end: Date.now(), + }, + }, + }) + } + } + } + const result = { get message() { return input.assistantMessage @@ -53,303 +112,319 @@ export namespace SessionProcessor { let reasoningMap: Record = {} const stream = await LLM.stream(streamInput) - for await (const value of stream.fullStream) { - input.abort.throwIfAborted() - switch (value.type) { - case "start": - SessionStatus.set(input.sessionID, { type: "busy" }) - break - - case "reasoning-start": - if (value.id in reasoningMap) { - continue - } - const reasoningPart = { - id: PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.assistantMessage.sessionID, - type: "reasoning" as const, - text: "", - time: { - start: Date.now(), - }, - metadata: value.providerMetadata, - } - reasoningMap[value.id] = reasoningPart - await Session.updatePart(reasoningPart) - break - - case "reasoning-delta": - if (value.id in reasoningMap) { - const part = reasoningMap[value.id] - part.text += value.text - if (value.providerMetadata) part.metadata = value.providerMetadata - await Session.updatePartDelta({ - sessionID: part.sessionID, - messageID: part.messageID, - partID: part.id, - field: "text", - delta: value.text, - }) - } - break - - case "reasoning-end": - if (value.id in reasoningMap) { - const part = reasoningMap[value.id] - part.text = part.text.trimEnd() - - part.time = { - ...part.time, - end: Date.now(), + // Race each stream chunk against the abort signal so we don't + // hang waiting for outstanding tool results after cancellation. + const aborted = new Promise((_, reject) => { + if (input.abort.aborted) return reject(new DOMException("Aborted", "AbortError")) + input.abort.addEventListener("abort", () => reject(new DOMException("Aborted", "AbortError")), { + once: true, + }) + }) + const iter = stream.fullStream[Symbol.asyncIterator]() + try { + while (true) { + const { done, value } = await Promise.race([iter.next(), aborted]) + if (done) break + input.abort.throwIfAborted() + switch (value.type) { + case "start": + SessionStatus.set(input.sessionID, { type: "busy" }) + break + + case "reasoning-start": + if (value.id in reasoningMap) { + continue } - if (value.providerMetadata) part.metadata = value.providerMetadata - await Session.updatePart(part) - delete reasoningMap[value.id] - } - break - - case "tool-input-start": - const part = await Session.updatePart({ - id: toolcalls[value.id]?.id ?? PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.assistantMessage.sessionID, - type: "tool", - tool: value.toolName, - callID: value.id, - state: { - status: "pending", - input: {}, - raw: "", - }, - }) - toolcalls[value.id] = part as MessageV2.ToolPart - break - - case "tool-input-delta": - break - - case "tool-input-end": - break - - case "tool-call": { - const match = toolcalls[value.toolCallId] - if (match) { + const reasoningPart = { + id: PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "reasoning" as const, + text: "", + time: { + start: Date.now(), + }, + metadata: value.providerMetadata, + } + reasoningMap[value.id] = reasoningPart + await Session.updatePart(reasoningPart) + break + + case "reasoning-delta": + if (value.id in reasoningMap) { + const part = reasoningMap[value.id] + part.text += value.text + if (value.providerMetadata) part.metadata = value.providerMetadata + await Session.updatePartDelta({ + sessionID: part.sessionID, + messageID: part.messageID, + partID: part.id, + field: "text", + delta: value.text, + }) + } + break + + case "reasoning-end": + if (value.id in reasoningMap) { + const part = reasoningMap[value.id] + part.text = part.text.trimEnd() + + part.time = { + ...part.time, + end: Date.now(), + } + if (value.providerMetadata) part.metadata = value.providerMetadata + await Session.updatePart(part) + delete reasoningMap[value.id] + } + break + + case "tool-input-start": const part = await Session.updatePart({ - ...match, + id: toolcalls[value.id]?.id ?? PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "tool", tool: value.toolName, + callID: value.id, state: { - status: "running", - input: value.input, - time: { - start: Date.now(), - }, + status: "pending", + input: {}, + raw: "", }, - metadata: value.providerMetadata, }) - toolcalls[value.toolCallId] = part as MessageV2.ToolPart - - const parts = await MessageV2.parts(input.assistantMessage.id) - const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD) - - if ( - lastThree.length === DOOM_LOOP_THRESHOLD && - lastThree.every( - (p) => - p.type === "tool" && - p.tool === value.toolName && - p.state.status !== "pending" && - JSON.stringify(p.state.input) === JSON.stringify(value.input), - ) - ) { - const agent = await Agent.get(input.assistantMessage.agent) - await PermissionNext.ask({ - permission: "doom_loop", - patterns: [value.toolName], - sessionID: input.assistantMessage.sessionID, - metadata: { - tool: value.toolName, + toolcalls[value.id] = part as MessageV2.ToolPart + break + + case "tool-input-delta": + break + + case "tool-input-end": + break + + case "tool-call": { + const match = toolcalls[value.toolCallId] + if (match) { + const part = await Session.updatePart({ + ...match, + tool: value.toolName, + state: { + status: "running", input: value.input, + time: { + start: Date.now(), + }, }, - always: [value.toolName], - ruleset: agent.permission, + metadata: value.providerMetadata, }) + toolcalls[value.toolCallId] = part as MessageV2.ToolPart + + const parts = await MessageV2.parts(input.assistantMessage.id) + const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD) + + if ( + lastThree.length === DOOM_LOOP_THRESHOLD && + lastThree.every( + (p) => + p.type === "tool" && + p.tool === value.toolName && + p.state.status !== "pending" && + JSON.stringify(p.state.input) === JSON.stringify(value.input), + ) + ) { + const agent = await Agent.get(input.assistantMessage.agent) + const permission = PermissionNext.ask({ + permission: "doom_loop", + patterns: [value.toolName], + sessionID: input.assistantMessage.sessionID, + metadata: { + tool: value.toolName, + input: value.input, + }, + always: [value.toolName], + ruleset: agent.permission, + }) + await Promise.race([permission, aborted]) + } } + break } - break - } - case "tool-result": { - const match = toolcalls[value.toolCallId] - if (match && match.state.status === "running") { - await Session.updatePart({ - ...match, - state: { - status: "completed", - input: value.input ?? match.state.input, - output: value.output.output, - metadata: value.output.metadata, - title: value.output.title, - time: { - start: match.state.time.start, - end: Date.now(), - }, - attachments: value.output.attachments, - }, - }) - - delete toolcalls[value.toolCallId] - } - break - } - - case "tool-error": { - const match = toolcalls[value.toolCallId] - if (match && match.state.status === "running") { - await Session.updatePart({ - ...match, - state: { - status: "error", - input: value.input ?? match.state.input, - error: value.error instanceof Error ? value.error.message : String(value.error), - time: { - start: match.state.time.start, - end: Date.now(), + case "tool-result": { + const match = toolcalls[value.toolCallId] + if (match && match.state.status === "running") { + await Session.updatePart({ + ...match, + state: { + status: "completed", + input: value.input ?? match.state.input, + output: value.output.output, + metadata: value.output.metadata, + title: value.output.title, + time: { + start: match.state.time.start, + end: Date.now(), + }, + attachments: value.output.attachments, }, - }, - }) - - if ( - value.error instanceof PermissionNext.RejectedError || - value.error instanceof Question.RejectedError - ) { - blocked = shouldBreak + }) + + delete toolcalls[value.toolCallId] } - delete toolcalls[value.toolCallId] + break } - break - } - case "error": - throw value.error - - case "start-step": - snapshot = await Snapshot.track() - await Session.updatePart({ - id: PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.sessionID, - snapshot, - type: "step-start", - }) - break - - case "finish-step": - const usage = Session.getUsage({ - model: input.model, - usage: value.usage, - metadata: value.providerMetadata, - }) - input.assistantMessage.finish = value.finishReason - input.assistantMessage.cost += usage.cost - input.assistantMessage.tokens = usage.tokens - await Session.updatePart({ - id: PartID.ascending(), - reason: value.finishReason, - snapshot: await Snapshot.track(), - messageID: input.assistantMessage.id, - sessionID: input.assistantMessage.sessionID, - type: "step-finish", - tokens: usage.tokens, - cost: usage.cost, - }) - await Session.updateMessage(input.assistantMessage) - if (snapshot) { - const patch = await Snapshot.patch(snapshot) - if (patch.files.length) { + + case "tool-error": { + const match = toolcalls[value.toolCallId] + if (match && match.state.status === "running") { await Session.updatePart({ - id: PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.sessionID, - type: "patch", - hash: patch.hash, - files: patch.files, + ...match, + state: { + status: "error", + input: value.input ?? match.state.input, + error: value.error instanceof Error ? value.error.message : String(value.error), + time: { + start: match.state.time.start, + end: Date.now(), + }, + }, }) + + if ( + value.error instanceof PermissionNext.RejectedError || + value.error instanceof Question.RejectedError + ) { + blocked = shouldBreak + } + delete toolcalls[value.toolCallId] } - snapshot = undefined - } - SessionSummary.summarize({ - sessionID: input.sessionID, - messageID: input.assistantMessage.parentID, - }) - if ( - !input.assistantMessage.summary && - (await SessionCompaction.isOverflow({ tokens: usage.tokens, model: input.model })) - ) { - needsCompaction = true + break } - break - - case "text-start": - currentText = { - id: PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.assistantMessage.sessionID, - type: "text", - text: "", - time: { - start: Date.now(), - }, - metadata: value.providerMetadata, - } - await Session.updatePart(currentText) - break - - case "text-delta": - if (currentText) { - currentText.text += value.text - if (value.providerMetadata) currentText.metadata = value.providerMetadata - await Session.updatePartDelta({ - sessionID: currentText.sessionID, - messageID: currentText.messageID, - partID: currentText.id, - field: "text", - delta: value.text, + case "error": + throw value.error + + case "start-step": + snapshot = await Snapshot.track() + await Session.updatePart({ + id: PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.sessionID, + snapshot, + type: "step-start", }) - } - break - - case "text-end": - if (currentText) { - currentText.text = currentText.text.trimEnd() - const textOutput = await Plugin.trigger( - "experimental.text.complete", - { - sessionID: input.sessionID, - messageID: input.assistantMessage.id, - partID: currentText.id, + break + + case "finish-step": + const usage = Session.getUsage({ + model: input.model, + usage: value.usage, + metadata: value.providerMetadata, + }) + input.assistantMessage.finish = value.finishReason + input.assistantMessage.cost += usage.cost + input.assistantMessage.tokens = usage.tokens + await Session.updatePart({ + id: PartID.ascending(), + reason: value.finishReason, + snapshot: await Snapshot.track(), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "step-finish", + tokens: usage.tokens, + cost: usage.cost, + }) + await Session.updateMessage(input.assistantMessage) + if (snapshot) { + const patch = await Snapshot.patch(snapshot) + if (patch.files.length) { + await Session.updatePart({ + id: PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.sessionID, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + } + snapshot = undefined + } + SessionSummary.summarize({ + sessionID: input.sessionID, + messageID: input.assistantMessage.parentID, + }) + if ( + !input.assistantMessage.summary && + (await SessionCompaction.isOverflow({ tokens: usage.tokens, model: input.model })) + ) { + needsCompaction = true + } + break + + case "text-start": + currentText = { + id: PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "text", + text: "", + time: { + start: Date.now(), }, - { text: currentText.text }, - ) - currentText.text = textOutput.text - currentText.time = { - start: Date.now(), - end: Date.now(), + metadata: value.providerMetadata, } - if (value.providerMetadata) currentText.metadata = value.providerMetadata await Session.updatePart(currentText) - } - currentText = undefined - break - - case "finish": - break - - default: - log.info("unhandled", { - ...value, - }) - continue + break + + case "text-delta": + if (currentText) { + currentText.text += value.text + if (value.providerMetadata) currentText.metadata = value.providerMetadata + await Session.updatePartDelta({ + sessionID: currentText.sessionID, + messageID: currentText.messageID, + partID: currentText.id, + field: "text", + delta: value.text, + }) + } + break + + case "text-end": + if (currentText) { + currentText.text = currentText.text.trimEnd() + const textOutput = await Plugin.trigger( + "experimental.text.complete", + { + sessionID: input.sessionID, + messageID: input.assistantMessage.id, + partID: currentText.id, + }, + { text: currentText.text }, + ) + currentText.text = textOutput.text + currentText.time = { + start: Date.now(), + end: Date.now(), + } + if (value.providerMetadata) currentText.metadata = value.providerMetadata + await Session.updatePart(currentText) + } + currentText = undefined + break + + case "finish": + break + + default: + log.info("unhandled", { + ...value, + }) + continue + } + if (needsCompaction) break } - if (needsCompaction) break + } finally { + await iter.return?.().catch(() => {}) } } catch (e: any) { log.error("process", { @@ -385,37 +460,28 @@ export namespace SessionProcessor { SessionStatus.set(input.sessionID, { type: "idle" }) } } + // Cleanup sweep FIRST -- mark any stuck tool parts as "error" before + // the (potentially slow) snapshot patch. This ensures parent sessions + // see child tool parts in a terminal state promptly after abort. + await sweep() if (snapshot) { - const patch = await Snapshot.patch(snapshot) - if (patch.files.length) { - await Session.updatePart({ - id: PartID.ascending(), - messageID: input.assistantMessage.id, - sessionID: input.sessionID, - type: "patch", - hash: patch.hash, - files: patch.files, - }) + try { + const patch = await Snapshot.patch(snapshot) + if (patch.files.length) { + await Session.updatePart({ + id: PartID.ascending(), + messageID: input.assistantMessage.id, + sessionID: input.sessionID, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + } + } catch (e) { + log.warn("snapshot patch failed during cleanup", { error: e, sessionID: input.sessionID }) } snapshot = undefined } - const p = await MessageV2.parts(input.assistantMessage.id) - for (const part of p) { - if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") { - await Session.updatePart({ - ...part, - state: { - ...part.state, - status: "error", - error: "Tool execution aborted", - time: { - start: Date.now(), - end: Date.now(), - }, - }, - }) - } - } input.assistantMessage.time.completed = Date.now() await Session.updateMessage(input.assistantMessage) if (needsCompaction) return "compact" diff --git a/packages/opencode/test/session/iterator-cleanup.test.ts b/packages/opencode/test/session/iterator-cleanup.test.ts new file mode 100644 index 000000000000..a79222f96f6a --- /dev/null +++ b/packages/opencode/test/session/iterator-cleanup.test.ts @@ -0,0 +1,165 @@ +import { describe, expect, test } from "bun:test" + +/** + * Unit tests for the async iterator cleanup pattern used in processor.ts. + * + * The processor wraps its `while(true)` iterator consumption in try/finally + * with `await iter.return?.().catch(() => {})` in the finally block. These + * tests verify the pattern works correctly for all exit paths. + * + * Validates: Property 8 (Iterator Cleanup on Abort) + */ + +function mock(items: number[]) { + let idx = 0 + const spy = { calls: 0 } + const iter: AsyncIterator = { + async next() { + if (idx >= items.length) return { done: true as const, value: undefined } + return { done: false as const, value: items[idx++] } + }, + async return() { + spy.calls++ + return { done: true as const, value: undefined } + }, + } + return { iter, spy } +} + +describe("iterator cleanup", () => { + test("Property 8: return() called on abort", async () => { + const { iter, spy } = mock([1, 2, 3, 4, 5]) + const controller = new AbortController() + const aborted = new Promise((_, reject) => { + if (controller.signal.aborted) return reject(new DOMException("Aborted", "AbortError")) + controller.signal.addEventListener("abort", () => reject(new DOMException("Aborted", "AbortError")), { + once: true, + }) + }) + + const collected: number[] = [] + try { + while (true) { + const { done, value } = await Promise.race([iter.next(), aborted]) + if (done) break + controller.signal.throwIfAborted() + collected.push(value) + if (collected.length === 2) controller.abort() + } + } catch { + // abort error expected + } finally { + await iter.return?.().catch(() => {}) + } + + expect(spy.calls).toBe(1) + expect(collected).toEqual([1, 2]) + }) + + test("Property 8: return() called on normal completion", async () => { + const { iter, spy } = mock([1, 2, 3]) + const aborted = new Promise(() => {}) + + const collected: number[] = [] + try { + while (true) { + const { done, value } = await Promise.race([iter.next(), aborted]) + if (done) break + collected.push(value) + } + } finally { + await iter.return?.().catch(() => {}) + } + + expect(spy.calls).toBe(1) + expect(collected).toEqual([1, 2, 3]) + }) + + test("Property 8: return() error suppressed", async () => { + const iter: AsyncIterator = { + async next() { + return { done: true as const, value: undefined } + }, + async return() { + throw new Error("cleanup failed") + }, + } + + // Should not throw despite return() throwing + try { + while (true) { + const { done } = await iter.next() + if (done) break + } + } finally { + await iter.return?.().catch(() => {}) + } + + // Reaching here means the error was suppressed + expect(true).toBe(true) + }) + + test("Property 8: return() tolerates exhausted iterator", async () => { + const { iter, spy } = mock([]) + + try { + while (true) { + const { done } = await iter.next() + if (done) break + } + } finally { + await iter.return?.().catch(() => {}) + } + + expect(spy.calls).toBe(1) + }) + + test("Property 8: return() called on thrown error", async () => { + const { iter, spy } = mock([1, 2, 3]) + const aborted = new Promise(() => {}) + + let caught = false + try { + while (true) { + const { done, value } = await Promise.race([iter.next(), aborted]) + if (done) break + if (value === 2) throw new Error("processing error") + } + } catch { + caught = true + } finally { + await iter.return?.().catch(() => {}) + } + + expect(caught).toBe(true) + expect(spy.calls).toBe(1) + }) + + test("Property 8: pre-aborted signal triggers immediate rejection and return()", async () => { + const { iter, spy } = mock([1, 2, 3]) + const controller = new AbortController() + controller.abort() + const aborted = new Promise((_, reject) => { + if (controller.signal.aborted) return reject(new DOMException("Aborted", "AbortError")) + controller.signal.addEventListener("abort", () => reject(new DOMException("Aborted", "AbortError")), { + once: true, + }) + }) + + let caught = false + try { + while (true) { + const { done, value } = await Promise.race([iter.next(), aborted]) + if (done) break + controller.signal.throwIfAborted() + } + } catch { + caught = true + } finally { + await iter.return?.().catch(() => {}) + } + + expect(caught).toBe(true) + expect(spy.calls).toBe(1) + }) +})