diff --git a/packages/app/src/components/prompt-input.tsx b/packages/app/src/components/prompt-input.tsx index 3ee8f43513be..a79129b0f58f 100644 --- a/packages/app/src/components/prompt-input.tsx +++ b/packages/app/src/components/prompt-input.tsx @@ -237,13 +237,14 @@ export const PromptInput: Component = (props) => { ) const [store, setStore] = createStore<{ - popover: "at" | "slash" | null + popover: "at" | "slash" | "at:model" | null historyIndex: number savedPrompt: PromptHistoryEntry | null placeholder: number draggingType: "image" | "@mention" | null mode: "normal" | "shell" applyingHistory: boolean + agentForModel: string | undefined }>({ popover: null, historyIndex: -1, @@ -252,6 +253,7 @@ export const PromptInput: Component = (props) => { draggingType: null, mode: "normal", applyingHistory: false, + agentForModel: undefined, }) const buttonsSpring = useSpring(() => (store.mode === "normal" ? 1 : 0), { visualDuration: 0.2, bounce: 0 }) @@ -509,18 +511,50 @@ export const PromptInput: Component = (props) => { ) const agentNames = createMemo(() => local.agent.list().map((agent) => agent.name)) + const modelList = createMemo(() => { + if (!store.agentForModel) return [] + return providers.connected().flatMap((provider) => + Object.entries(provider.models) + .filter(([_, info]) => info.status !== "deprecated") + .map( + ([modelId, info]): AtOption => ({ + type: "model", + providerID: provider.id, + modelID: modelId, + modelName: info.name ?? modelId, + providerName: provider.name, + display: `${provider.id}/${modelId}`, + }), + ), + ) + }) + const handleAtSelect = (option: AtOption | undefined) => { if (!option) return if (option.type === "agent") { addPart({ type: "agent", name: option.name, content: "@" + option.name, start: 0, end: 0 }) - } else { + } else if (option.type === "model" && store.agentForModel) { + const agentName = store.agentForModel + const content = `@${agentName}:${option.providerID}/${option.modelID}` + addPart({ + type: "agent", + name: agentName, + model: { providerID: option.providerID, modelID: option.modelID }, + content, + start: 0, + end: 0, + }) + setStore("agentForModel", undefined) + } else if (option.type === "file") { addPart({ type: "file", path: option.path, content: "@" + option.path, start: 0, end: 0 }) } } const atKey = (x: AtOption | undefined) => { if (!x) return "" - return x.type === "agent" ? `agent:${x.name}` : `file:${x.path}` + if (x.type === "agent") return `agent:${x.name}` + if (x.type === "model") return `model:${x.providerID}/${x.modelID}` + return `file:${x.path}` } const { @@ -531,6 +565,10 @@ export const PromptInput: Component = (props) => { onKeyDown: atOnKeyDown, } = useFilteredList({ items: async (query) => { + // When in model selection mode, return models instead of agents/files + if (store.agentForModel) { + return modelList() + } const agents = agentList() const open = recent() const seen = new Set(open) @@ -545,7 +583,8 @@ export const PromptInput: Component = (props) => { filterKeys: ["display"], groupBy: (item) => { if (item.type === "agent") return "agent" - if (item.recent) return "recent" + if (item.type === "model") return "model" + if (item.type === "file" && item.recent) return "recent" return "file" }, sortGroupsBy: (a, b) => { @@ -618,7 +657,13 @@ export const PromptInput: Component = (props) => { pill.textContent = part.content pill.setAttribute("data-type", part.type) if (part.type === "file") pill.setAttribute("data-path", part.path) - if (part.type === "agent") pill.setAttribute("data-name", part.name) + if (part.type === "agent") { + pill.setAttribute("data-name", part.name) + if (part.model) { + pill.setAttribute("data-model-provider", part.model.providerID) + pill.setAttribute("data-model-id", part.model.modelID) + } + } pill.setAttribute("contenteditable", "false") pill.style.userSelect = "text" pill.style.cursor = "default" @@ -674,7 +719,7 @@ export const PromptInput: Component = (props) => { }) const selectPopoverActive = () => { - if (store.popover === "at") { + if (store.popover === "at" || store.popover === "at:model") { const items = atFlat() if (items.length === 0) return const active = atActive() @@ -746,9 +791,13 @@ export const PromptInput: Component = (props) => { const pushAgent = (agent: HTMLElement) => { const content = agent.textContent ?? "" + const providerID = agent.dataset.modelProvider + const modelID = agent.dataset.modelId + const model = providerID && modelID ? { providerID, modelID } : undefined parts.push({ type: "agent", name: agent.dataset.name!, + model, content, start: position, end: position + content.length, @@ -824,20 +873,89 @@ export const PromptInput: Component = (props) => { const shellMode = store.mode === "shell" if (!shellMode) { + // Check for @agentname : pattern (space before colon) and collapse it + const spaceColonMatch = rawText.substring(0, cursorPosition).match(/@(\S+) :$/) + if (spaceColonMatch) { + const agentName = spaceColonMatch[1] + const validAgent = sync.data.agent.find( + (a) => !a.hidden && a.mode !== "primary" && a.name.toLowerCase() === agentName.toLowerCase(), + ) + if (validAgent) { + // Remove the space before the colon by manipulating the DOM + const selection = window.getSelection() + if (selection && selection.rangeCount > 0) { + const range = selection.getRangeAt(0) + // Move back 2 positions (past the colon and space), then delete the space + const spacePos = cursorPosition - 2 + setRangeEdge(editorRef, range, "start", spacePos) + setRangeEdge(editorRef, range, "end", spacePos + 1) + range.deleteContents() + // Position cursor after the colon (which is now at spacePos) + const newCursorPos = spacePos + 1 + setRangeEdge(editorRef, range, "start", newCursorPos) + setRangeEdge(editorRef, range, "end", newCursorPos) + range.collapse(true) + selection.removeAllRanges() + selection.addRange(range) + // Re-parse to get updated parts + const updatedParts = parseFromDOM() + // Trigger model selection mode + setStore("agentForModel", validAgent.name) + setStore("popover", "at:model") + atOnInput("") + // Update prompt with new cursor position + mirror.input = true + prompt.set([...updatedParts, ...images], newCursorPos) + queueScroll() + return + } + } + } + const atMatch = rawText.substring(0, cursorPosition).match(/@(\S*)$/) const slashMatch = rawText.match(/^\/(\S*)$/) if (atMatch) { - atOnInput(atMatch[1]) - setStore("popover", "at") + const afterAt = atMatch[1] + // Check if the user typed @agent: to enter model selection mode + // Match against known agent names followed by colon + const validAgent = sync.data.agent.find( + (a) => !a.hidden && a.mode !== "primary" && afterAt.toLowerCase().startsWith(a.name.toLowerCase() + ":"), + ) + if (validAgent) { + // Extract the part after agent: for filtering models + const colonIndex = + afterAt.toLowerCase().indexOf(validAgent.name.toLowerCase() + ":") + validAgent.name.length + 1 + const modelFilter = afterAt.slice(colonIndex) + setStore("agentForModel", validAgent.name) + setStore("popover", "at:model") + atOnInput(modelFilter) + } else if (store.agentForModel) { + // Check if user deleted the colon + if (!afterAt.includes(":")) { + setStore("agentForModel", undefined) + setStore("popover", "at") + atOnInput(afterAt) + } else { + // Still in model mode, extract model filter + const colonIdx = afterAt.indexOf(":") + const modelFilter = afterAt.slice(colonIdx + 1) + atOnInput(modelFilter) + } + } else { + setStore("popover", "at") + atOnInput(afterAt) + } } else if (slashMatch) { slashOnInput(slashMatch[1]) setStore("popover", "slash") } else { - closePopover() + setStore("popover", null) + setStore("agentForModel", undefined) } } else { - closePopover() + setStore("popover", null) + setStore("agentForModel", undefined) } resetHistoryNavigation() @@ -1091,7 +1209,7 @@ export const PromptInput: Component = (props) => { const nav = event.key === "ArrowUp" || event.key === "ArrowDown" || event.key === "Enter" const ctrlNav = ctrl && (event.key === "n" || event.key === "p") if (nav || ctrlNav) { - if (store.popover === "at") { + if (store.popover === "at" || store.popover === "at:model") { atOnKeyDown(event) event.preventDefault() return @@ -1106,7 +1224,8 @@ export const PromptInput: Component = (props) => { if (ctrl && event.code === "KeyG") { if (store.popover) { - closePopover() + setStore("popover", null) + setStore("agentForModel", undefined) event.preventDefault() return } @@ -1157,6 +1276,7 @@ export const PromptInput: Component = (props) => { onSlashSelect={handleSlashSelect} commandKeybind={command.keybind} t={(key) => language.t(key as Parameters[0])} + agentForModel={store.agentForModel} /> void atFlat: AtOption[] atActive?: string @@ -31,6 +34,7 @@ type PromptPopoverProps = { onSlashSelect: (item: SlashCommand) => void commandKeybind: (id: string) => string | undefined t: (key: string) => string + agentForModel?: string } export const PromptPopover: Component = (props) => { @@ -69,6 +73,8 @@ export const PromptPopover: Component = (props) => { ) } + if (item.type === "model") return null + const isDirectory = item.path.endsWith("/") const directory = isDirectory ? item.path : getDirectory(item.path) const filename = isDirectory ? "" : getFilename(item.path) @@ -93,6 +99,42 @@ export const PromptPopover: Component = (props) => { + + 0} + fallback={
{props.t("prompt.popover.emptyResults")}
} + > + + {(item) => { + const model = item as { + type: "model" + providerID: string + modelID: string + modelName: string + providerName: string + display: string + } + return ( + + ) + }} + +
+
0} diff --git a/packages/app/src/context/prompt.tsx b/packages/app/src/context/prompt.tsx index fb8226559113..0f8b5fafa5f4 100644 --- a/packages/app/src/context/prompt.tsx +++ b/packages/app/src/context/prompt.tsx @@ -25,6 +25,10 @@ export interface FileAttachmentPart extends PartBase { export interface AgentPart extends PartBase { type: "agent" name: string + model?: { + providerID: string + modelID: string + } } export interface ImageAttachmentPart { @@ -67,7 +71,12 @@ function isPartEqual(partA: ContentPart, partB: ContentPart) { case "file": return partB.type === "file" && partA.path === partB.path && isSelectionEqual(partA.selection, partB.selection) case "agent": - return partB.type === "agent" && partA.name === partB.name + return ( + partB.type === "agent" && + partA.name === partB.name && + partA.model?.providerID === partB.model?.providerID && + partA.model?.modelID === partB.model?.modelID + ) case "image": return partB.type === "image" && partA.id === partB.id } @@ -89,7 +98,7 @@ function cloneSelection(selection?: FileSelection) { function clonePart(part: ContentPart): ContentPart { if (part.type === "text") return { ...part } if (part.type === "image") return { ...part } - if (part.type === "agent") return { ...part } + if (part.type === "agent") return { ...part, model: part.model ? { ...part.model } : undefined } return { ...part, selection: cloneSelection(part.selection), diff --git a/packages/opencode/src/cli/cmd/tui/component/prompt/autocomplete.tsx b/packages/opencode/src/cli/cmd/tui/component/prompt/autocomplete.tsx index 3240afab326a..bbc8d4a319a7 100644 --- a/packages/opencode/src/cli/cmd/tui/component/prompt/autocomplete.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/prompt/autocomplete.tsx @@ -49,7 +49,7 @@ function extractLineRange(input: string) { export type AutocompleteRef = { onInput: (value: string) => void onKeyDown: (e: KeyEvent) => void - visible: false | "@" | "/" + visible: false | "@" | "/" | "@:model" } export type AutocompleteOption = { @@ -73,6 +73,7 @@ export function Autocomplete(props: { ref: (ref: AutocompleteRef) => void fileStyleId: number agentStyleId: number + modelStyleId: number promptPartTypeId: () => number }) { const sdk = useSDK() @@ -87,6 +88,7 @@ export function Autocomplete(props: { selected: 0, visible: false as AutocompleteRef["visible"], input: "keyboard" as "keyboard" | "mouse", + agentForModel: undefined as string | undefined, }) const [positionTick, setPositionTick] = createSignal(0) @@ -168,50 +170,88 @@ export function Autocomplete(props: { const extmarkStart = store.index const extmarkEnd = extmarkStart + Bun.stringWidth(virtualText) - const styleId = part.type === "file" ? props.fileStyleId : part.type === "agent" ? props.agentStyleId : undefined + if (part.type === "agent" && part.model) { + const colonIndex = text.indexOf(":") + const agentText = "@" + text.slice(0, colonIndex) + const modelText = text.slice(colonIndex) + + const agentEnd = extmarkStart + Bun.stringWidth(agentText) + const modelStart = agentEnd + const modelEnd = modelStart + Bun.stringWidth(modelText) + + const agentExtmarkId = input.extmarks.create({ + start: extmarkStart, + end: agentEnd, + virtual: true, + styleId: props.agentStyleId, + typeId: props.promptPartTypeId(), + }) - const extmarkId = input.extmarks.create({ - start: extmarkStart, - end: extmarkEnd, - virtual: true, - styleId, - typeId: props.promptPartTypeId(), - }) + const modelExtmarkId = input.extmarks.create({ + start: modelStart, + end: modelEnd, + virtual: true, + styleId: props.modelStyleId, + typeId: props.promptPartTypeId(), + }) - props.setPrompt((draft) => { - if (part.type === "file") { - const existingIndex = draft.parts.findIndex((p) => p.type === "file" && "url" in p && p.url === part.url) - if (existingIndex !== -1) { - const existing = draft.parts[existingIndex] - if ( - part.source?.text && - existing && - "source" in existing && - existing.source && - "text" in existing.source && - existing.source.text - ) { - existing.source.text.start = extmarkStart - existing.source.text.end = extmarkEnd - existing.source.text.value = virtualText + props.setPrompt((draft) => { + if (part.source) { + part.source.start = extmarkStart + part.source.end = extmarkEnd + part.source.value = virtualText + } + const partIndex = draft.parts.length + draft.parts.push(part) + props.setExtmark(partIndex, agentExtmarkId) + props.setExtmark(partIndex, modelExtmarkId) + }) + } else { + const styleId = part.type === "file" ? props.fileStyleId : part.type === "agent" ? props.agentStyleId : undefined + + const extmarkId = input.extmarks.create({ + start: extmarkStart, + end: extmarkEnd, + virtual: true, + styleId, + typeId: props.promptPartTypeId(), + }) + + props.setPrompt((draft) => { + if (part.type === "file") { + const existingIndex = draft.parts.findIndex((p) => p.type === "file" && "url" in p && p.url === part.url) + if (existingIndex !== -1) { + const existing = draft.parts[existingIndex] + if ( + part.source?.text && + existing && + "source" in existing && + existing.source && + "text" in existing.source && + existing.source.text + ) { + existing.source.text.start = extmarkStart + existing.source.text.end = extmarkEnd + existing.source.text.value = virtualText + } + return } - return } - } - if (part.type === "file" && part.source?.text) { - part.source.text.start = extmarkStart - part.source.text.end = extmarkEnd - part.source.text.value = virtualText - } else if (part.type === "agent" && part.source) { - part.source.start = extmarkStart - part.source.end = extmarkEnd - part.source.value = virtualText - } - const partIndex = draft.parts.length - draft.parts.push(part) - props.setExtmark(partIndex, extmarkId) - }) + if (part.type === "file" && part.source?.text) { + part.source.text.start = extmarkStart + part.source.text.end = extmarkEnd + part.source.text.value = virtualText + } else if (part.type === "agent" && part.source) { + part.source.start = extmarkStart + part.source.end = extmarkEnd + part.source.value = virtualText + } + const partIndex = draft.parts.length + draft.parts.push(part) + props.setExtmark(partIndex, extmarkId) + }) + } if (part.type === "file" && part.source && part.source.type === "file") { frecency.updateFrecency(part.source.path) @@ -221,7 +261,7 @@ export function Autocomplete(props: { const [files] = createResource( () => search(), async (query) => { - if (!store.visible || store.visible === "/") return [] + if (!store.visible || store.visible === "/" || store.agentForModel) return [] const { lineRange, baseQuery } = extractLineRange(query ?? "") @@ -353,6 +393,34 @@ export function Autocomplete(props: { ) }) + const models = createMemo(() => { + if (!store.agentForModel) return [] + const agentName = store.agentForModel + const width = props.anchor().width - 4 + + return sync.data.provider.flatMap((provider) => + Object.entries(provider.models) + .filter(([_, info]) => info.status !== "deprecated") + .map(([modelId, info]): AutocompleteOption => { + const modelRef = `${provider.id}/${modelId}` + const displayText = `@${agentName}:${modelRef}` + return { + display: Locale.truncateMiddle(displayText, width), + value: modelRef, + onSelect: () => { + insertPart(`${agentName}:${modelRef}`, { + type: "agent", + name: agentName, + model: { providerID: provider.id, modelID: modelId }, + source: { start: 0, end: 0, value: "" }, + }) + setStore("agentForModel", undefined) + }, + } + }), + ) + }) + const commands = createMemo((): AutocompleteOption[] => { const results: AutocompleteOption[] = [...command.slashes()] @@ -383,6 +451,28 @@ export function Autocomplete(props: { }) const options = createMemo((prev: AutocompleteOption[] | undefined) => { + if (store.agentForModel) { + const modelOptions = models() + const currentFilter = filter() + + if (files.loading && prev && prev.length > 0) { + return prev + } + + if (!currentFilter) return modelOptions + + const colonIndex = currentFilter.indexOf(":") + const modelFilter = colonIndex !== -1 ? currentFilter.slice(colonIndex + 1) : "" + + if (!modelFilter) return modelOptions + + const result = fuzzysort.go(modelFilter, modelOptions, { + keys: [(obj) => (obj.value ?? obj.display).trimEnd(), "description"], + limit: 10, + }) + return result.map((arr) => arr.obj) + } + const filesValue = files() const agentsValue = agents() const commandsValue = commands() @@ -488,13 +578,13 @@ export function Autocomplete(props: { if (store.visible === "/" && !text.endsWith(" ") && text.startsWith("/")) { const cursor = props.input().logicalCursor props.input().deleteRange(0, 0, cursor.row, cursor.col) - // Sync the prompt store immediately since onContentChange is async props.setPrompt((draft) => { draft.input = props.input().plainText }) } command.keybinds(true) setStore("visible", false) + setStore("agentForModel", undefined) } onMount(() => { @@ -504,31 +594,85 @@ export function Autocomplete(props: { }, onInput(value) { if (store.visible) { + if (store.agentForModel) { + const textAfterTrigger = value.slice(store.index + 1, props.input().cursorOffset) + if (!textAfterTrigger.includes(":")) { + setStore("agentForModel", undefined) + setStore("visible", "@") + } + } + if ( - // Typed text before the trigger props.input().cursorOffset <= store.index || - // There is a space between the trigger and the cursor props.input().getTextRange(store.index, props.input().cursorOffset).match(/\s/) || - // "/" is not the sole content (store.visible === "/" && value.match(/^\S+\s+\S+\s*$/)) ) { hide() + return + } + + if (store.visible === "@" && !store.agentForModel) { + const offset = props.input().cursorOffset + const textAfterTrigger = value.slice(store.index + 1, offset) + // Match against known agent names followed by colon to handle agents with colons in their names + const validAgent = sync.data.agent.find( + (a) => + !a.hidden && + a.mode !== "primary" && + textAfterTrigger.toLowerCase().startsWith(a.name.toLowerCase() + ":"), + ) + if (validAgent) { + setStore("agentForModel", validAgent.name) + setStore("visible", "@:model") + } } return } - // Check if autocomplete should reopen (e.g., after backspace deleted a space) const offset = props.input().cursorOffset if (offset === 0) return - // Check for "/" at position 0 - reopen slash commands + // Collapse space before colon when typing `:` after `@agentname ` + // This allows ergonomic model override: select agent (adds space), then type `:` to override model + const charBeforeCursor = offset > 0 ? value[offset - 1] : undefined + if (charBeforeCursor === ":" && offset >= 3) { + const charBeforeColon = value[offset - 2] + if (charBeforeColon === " ") { + const textBeforeSpace = value.slice(0, offset - 2) + const atIdx = textBeforeSpace.lastIndexOf("@") + if (atIdx !== -1) { + const agentName = textBeforeSpace.slice(atIdx + 1) + const validAgent = sync.data.agent.find( + (a) => !a.hidden && a.mode !== "primary" && a.name.toLowerCase() === agentName.toLowerCase(), + ) + if (validAgent) { + // Delete the space before the colon + const input = props.input() + const spacePos = offset - 2 + input.cursorOffset = spacePos + const startCursor = input.logicalCursor + input.cursorOffset = spacePos + 1 + const endCursor = input.logicalCursor + input.deleteRange(startCursor.row, startCursor.col, endCursor.row, endCursor.col) + // After deletion, colon is now at spacePos, so position cursor after it + input.cursorOffset = spacePos + 1 + // Trigger model autocomplete + show("@") + setStore("index", atIdx) + setStore("agentForModel", validAgent.name) + setStore("visible", "@:model") + return + } + } + } + } + if (value.startsWith("/") && !value.slice(0, offset).match(/\s/)) { show("/") setStore("index", 0) return } - // Check for "@" trigger - find the nearest "@" before cursor with no whitespace between const text = value.slice(0, offset) const idx = text.lastIndexOf("@") if (idx === -1) return @@ -538,6 +682,17 @@ export function Autocomplete(props: { if ((before === undefined || /\s/.test(before)) && !between.match(/\s/)) { show("@") setStore("index", idx) + + const textAfterAt = between.slice(1) + // Match against known agent names followed by colon to handle agents with colons in their names + const validAgent = sync.data.agent.find( + (a) => + !a.hidden && a.mode !== "primary" && textAfterAt.toLowerCase().startsWith(a.name.toLowerCase() + ":"), + ) + if (validAgent) { + setStore("agentForModel", validAgent.name) + setStore("visible", "@:model") + } } }, onKeyDown(e: KeyEvent) { diff --git a/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx b/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx index c85426cc2471..27e557ba6215 100644 --- a/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx @@ -94,6 +94,7 @@ export function Prompt(props: PromptProps) { const fileStyleId = syntax().getStyleId("extmark.file")! const agentStyleId = syntax().getStyleId("extmark.agent")! + const modelStyleId = syntax().getStyleId("extmark.model")! const pasteStyleId = syntax().getStyleId("extmark.paste")! let promptPartTypeId = 0 @@ -814,6 +815,7 @@ export function Prompt(props: PromptProps) { value={store.prompt.input} fileStyleId={fileStyleId} agentStyleId={agentStyleId} + modelStyleId={modelStyleId} promptPartTypeId={() => promptPartTypeId} /> (anchor = r)} visible={props.visible !== false}> diff --git a/packages/opencode/src/cli/cmd/tui/context/theme.tsx b/packages/opencode/src/cli/cmd/tui/context/theme.tsx index 2320c08ccc6e..c9445d932622 100644 --- a/packages/opencode/src/cli/cmd/tui/context/theme.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/theme.tsx @@ -675,6 +675,13 @@ function getSyntaxRules(theme: Theme) { bold: true, }, }, + { + scope: ["extmark.model"], + style: { + foreground: theme.textMuted, + bold: true, + }, + }, { scope: ["extmark.paste"], style: { diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 2b8aa9e03010..7175cddc4f7f 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -648,6 +648,7 @@ export namespace Config { list: PermissionRule.optional(), bash: PermissionRule.optional(), task: PermissionRule.optional(), + model: PermissionRule.optional(), external_directory: PermissionRule.optional(), todowrite: PermissionAction.optional(), todoread: PermissionAction.optional(), diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 90abf54526a7..9d0b1031e318 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -185,6 +185,12 @@ export namespace MessageV2 { export const AgentPart = PartBase.extend({ type: z.literal("agent"), name: z.string(), + model: z + .object({ + providerID: z.string(), + modelID: z.string(), + }) + .optional(), source: z .object({ value: z.string(), diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index b8be93b6be00..9502db001163 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -403,6 +403,7 @@ export namespace SessionPrompt { description: task.description, subagent_type: task.agent, command: task.command, + model: task.model, } await Plugin.trigger( "tool.execute.before", @@ -421,7 +422,7 @@ export namespace SessionPrompt { sessionID: sessionID, abort, callID: part.callID, - extra: { bypassAgentCheck: true }, + extra: { bypassAgentCheck: true, bypassModelCheck: task.model !== undefined }, messages: msgs, async metadata(input) { await Session.updatePart({ @@ -600,6 +601,8 @@ export namespace SessionPrompt { // Check if user explicitly invoked an agent via @ in this turn const lastUserMsg = msgs.findLast((m) => m.info.role === "user") const bypassAgentCheck = lastUserMsg?.parts.some((p) => p.type === "agent") ?? false + // Check if user explicitly included a model override in @agent:provider/model syntax + const bypassModelCheck = lastUserMsg?.parts.some((p) => p.type === "agent" && p.model !== undefined) ?? false const tools = await resolveTools({ agent, @@ -608,6 +611,7 @@ export namespace SessionPrompt { tools: lastUser.tools, processor, bypassAgentCheck, + bypassModelCheck, messages: msgs, }) @@ -746,6 +750,7 @@ export namespace SessionPrompt { tools?: Record processor: SessionProcessor.Info bypassAgentCheck: boolean + bypassModelCheck: boolean messages: MessageV2.WithParts[] }) { using _ = log.time("resolveTools") @@ -756,7 +761,7 @@ export namespace SessionPrompt { abort: options.abortSignal!, messageID: input.processor.message.id, callID: options.toolCallId, - extra: { model: input.model, bypassAgentCheck: input.bypassAgentCheck }, + extra: { model: input.model, bypassAgentCheck: input.bypassAgentCheck, bypassModelCheck: input.bypassModelCheck }, agent: input.agent.name, messages: input.messages, metadata: async (val: { title?: string; metadata?: any }) => { @@ -1269,6 +1274,7 @@ export namespace SessionPrompt { // Check if this agent would be denied by task permission const perm = PermissionNext.evaluate("task", part.name, agent.permission) const hint = perm.action === "deny" ? " . Invoked by user; guaranteed to exist." : "" + const modelHint = part.model ? ` Use model override: ${part.model.providerID}/${part.model.modelID}.` : "" return [ { ...part, @@ -1280,11 +1286,10 @@ export namespace SessionPrompt { sessionID: input.sessionID, type: "text", synthetic: true, - // An extra space is added here. Otherwise the 'Use' gets appended - // to user's last word; making a combined word text: " Use the above message and context to generate a prompt and call the task tool with subagent: " + part.name + + modelHint + hint, }, ] diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index 68e44eb97e48..9ac039b048ba 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -11,6 +11,7 @@ import { iife } from "@/util/iife" import { defer } from "@/util/defer" import { Config } from "../config/config" import { PermissionNext } from "@/permission/next" +import { Provider } from "../provider/provider" const parameters = z.object({ description: z.string().describe("A short (3-5 words) description of the task"), @@ -23,10 +24,18 @@ const parameters = z.object({ ) .optional(), command: z.string().describe("The command that triggered this task").optional(), + model: z + .object({ + providerID: z.string(), + modelID: z.string(), + }) + .describe("Override the model for this task") + .optional(), }) export const TaskTool = Tool.define("task", async (ctx) => { const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary")) + const providers = await Provider.list() // Filter agents by permissions if agent provided const caller = ctx?.agent @@ -34,12 +43,22 @@ export const TaskTool = Tool.define("task", async (ctx) => { ? agents.filter((a) => PermissionNext.evaluate("task", a.name, caller.permission).action !== "deny") : agents + // Build models list from configured providers, filtered by caller permissions + const allModels = Object.entries(providers).flatMap(([providerID, provider]) => + Object.entries(provider.models) + .filter(([_, info]) => info.status !== "deprecated") + .map(([modelID]) => `${providerID}/${modelID}`), + ) + const models = caller + ? allModels.filter((m) => PermissionNext.evaluate("model", m, caller.permission).action !== "deny") + : allModels + const description = DESCRIPTION.replace( "{agents}", accessibleAgents .map((a) => `- ${a.name}: ${a.description ?? "This subagent should only be called manually by the user."}`) .join("\n"), - ) + ).replace("{models}", models.join(", ")) return { description, parameters, @@ -104,9 +123,24 @@ export const TaskTool = Tool.define("task", async (ctx) => { const msg = await MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID }) if (msg.info.role !== "assistant") throw new Error("Not an assistant message") - const model = agent.model ?? { - modelID: msg.info.modelID, - providerID: msg.info.providerID, + const model = params.model ?? + agent.model ?? { + modelID: msg.info.modelID, + providerID: msg.info.providerID, + } + + // Check model permission when LLM explicitly overrides the model + if (params.model && !ctx.extra?.bypassModelCheck) { + const modelPattern = `${params.model.providerID}/${params.model.modelID}` + await ctx.ask({ + permission: "model", + patterns: [modelPattern], + always: ["*"], + metadata: { + providerID: params.model.providerID, + modelID: params.model.modelID, + }, + }) } ctx.metadata({ diff --git a/packages/opencode/src/tool/task.txt b/packages/opencode/src/tool/task.txt index 585cce8f9d0a..02fac2787bda 100644 --- a/packages/opencode/src/tool/task.txt +++ b/packages/opencode/src/tool/task.txt @@ -3,7 +3,10 @@ Launch a new agent to handle complex, multistep tasks autonomously. Available agent types and the tools they have access to: {agents} -When using the Task tool, you must specify a subagent_type parameter to select which agent type to use. +When using the Task tool, you must specify a subagent_type parameter to select which agent type to use. You can optionally specify a model parameter to override the default model for this task. + +Available models for the model parameter: +{models} When to use the Task tool: - When you are instructed to execute custom slash commands. Use the Task tool with the slash command invocation as the entire prompt. The slash command can take arguments. For example: Task(description="Check the file", prompt="/check-file path/to/file.py") diff --git a/packages/opencode/test/permission-task.test.ts b/packages/opencode/test/permission-task.test.ts index 3d592a3d981a..bdafdaf3c01e 100644 --- a/packages/opencode/test/permission-task.test.ts +++ b/packages/opencode/test/permission-task.test.ts @@ -317,3 +317,139 @@ describe("permission.task with real config files", () => { }) }) }) + +describe("PermissionNext.evaluate for permission.model", () => { + const createRuleset = (rules: Record): PermissionNext.Ruleset => + Object.entries(rules).map(([pattern, action]) => ({ + permission: "model", + pattern, + action, + })) + + test("returns ask when no match (default)", () => { + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", []).action).toBe("ask") + }) + + test("returns deny for explicit deny", () => { + const ruleset = createRuleset({ "anthropic/claude-opus-4-20250514": "deny" }) + expect(PermissionNext.evaluate("model", "anthropic/claude-opus-4-20250514", ruleset).action).toBe("deny") + }) + + test("returns allow for explicit allow", () => { + const ruleset = createRuleset({ "anthropic/claude-sonnet-4-20250514": "allow" }) + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + }) + + test("matches provider wildcard patterns", () => { + const ruleset = createRuleset({ "anthropic/*": "allow", "openai/*": "deny" }) + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "anthropic/claude-opus-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "openai/gpt-4o", ruleset).action).toBe("deny") + expect(PermissionNext.evaluate("model", "openai/o1", ruleset).action).toBe("deny") + }) + + test("matches model name wildcard patterns", () => { + const ruleset = createRuleset({ "*/claude-opus-*": "deny", "*/o1*": "ask" }) + expect(PermissionNext.evaluate("model", "anthropic/claude-opus-4-20250514", ruleset).action).toBe("deny") + expect(PermissionNext.evaluate("model", "openai/o1", ruleset).action).toBe("ask") + expect(PermissionNext.evaluate("model", "openai/o1-preview", ruleset).action).toBe("ask") + }) + + test("later rules take precedence (last match wins)", () => { + const ruleset = createRuleset({ + "anthropic/*": "deny", + "anthropic/claude-sonnet-4-20250514": "allow", + }) + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "anthropic/claude-opus-4-20250514", ruleset).action).toBe("deny") + }) + + test("matches global wildcard", () => { + expect(PermissionNext.evaluate("model", "any/model", createRuleset({ "*": "allow" })).action).toBe("allow") + expect(PermissionNext.evaluate("model", "any/model", createRuleset({ "*": "deny" })).action).toBe("deny") + expect(PermissionNext.evaluate("model", "any/model", createRuleset({ "*": "ask" })).action).toBe("ask") + }) +}) + +describe("permission.model with real config files", () => { + test("loads model permissions from opencode.json config", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + permission: { + model: { + "*": "allow", + "anthropic/claude-opus-4-*": "deny", + }, + }, + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const config = await Config.get() + const ruleset = PermissionNext.fromConfig(config.permission ?? {}) + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "openai/gpt-4o", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "anthropic/claude-opus-4-20250514", ruleset).action).toBe("deny") + }, + }) + }) + + test("loads model permissions with provider wildcards from config", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + permission: { + model: { + "*": "deny", + "anthropic/*": "allow", + }, + }, + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const config = await Config.get() + const ruleset = PermissionNext.fromConfig(config.permission ?? {}) + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "openai/gpt-4o", ruleset).action).toBe("deny") + }, + }) + }) + + test("mixed permission config with model and other tools", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + permission: { + bash: "allow", + task: { + "*": "allow", + }, + model: { + "*": "allow", + "openai/o1*": "ask", + }, + }, + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const config = await Config.get() + const ruleset = PermissionNext.fromConfig(config.permission ?? {}) + + // Verify model permissions + expect(PermissionNext.evaluate("model", "anthropic/claude-sonnet-4-20250514", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("model", "openai/o1", ruleset).action).toBe("ask") + expect(PermissionNext.evaluate("model", "openai/o1-preview", ruleset).action).toBe("ask") + + // Verify other tool permissions still work + expect(PermissionNext.evaluate("bash", "*", ruleset).action).toBe("allow") + expect(PermissionNext.evaluate("task", "general", ruleset).action).toBe("allow") + }, + }) + }) +}) diff --git a/packages/opencode/test/session/message-v2.test.ts b/packages/opencode/test/session/message-v2.test.ts index 1a7c75c05f87..98ce21f313d9 100644 --- a/packages/opencode/test/session/message-v2.test.ts +++ b/packages/opencode/test/session/message-v2.test.ts @@ -894,3 +894,118 @@ describe("session.message-v2.fromError", () => { }) }) }) + +describe("session.message-v2.AgentPart", () => { + test("parses AgentPart with model field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "agent", + name: "explore", + model: { + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }, + } + const result = MessageV2.AgentPart.parse(part) + expect(result.model).toEqual({ + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }) + }) + + test("parses AgentPart without model field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "agent", + name: "explore", + } + const result = MessageV2.AgentPart.parse(part) + expect(result.model).toBeUndefined() + }) + + test("parses AgentPart with source field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "agent", + name: "explore", + model: { + providerID: "openai", + modelID: "gpt-4o", + }, + source: { + value: "@explore:openai/gpt-4o", + start: 0, + end: 22, + }, + } + const result = MessageV2.AgentPart.parse(part) + expect(result.name).toBe("explore") + expect(result.model).toEqual({ + providerID: "openai", + modelID: "gpt-4o", + }) + expect(result.source).toEqual({ + value: "@explore:openai/gpt-4o", + start: 0, + end: 22, + }) + }) + + test("rejects AgentPart with invalid model field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "agent", + name: "explore", + model: { + providerID: "anthropic", + // missing modelID + }, + } + expect(() => MessageV2.AgentPart.parse(part)).toThrow() + }) +}) + +describe("session.message-v2.SubtaskPart", () => { + test("parses SubtaskPart with model field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "subtask", + prompt: "do something", + description: "task desc", + agent: "explore", + model: { + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }, + } + const result = MessageV2.SubtaskPart.parse(part) + expect(result.model).toEqual({ + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }) + }) + + test("parses SubtaskPart without model field", () => { + const part = { + id: "prt_000000000001", + sessionID: "ses_000000000001", + messageID: "msg_000000000001", + type: "subtask", + prompt: "do something", + description: "task desc", + agent: "explore", + } + const result = MessageV2.SubtaskPart.parse(part) + expect(result.model).toBeUndefined() + }) +}) diff --git a/packages/opencode/test/tool/task.test.ts b/packages/opencode/test/tool/task.test.ts new file mode 100644 index 000000000000..2c5e9299e71d --- /dev/null +++ b/packages/opencode/test/tool/task.test.ts @@ -0,0 +1,173 @@ +import { describe, expect, test } from "bun:test" +import path from "path" +import { TaskTool } from "../../src/tool/task" +import { Instance } from "../../src/project/instance" +import { tmpdir } from "../fixture/fixture" +import { PermissionNext } from "../../src/permission/next" +import { Agent } from "../../src/agent/agent" + +describe("tool.task", () => { + test("task tool schema includes model parameter", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await TaskTool.init() + const schema = tool.parameters + + // Verify the model parameter exists and is optional + const shape = schema.shape + expect(shape.model).toBeDefined() + + // Verify the model schema structure - unwrap optional + const modelShape = shape.model._def.innerType.shape + expect(modelShape.providerID).toBeDefined() + expect(modelShape.modelID).toBeDefined() + }, + }) + }) + + test("task tool description includes available models list", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + provider: { + anthropic: { + id: "anthropic", + }, + }, + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await TaskTool.init() + + // Verify the description does not contain the placeholder + expect(tool.description).not.toContain("{models}") + + // Verify the description mentions model parameter usage + expect(tool.description).toContain("model parameter") + expect(tool.description).toContain("Available models") + }, + }) + }) + + test("task tool accepts valid model parameter", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await TaskTool.init() + + // Valid input with model should parse successfully + const result = tool.parameters.safeParse({ + description: "test task", + prompt: "do something", + subagent_type: "explore", + model: { + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }, + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.model).toEqual({ + providerID: "anthropic", + modelID: "claude-sonnet-4-20250514", + }) + } + }, + }) + }) + + test("task tool accepts input without model parameter", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await TaskTool.init() + + // Valid input without model should parse successfully + const result = tool.parameters.safeParse({ + description: "test task", + prompt: "do something", + subagent_type: "explore", + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.model).toBeUndefined() + } + }, + }) + }) + + test("task tool rejects invalid model parameter", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await TaskTool.init() + + // Invalid model (missing modelID) should fail validation + const result = tool.parameters.safeParse({ + description: "test task", + prompt: "do something", + subagent_type: "explore", + model: { + providerID: "anthropic", + // missing modelID + }, + }) + expect(result.success).toBe(false) + }, + }) + }) + + test("task tool filters models list based on agent permissions", async () => { + await using tmp = await tmpdir({ + git: true, + config: { + provider: { + anthropic: { + id: "anthropic", + }, + }, + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + // Create a mock agent with model permissions that deny all models + const agentWithAllModelsDenied: Agent.Info = { + name: "test-agent", + mode: "primary", + permission: [{ permission: "model", pattern: "*", action: "deny" }], + options: {}, + } + + // Initialize with the agent context - all models denied + const toolDenied = await TaskTool.init({ agent: agentWithAllModelsDenied }) + + // When all models are denied, the models list should be empty + // Use regex to handle both \n and \r\n line endings (Windows CI) + expect(toolDenied.description).toMatch(/Available models for the model parameter:\r?\n\r?\n/) + + // Create a mock agent with no model restrictions + const agentWithNoRestrictions: Agent.Info = { + name: "test-agent", + mode: "primary", + permission: [], + options: {}, + } + + // Initialize with no restrictions - should have models + const toolAllowed = await TaskTool.init({ agent: agentWithNoRestrictions }) + + // With no restrictions, there should be models (not empty after the header) + expect(toolAllowed.description).not.toMatch(/Available models for the model parameter:\r?\n\r?\n/) + expect(toolAllowed.description).toContain("Available models for the model parameter:") + }, + }) + }) +}) diff --git a/packages/sdk/js/src/v2/gen/types.gen.ts b/packages/sdk/js/src/v2/gen/types.gen.ts index 30568c96df12..003d28d4db47 100644 --- a/packages/sdk/js/src/v2/gen/types.gen.ts +++ b/packages/sdk/js/src/v2/gen/types.gen.ts @@ -480,6 +480,10 @@ export type AgentPart = { messageID: string type: "agent" name: string + model?: { + providerID: string + modelID: string + } source?: { value: string start: number @@ -1058,6 +1062,7 @@ export type PermissionConfig = list?: PermissionRuleConfig bash?: PermissionRuleConfig task?: PermissionRuleConfig + model?: PermissionRuleConfig external_directory?: PermissionRuleConfig todowrite?: PermissionActionConfig todoread?: PermissionActionConfig @@ -1743,6 +1748,10 @@ export type AgentPartInput = { id?: string type: "agent" name: string + model?: { + providerID: string + modelID: string + } source?: { value: string start: number diff --git a/packages/sdk/openapi.json b/packages/sdk/openapi.json index a5e10f0d735f..d2eb7ce5cad6 100644 --- a/packages/sdk/openapi.json +++ b/packages/sdk/openapi.json @@ -8302,6 +8302,18 @@ "name": { "type": "string" }, + "model": { + "type": "object", + "properties": { + "providerID": { + "type": "string" + }, + "modelID": { + "type": "string" + } + }, + "required": ["providerID", "modelID"] + }, "source": { "type": "object", "properties": { @@ -11453,6 +11465,18 @@ "name": { "type": "string" }, + "model": { + "type": "object", + "properties": { + "providerID": { + "type": "string" + }, + "modelID": { + "type": "string" + } + }, + "required": ["providerID", "modelID"] + }, "source": { "type": "object", "properties": {