diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index dca8085c5b2e..76a7c918aa49 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -424,7 +424,14 @@ export namespace SessionPrompt { sessionID: sessionID, abort, callID: part.callID, - extra: { bypassAgentCheck: true }, + // Reuse the subtask model resolved upstream instead of re-resolving against the agent. + extra: { + bypassAgentCheck: true, + preferredModel: { + providerID: taskModel.providerID, + modelID: taskModel.id, + }, + }, messages: msgs, async metadata(input) { part = (await Session.updatePart({ @@ -1834,13 +1841,13 @@ NOTE: At any point in time through this workflow you should feel free to ask the if (command.model) { return Provider.parseModel(command.model) } + if (input.model) return Provider.parseModel(input.model) if (command.agent) { const cmdAgent = await Agent.get(command.agent) if (cmdAgent?.model) { return cmdAgent.model } } - if (input.model) return Provider.parseModel(input.model) return await lastModel(input.sessionID) })() diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index e3781126d0c1..b6394f7df620 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -105,7 +105,14 @@ export const TaskTool = Tool.define("task", async (ctx) => { const msg = await MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID }) if (msg.info.role !== "assistant") throw new Error("Not an assistant message") - const model = agent.model ?? { + const preferredModel = ctx.extra?.preferredModel as + | { + providerID: MessageV2.Assistant["providerID"] + modelID: MessageV2.Assistant["modelID"] + } + | undefined + + const model = preferredModel ?? agent.model ?? { modelID: msg.info.modelID, providerID: msg.info.providerID, } diff --git a/packages/opencode/test/session/prompt.test.ts b/packages/opencode/test/session/prompt.test.ts index 3986271dab96..853f9feef970 100644 --- a/packages/opencode/test/session/prompt.test.ts +++ b/packages/opencode/test/session/prompt.test.ts @@ -1,11 +1,12 @@ import path from "path" -import { describe, expect, test } from "bun:test" +import { describe, expect, spyOn, test } from "bun:test" import { fileURLToPath } from "url" import { Instance } from "../../src/project/instance" import { ModelID, ProviderID } from "../../src/provider/schema" import { Session } from "../../src/session" import { MessageV2 } from "../../src/session/message-v2" import { SessionPrompt } from "../../src/session/prompt" +import { MessageID } from "../../src/session/schema" import { Log } from "../../src/util/log" import { tmpdir } from "../fixture/fixture" @@ -179,6 +180,10 @@ describe("session.prompt agent variant", () => { parts: [{ type: "text", text: "hello" }], }) if (other.info.role !== "user") throw new Error("expected user message") + expect(other.info.model).toEqual({ + providerID: ProviderID.make("opencode"), + modelID: ModelID.make("kimi-k2.5-free"), + }) expect(other.info.variant).toBeUndefined() const match = await SessionPrompt.prompt({ @@ -210,3 +215,81 @@ describe("session.prompt agent variant", () => { } }) }) + +describe("session.command model precedence", () => { + test("preserves explicit model on subtask commands before execution continues", async () => { + const prev = process.env.OPENAI_API_KEY + process.env.OPENAI_API_KEY = "test-openai-key" + + try { + await using tmp = await tmpdir({ + git: true, + config: { + agent: { + general: { + model: "openai/gpt-5.2", + }, + }, + command: { + delegated: { + agent: "general", + subtask: true, + template: "delegate this task", + }, + }, + }, + }) + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const explicitModel = "opencode/kimi-k2.5-free" + const expectedModel = { + providerID: ProviderID.make("opencode"), + modelID: ModelID.make("kimi-k2.5-free"), + } + const stop = new Error("stop after user message") + const originalUpdateMessage = Session.updateMessage + let userMessageID: MessageID | undefined + + const updateSpy = spyOn(Session, "updateMessage").mockImplementation( + (async (message: any) => { + if (message.role === "user") { + userMessageID = message.id + return await originalUpdateMessage(message) + } + throw stop + }) as any, + ) + + try { + await SessionPrompt.command({ + sessionID: session.id, + command: "delegated", + arguments: "", + model: explicitModel, + }) + throw new Error("expected command execution to stop early") + } catch (error) { + expect(error).toBe(stop) + } finally { + updateSpy.mockRestore() + } + + if (!userMessageID) throw new Error("expected user message to be created") + + const stored = await MessageV2.get({ sessionID: session.id, messageID: userMessageID }) + const subtask = stored.parts.find((part) => part.type === "subtask") + if (!subtask || subtask.type !== "subtask") throw new Error("expected subtask part") + expect(subtask.model).toEqual(expectedModel) + + await Session.remove(session.id) + }, + }) + } finally { + if (prev === undefined) delete process.env.OPENAI_API_KEY + else process.env.OPENAI_API_KEY = prev + } + }) +}) diff --git a/packages/opencode/test/tool/task.test.ts b/packages/opencode/test/tool/task.test.ts index aae48a30ab3f..7629dfd10208 100644 --- a/packages/opencode/test/tool/task.test.ts +++ b/packages/opencode/test/tool/task.test.ts @@ -1,6 +1,10 @@ -import { afterEach, describe, expect, test } from "bun:test" +import { afterEach, describe, expect, spyOn, test } from "bun:test" import { Agent } from "../../src/agent/agent" import { Instance } from "../../src/project/instance" +import { ModelID, ProviderID } from "../../src/provider/schema" +import { Session } from "../../src/session" +import { SessionPrompt } from "../../src/session/prompt" +import { MessageID, PartID } from "../../src/session/schema" import { TaskTool } from "../../src/tool/task" import { tmpdir } from "../fixture/fixture" @@ -46,4 +50,114 @@ describe("tool.task", () => { }, }) }) + + test("prefers the resolved subtask model over the subagent fallback model", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + agent: { + general: { + model: "openai/gpt-5.2", + }, + }, + }, + }) + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const explicitModel = { + providerID: ProviderID.make("opencode"), + modelID: ModelID.make("kimi-k2.5-free"), + } + + const userMessage = await Session.updateMessage({ + id: MessageID.ascending(), + sessionID: session.id, + role: "user", + time: { + created: Date.now(), + }, + agent: "build", + model: explicitModel, + }) + + const assistantMessage = await Session.updateMessage({ + id: MessageID.ascending(), + sessionID: session.id, + parentID: userMessage.id, + role: "assistant", + mode: "build", + agent: "build", + path: { + cwd: Instance.directory, + root: Instance.worktree, + }, + cost: 0, + time: { + created: Date.now(), + }, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: explicitModel.modelID, + providerID: explicitModel.providerID, + }) + + let capturedPrompt: Parameters[0] | undefined + + const promptSpy = spyOn(SessionPrompt, "prompt").mockImplementation( + (async (input: any) => { + capturedPrompt = input + return { + info: assistantMessage, + parts: [ + { + id: PartID.ascending(), + messageID: assistantMessage.id, + sessionID: assistantMessage.sessionID, + type: "text", + text: "done", + }, + ], + } + }) as any, + ) + + try { + const tool = await TaskTool.init() + await tool.execute( + { + description: "delegate", + prompt: "delegate this task", + subagent_type: "general", + }, + { + sessionID: session.id, + messageID: assistantMessage.id, + agent: "build", + abort: new AbortController().signal, + extra: { + bypassAgentCheck: true, + preferredModel: explicitModel, + }, + messages: [], + metadata() {}, + ask: async () => {}, + }, + ) + } finally { + promptSpy.mockRestore() + } + + expect(capturedPrompt?.model).toEqual(explicitModel) + + await Session.remove(session.id) + }, + }) + }) })