From 77fc47c01c2e4040d923b89a11564d2e6bb160d8 Mon Sep 17 00:00:00 2001 From: Minh-Phuc Tran Date: Mon, 30 Mar 2026 14:15:34 +0700 Subject: [PATCH 1/2] Allow tool hooks to replace call arguments --- packages/opencode/src/session/prompt.ts | 35 ++-- .../opencode/test/session/tool-hooks.test.ts | 160 ++++++++++++++++++ 2 files changed, 180 insertions(+), 15 deletions(-) create mode 100644 packages/opencode/test/session/tool-hooks.test.ts diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index a9edf838ca8c..c572ed8be5cc 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -826,19 +826,21 @@ export namespace SessionPrompt { description: item.description, inputSchema: jsonSchema(schema as any), async execute(args, options) { - const ctx = context(args, options) + const hookOutput = { + args: args ?? {}, + } await Plugin.trigger( "tool.execute.before", { tool: item.id, - sessionID: ctx.sessionID, - callID: ctx.callID, - }, - { - args, + sessionID: input.session.id, + callID: options.toolCallId, }, + hookOutput, ) - const result = await item.execute(args, ctx) + const nextArgs = hookOutput.args + const ctx = context(nextArgs, options) + const result = await item.execute(nextArgs, ctx) const output = { ...result, attachments: result.attachments?.map((attachment) => ({ @@ -854,7 +856,7 @@ export namespace SessionPrompt { tool: item.id, sessionID: ctx.sessionID, callID: ctx.callID, - args, + args: nextArgs, }, output, ) @@ -872,20 +874,23 @@ export namespace SessionPrompt { item.inputSchema = jsonSchema(transformed) // Wrap execute to add plugin hooks and format output item.execute = async (args, opts) => { - const ctx = context(args, opts) + const hookOutput = { + args: args ?? {}, + } await Plugin.trigger( "tool.execute.before", { tool: key, - sessionID: ctx.sessionID, + sessionID: input.session.id, callID: opts.toolCallId, }, - { - args, - }, + hookOutput, ) + const nextArgs = hookOutput.args + const ctx = context(nextArgs, opts) + await ctx.ask({ permission: key, metadata: {}, @@ -893,7 +898,7 @@ export namespace SessionPrompt { always: ["*"], }) - const result = await execute(args, opts) + const result = await execute(nextArgs, opts) await Plugin.trigger( "tool.execute.after", @@ -901,7 +906,7 @@ export namespace SessionPrompt { tool: key, sessionID: ctx.sessionID, callID: opts.toolCallId, - args, + args: nextArgs, }, result, ) diff --git a/packages/opencode/test/session/tool-hooks.test.ts b/packages/opencode/test/session/tool-hooks.test.ts new file mode 100644 index 000000000000..c4b5d082446c --- /dev/null +++ b/packages/opencode/test/session/tool-hooks.test.ts @@ -0,0 +1,160 @@ +import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" +import { jsonSchema } from "ai" +import z from "zod" +import { MCP } from "../../src/mcp" +import { Permission } from "../../src/permission" +import { Plugin } from "../../src/plugin" +import type { Provider } from "../../src/provider/provider" +import { ProviderID } from "../../src/provider/schema" +import { SessionPrompt } from "../../src/session/prompt" +import { MessageID, SessionID } from "../../src/session/schema" +import { ToolRegistry } from "../../src/tool/registry" +import { Truncate } from "../../src/tool/truncate" + +afterEach(() => { + mock.restore() +}) + +function createModel(): Provider.Model { + return { + id: "test-model", + providerID: ProviderID.make("test"), + name: "Test", + limit: { + context: 100_000, + output: 32_000, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + capabilities: { + toolcall: true, + attachment: false, + reasoning: false, + temperature: true, + input: { text: true, image: false, audio: false, video: false }, + output: { text: true, image: false, audio: false, video: false }, + }, + api: { + id: "test-model", + npm: "@ai-sdk/openai", + }, + options: {}, + } as Provider.Model +} + +function input() { + const model = createModel() + return { + agent: { + name: "build", + permission: [], + } as any, + model, + session: { + id: SessionID.descending(), + permission: [], + } as any, + processor: { + message: { + id: MessageID.ascending(), + sessionID: SessionID.descending(), + role: "assistant", + time: { created: Date.now() }, + parentID: MessageID.ascending(), + mode: "build", + agent: "build", + path: { cwd: "/", root: "/" }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.id, + providerID: model.providerID, + finish: "end_turn", + } as any, + partFromToolCall() { + return undefined + }, + process: async () => "stop" as const, + }, + bypassAgentCheck: false, + messages: [], + } as Parameters<(typeof SessionPrompt)["resolveTools"]>[0] +} + +describe("session.prompt tool hooks", () => { + test("tool.execute.before can replace undefined args for native tools", async () => { + let captured: unknown + + spyOn(ToolRegistry, "tools").mockResolvedValue([ + { + id: "native_demo", + description: "Native demo", + parameters: z.object({}), + execute: async (args: unknown) => { + captured = args + return { title: "", output: "ok", metadata: {} } + }, + } as any, + ]) + spyOn(MCP, "tools").mockResolvedValue({}) + spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) + spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { + if (name === "tool.execute.before") { + ;(output as { args: unknown }).args = { __patched: true } + } + return output + }) + + const tools = await SessionPrompt.resolveTools(input()) + if (!tools.native_demo?.execute) throw new Error("native_demo execute missing") + await tools.native_demo.execute(undefined, { + toolCallId: "call-native", + abortSignal: new AbortController().signal, + } as any) + + expect(captured).toEqual({ __patched: true }) + }) + + test("tool.execute.before can replace undefined args for MCP tools", async () => { + let captured: unknown + + spyOn(ToolRegistry, "tools").mockResolvedValue([]) + spyOn(MCP, "tools").mockResolvedValue({ + openkitten_echo: { + description: "MCP demo", + inputSchema: jsonSchema({ + type: "object", + properties: {}, + additionalProperties: false, + }), + execute: async (args: unknown) => { + captured = args + return { + content: [{ type: "text", text: "ok" }], + metadata: {}, + } + }, + } as any, + }) + spyOn(Permission, "ask").mockResolvedValue(undefined) + spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) + spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { + if (name === "tool.execute.before") { + ;(output as { args: unknown }).args = { __patched: true } + } + return output + }) + + const tools = await SessionPrompt.resolveTools(input()) + if (!tools.openkitten_echo?.execute) throw new Error("openkitten_echo execute missing") + await tools.openkitten_echo.execute(undefined, { + toolCallId: "call-mcp", + abortSignal: new AbortController().signal, + } as any) + + expect(captured).toEqual({ __patched: true }) + }) +}) From 8d1351dc0f372394978c2b352afaea2af79887b4 Mon Sep 17 00:00:00 2001 From: Minh-Phuc Tran Date: Mon, 30 Mar 2026 14:18:14 +0700 Subject: [PATCH 2/2] Tighten tool hook regression test types --- .../opencode/test/session/tool-hooks.test.ts | 115 +++++++++++------- 1 file changed, 69 insertions(+), 46 deletions(-) diff --git a/packages/opencode/test/session/tool-hooks.test.ts b/packages/opencode/test/session/tool-hooks.test.ts index c4b5d082446c..589c9d8982e4 100644 --- a/packages/opencode/test/session/tool-hooks.test.ts +++ b/packages/opencode/test/session/tool-hooks.test.ts @@ -1,13 +1,18 @@ import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" import { jsonSchema } from "ai" import z from "zod" +import type { Agent } from "../../src/agent/agent" import { MCP } from "../../src/mcp" import { Permission } from "../../src/permission" import { Plugin } from "../../src/plugin" +import { ProjectID } from "../../src/project/schema" import type { Provider } from "../../src/provider/provider" import { ProviderID } from "../../src/provider/schema" +import type { MessageV2 } from "../../src/session/message-v2" +import type { SessionProcessor } from "../../src/session/processor" import { SessionPrompt } from "../../src/session/prompt" import { MessageID, SessionID } from "../../src/session/schema" +import type { Session } from "../../src/session" import { ToolRegistry } from "../../src/tool/registry" import { Truncate } from "../../src/tool/truncate" @@ -41,47 +46,63 @@ function createModel(): Provider.Model { } as Provider.Model } -function input() { +function input(): Parameters<(typeof SessionPrompt)["resolveTools"]>[0] { const model = createModel() + const agent: Agent.Info = { + name: "build", + mode: "primary", + permission: [], + options: {}, + } + const session: Session.Info = { + id: SessionID.descending(), + slug: "tool-hooks", + projectID: ProjectID.global, + directory: "/", + title: "Tool hooks", + version: "1", + time: { + created: Date.now(), + updated: Date.now(), + }, + permission: [], + } + const message: MessageV2.Assistant = { + id: MessageID.ascending(), + sessionID: SessionID.descending(), + role: "assistant", + time: { created: Date.now() }, + parentID: MessageID.ascending(), + mode: "build", + agent: "build", + path: { cwd: "/", root: "/" }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.id, + providerID: model.providerID, + finish: "end_turn", + } + const processor: SessionProcessor.Info = { + message, + partFromToolCall() { + return undefined + }, + process: async () => "stop", + } + return { - agent: { - name: "build", - permission: [], - } as any, + agent, model, - session: { - id: SessionID.descending(), - permission: [], - } as any, - processor: { - message: { - id: MessageID.ascending(), - sessionID: SessionID.descending(), - role: "assistant", - time: { created: Date.now() }, - parentID: MessageID.ascending(), - mode: "build", - agent: "build", - path: { cwd: "/", root: "/" }, - cost: 0, - tokens: { - input: 0, - output: 0, - reasoning: 0, - cache: { read: 0, write: 0 }, - }, - modelID: model.id, - providerID: model.providerID, - finish: "end_turn", - } as any, - partFromToolCall() { - return undefined - }, - process: async () => "stop" as const, - }, + session, + processor, bypassAgentCheck: false, messages: [], - } as Parameters<(typeof SessionPrompt)["resolveTools"]>[0] + } } describe("session.prompt tool hooks", () => { @@ -97,13 +118,13 @@ describe("session.prompt tool hooks", () => { captured = args return { title: "", output: "ok", metadata: {} } }, - } as any, - ]) + }, + ] as never) spyOn(MCP, "tools").mockResolvedValue({}) spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { - if (name === "tool.execute.before") { - ;(output as { args: unknown }).args = { __patched: true } + if (name === "tool.execute.before" && typeof output === "object" && output !== null && "args" in output) { + Object.assign(output, { args: { __patched: true } }) } return output }) @@ -113,7 +134,8 @@ describe("session.prompt tool hooks", () => { await tools.native_demo.execute(undefined, { toolCallId: "call-native", abortSignal: new AbortController().signal, - } as any) + messages: [], + }) expect(captured).toEqual({ __patched: true }) }) @@ -137,13 +159,13 @@ describe("session.prompt tool hooks", () => { metadata: {}, } }, - } as any, - }) + }, + } as never) spyOn(Permission, "ask").mockResolvedValue(undefined) spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { - if (name === "tool.execute.before") { - ;(output as { args: unknown }).args = { __patched: true } + if (name === "tool.execute.before" && typeof output === "object" && output !== null && "args" in output) { + Object.assign(output, { args: { __patched: true } }) } return output }) @@ -153,7 +175,8 @@ describe("session.prompt tool hooks", () => { await tools.openkitten_echo.execute(undefined, { toolCallId: "call-mcp", abortSignal: new AbortController().signal, - } as any) + messages: [], + }) expect(captured).toEqual({ __patched: true }) })