diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index a9edf838ca8c..c572ed8be5cc 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -826,19 +826,21 @@ export namespace SessionPrompt { description: item.description, inputSchema: jsonSchema(schema as any), async execute(args, options) { - const ctx = context(args, options) + const hookOutput = { + args: args ?? {}, + } await Plugin.trigger( "tool.execute.before", { tool: item.id, - sessionID: ctx.sessionID, - callID: ctx.callID, - }, - { - args, + sessionID: input.session.id, + callID: options.toolCallId, }, + hookOutput, ) - const result = await item.execute(args, ctx) + const nextArgs = hookOutput.args + const ctx = context(nextArgs, options) + const result = await item.execute(nextArgs, ctx) const output = { ...result, attachments: result.attachments?.map((attachment) => ({ @@ -854,7 +856,7 @@ export namespace SessionPrompt { tool: item.id, sessionID: ctx.sessionID, callID: ctx.callID, - args, + args: nextArgs, }, output, ) @@ -872,20 +874,23 @@ export namespace SessionPrompt { item.inputSchema = jsonSchema(transformed) // Wrap execute to add plugin hooks and format output item.execute = async (args, opts) => { - const ctx = context(args, opts) + const hookOutput = { + args: args ?? {}, + } await Plugin.trigger( "tool.execute.before", { tool: key, - sessionID: ctx.sessionID, + sessionID: input.session.id, callID: opts.toolCallId, }, - { - args, - }, + hookOutput, ) + const nextArgs = hookOutput.args + const ctx = context(nextArgs, opts) + await ctx.ask({ permission: key, metadata: {}, @@ -893,7 +898,7 @@ export namespace SessionPrompt { always: ["*"], }) - const result = await execute(args, opts) + const result = await execute(nextArgs, opts) await Plugin.trigger( "tool.execute.after", @@ -901,7 +906,7 @@ export namespace SessionPrompt { tool: key, sessionID: ctx.sessionID, callID: opts.toolCallId, - args, + args: nextArgs, }, result, ) diff --git a/packages/opencode/test/session/tool-hooks.test.ts b/packages/opencode/test/session/tool-hooks.test.ts new file mode 100644 index 000000000000..589c9d8982e4 --- /dev/null +++ b/packages/opencode/test/session/tool-hooks.test.ts @@ -0,0 +1,183 @@ +import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" +import { jsonSchema } from "ai" +import z from "zod" +import type { Agent } from "../../src/agent/agent" +import { MCP } from "../../src/mcp" +import { Permission } from "../../src/permission" +import { Plugin } from "../../src/plugin" +import { ProjectID } from "../../src/project/schema" +import type { Provider } from "../../src/provider/provider" +import { ProviderID } from "../../src/provider/schema" +import type { MessageV2 } from "../../src/session/message-v2" +import type { SessionProcessor } from "../../src/session/processor" +import { SessionPrompt } from "../../src/session/prompt" +import { MessageID, SessionID } from "../../src/session/schema" +import type { Session } from "../../src/session" +import { ToolRegistry } from "../../src/tool/registry" +import { Truncate } from "../../src/tool/truncate" + +afterEach(() => { + mock.restore() +}) + +function createModel(): Provider.Model { + return { + id: "test-model", + providerID: ProviderID.make("test"), + name: "Test", + limit: { + context: 100_000, + output: 32_000, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + capabilities: { + toolcall: true, + attachment: false, + reasoning: false, + temperature: true, + input: { text: true, image: false, audio: false, video: false }, + output: { text: true, image: false, audio: false, video: false }, + }, + api: { + id: "test-model", + npm: "@ai-sdk/openai", + }, + options: {}, + } as Provider.Model +} + +function input(): Parameters<(typeof SessionPrompt)["resolveTools"]>[0] { + const model = createModel() + const agent: Agent.Info = { + name: "build", + mode: "primary", + permission: [], + options: {}, + } + const session: Session.Info = { + id: SessionID.descending(), + slug: "tool-hooks", + projectID: ProjectID.global, + directory: "/", + title: "Tool hooks", + version: "1", + time: { + created: Date.now(), + updated: Date.now(), + }, + permission: [], + } + const message: MessageV2.Assistant = { + id: MessageID.ascending(), + sessionID: SessionID.descending(), + role: "assistant", + time: { created: Date.now() }, + parentID: MessageID.ascending(), + mode: "build", + agent: "build", + path: { cwd: "/", root: "/" }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.id, + providerID: model.providerID, + finish: "end_turn", + } + const processor: SessionProcessor.Info = { + message, + partFromToolCall() { + return undefined + }, + process: async () => "stop", + } + + return { + agent, + model, + session, + processor, + bypassAgentCheck: false, + messages: [], + } +} + +describe("session.prompt tool hooks", () => { + test("tool.execute.before can replace undefined args for native tools", async () => { + let captured: unknown + + spyOn(ToolRegistry, "tools").mockResolvedValue([ + { + id: "native_demo", + description: "Native demo", + parameters: z.object({}), + execute: async (args: unknown) => { + captured = args + return { title: "", output: "ok", metadata: {} } + }, + }, + ] as never) + spyOn(MCP, "tools").mockResolvedValue({}) + spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) + spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { + if (name === "tool.execute.before" && typeof output === "object" && output !== null && "args" in output) { + Object.assign(output, { args: { __patched: true } }) + } + return output + }) + + const tools = await SessionPrompt.resolveTools(input()) + if (!tools.native_demo?.execute) throw new Error("native_demo execute missing") + await tools.native_demo.execute(undefined, { + toolCallId: "call-native", + abortSignal: new AbortController().signal, + messages: [], + }) + + expect(captured).toEqual({ __patched: true }) + }) + + test("tool.execute.before can replace undefined args for MCP tools", async () => { + let captured: unknown + + spyOn(ToolRegistry, "tools").mockResolvedValue([]) + spyOn(MCP, "tools").mockResolvedValue({ + openkitten_echo: { + description: "MCP demo", + inputSchema: jsonSchema({ + type: "object", + properties: {}, + additionalProperties: false, + }), + execute: async (args: unknown) => { + captured = args + return { + content: [{ type: "text", text: "ok" }], + metadata: {}, + } + }, + }, + } as never) + spyOn(Permission, "ask").mockResolvedValue(undefined) + spyOn(Truncate, "output").mockResolvedValue({ content: "ok", truncated: false }) + spyOn(Plugin, "trigger").mockImplementation(async (name, _input, output) => { + if (name === "tool.execute.before" && typeof output === "object" && output !== null && "args" in output) { + Object.assign(output, { args: { __patched: true } }) + } + return output + }) + + const tools = await SessionPrompt.resolveTools(input()) + if (!tools.openkitten_echo?.execute) throw new Error("openkitten_echo execute missing") + await tools.openkitten_echo.execute(undefined, { + toolCallId: "call-mcp", + abortSignal: new AbortController().signal, + messages: [], + }) + + expect(captured).toEqual({ __patched: true }) + }) +})