diff --git a/packages/opencode/src/provider/models.ts b/packages/opencode/src/provider/models.ts index 2d787588b0b5..f73c0c4c5621 100644 --- a/packages/opencode/src/provider/models.ts +++ b/packages/opencode/src/provider/models.ts @@ -13,6 +13,19 @@ import { Hash } from "@/util/hash" // Falls back to undefined in dev mode when snapshot doesn't exist /* @ts-ignore */ +// Cache format types for prompt caching +export const CacheFormat = z.enum(["anthropic", "openrouter", "bedrock", "openaiCompatible"]) +export type CacheFormat = z.infer + +export const Caching = z.union([ + z.boolean(), + z.object({ + format: CacheFormat.optional(), + positions: z.array(z.enum(["system", "first", "last"])).optional(), + }), +]) +export type Caching = z.infer + export namespace ModelsDev { const log = Log.create({ service: "models.dev" }) const source = url() @@ -94,6 +107,8 @@ export namespace ModelsDev { .optional(), status: z.enum(["alpha", "beta", "deprecated"]).optional(), provider: z.object({ npm: z.string().optional(), api: z.string().optional() }).optional(), + variants: z.record(z.string(), z.record(z.string(), z.any())).optional(), + caching: Caching.optional(), }) export type Model = z.infer diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 004fb77f91a0..96446c820b61 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -9,8 +9,8 @@ import { Npm } from "../npm" import { Hash } from "../util/hash" import { Plugin } from "../plugin" import { NamedError } from "@opencode-ai/util/error" +import { ModelsDev, Caching } from "./models" import { type LanguageModelV3 } from "@ai-sdk/provider" -import { ModelsDev } from "./models" import { Auth } from "../auth" import { Env } from "../env" import { Instance } from "../project/instance" @@ -882,6 +882,7 @@ export namespace Provider { headers: z.record(z.string(), z.string()), release_date: z.string(), variants: z.record(z.string(), z.record(z.string(), z.any())).optional(), + caching: Caching.optional(), }) .meta({ ref: "Model", @@ -991,6 +992,7 @@ export namespace Provider { }, release_date: model.release_date, variants: {}, + caching: model.caching, } m.variants = mapValues(ProviderTransform.variants(m), (v) => v) diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index 111832099216..3fddff9f4e61 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -190,9 +190,50 @@ export namespace ProviderTransform { } function applyCaching(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] { - const system = msgs.filter((msg) => msg.role === "system").slice(0, 2) - const final = msgs.filter((msg) => msg.role !== "system").slice(-2) + // Determine cache format from model.caching config or infer from provider + const npm = model.api.npm + const providerID = model.providerID + + // Get format from explicit config or infer from provider + let format: "anthropic" | "openrouter" | "bedrock" | "openaiCompatible" | undefined + if (model.caching && typeof model.caching === "object" && model.caching.format) { + format = model.caching.format + } else if (npm === "@ai-sdk/amazon-bedrock" || providerID.includes("bedrock")) { + format = "bedrock" + } else if (npm === "@ai-sdk/anthropic" || providerID === "anthropic") { + format = "anthropic" + } else if (npm === "@openrouter/ai-sdk-provider" || providerID === "openrouter") { + format = "openrouter" + } else { + // Default to openaiCompatible for other providers (kiro-gateway, etc.) + format = "openaiCompatible" + } + + // Determine positions to cache + let positions: ("system" | "first" | "last")[] = ["system", "last"] + if (model.caching && typeof model.caching === "object" && model.caching.positions) { + positions = model.caching.positions + } + + // Select messages to cache based on positions + const messagesToCache: ModelMessage[] = [] + const systemMsgs = msgs.filter((msg) => msg.role === "system") + const nonSystemMsgs = msgs.filter((msg) => msg.role !== "system") + + if (positions.includes("system")) { + messagesToCache.push(...systemMsgs.slice(0, 2)) + } + if (positions.includes("first") && nonSystemMsgs.length > 0) { + messagesToCache.push(nonSystemMsgs[0]) + } + if (positions.includes("last") && nonSystemMsgs.length > 0) { + const lastMsg = nonSystemMsgs[nonSystemMsgs.length - 1] + if (!messagesToCache.includes(lastMsg)) { + messagesToCache.push(lastMsg) + } + } + // Build provider options for all formats (SDK will pick the right one) const providerOptions = { anthropic: { cacheControl: { type: "ephemeral" }, @@ -211,11 +252,11 @@ export namespace ProviderTransform { }, } - for (const msg of unique([...system, ...final])) { - const useMessageLevelOptions = - model.providerID === "anthropic" || - model.providerID.includes("bedrock") || - model.api.npm === "@ai-sdk/amazon-bedrock" + // Determine if we should use message-level or content-level options + // Anthropic and Bedrock use message-level, others use content-level + const useMessageLevelOptions = format === "anthropic" || format === "bedrock" + + for (const msg of unique(messagesToCache)) { const shouldUseContentOptions = !useMessageLevelOptions && Array.isArray(msg.content) && msg.content.length > 0 if (shouldUseContentOptions) { @@ -278,15 +319,36 @@ export namespace ProviderTransform { export function message(msgs: ModelMessage[], model: Provider.Model, options: Record) { msgs = unsupportedParts(msgs, model) msgs = normalizeMessages(msgs, model, options) - if ( - (model.providerID === "anthropic" || - model.providerID === "google-vertex-anthropic" || - model.api.id.includes("anthropic") || + + // Apply caching when: + // 1. Explicitly enabled via model.caching (true or object) + // 2. Auto-detected: Anthropic models (not gateway) + // 3. Auto-detected: Amazon Bedrock models that support caching (Claude, Nova) + const isAnthropic = model.providerID === "anthropic" || + model.providerID === "google-vertex-anthropic" || + model.api.id.includes("anthropic") || + model.api.id.includes("claude") || + model.id.includes("anthropic") || + model.id.includes("claude") || + model.api.npm === "@ai-sdk/anthropic" + const isBedrock = model.api.npm === "@ai-sdk/amazon-bedrock" || model.providerID.includes("bedrock") + const isBedrockCacheEligible = isBedrock && ( + // Explicit caching option (true or false via model.options.caching) + (model.options?.caching === true) || + // Explicit caching option via model.caching + (model.caching === true) || + // Auto-detect Claude/Nova models (unless caching is explicitly disabled) + (model.options?.caching !== false && ( model.api.id.includes("claude") || - model.id.includes("anthropic") || + model.api.id.includes("nova") || model.id.includes("claude") || - model.api.npm === "@ai-sdk/anthropic") && - model.api.npm !== "@ai-sdk/gateway" + model.id.includes("nova") + )) + ) + + if (model.caching + || (isAnthropic && model.api.npm !== "@ai-sdk/gateway" && model.options?.caching !== false) + || isBedrockCacheEligible ) { msgs = applyCaching(msgs, model) } diff --git a/packages/opencode/test/provider/transform.test.ts b/packages/opencode/test/provider/transform.test.ts index 0aee396f44a3..8532845245e2 100644 --- a/packages/opencode/test/provider/transform.test.ts +++ b/packages/opencode/test/provider/transform.test.ts @@ -1623,38 +1623,120 @@ describe("ProviderTransform.message - providerOptions key remapping", () => { }) }) -describe("ProviderTransform.message - claude w/bedrock custom inference profile", () => { - test("adds cachePoint", () => { - const model = { - id: "amazon-bedrock/custom-claude-sonnet-4.5", - providerID: "amazon-bedrock", +describe("ProviderTransform.message - bedrock prompt caching", () => { + const createBedrockModel = (apiId: string, providerID = "amazon-bedrock") => + ({ + id: `${providerID}/${apiId}`, + providerID, api: { - id: "arn:aws:bedrock:xxx:yyy:application-inference-profile/zzz", - url: "https://api.test.com", + id: apiId, + url: "https://bedrock.amazonaws.com", npm: "@ai-sdk/amazon-bedrock", }, - name: "Custom inference profile", + name: apiId, capabilities: {}, options: {}, headers: {}, - } as any + }) as any - const msgs = [ - { - role: "user", - content: "Hello", - }, - ] as any[] + test("Claude models on Bedrock get prompt caching", () => { + const model = createBedrockModel("anthropic.claude-3-5-sonnet-20241022-v2:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + test("Amazon Nova models get prompt caching", () => { + const model = createBedrockModel("amazon.nova-pro-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) - expect(result[0].providerOptions?.bedrock).toEqual( - expect.objectContaining({ - cachePoint: { - type: "default", - }, - }), - ) + test("Nova models with nova- prefix get prompt caching", () => { + const model = createBedrockModel("nova-lite-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Llama models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("meta.llama3-70b-instruct-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Mistral models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("mistral.mistral-large-2402-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Cohere models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("cohere.command-r-plus-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Custom ARN with Claude in name gets prompt caching", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-claude-finetune") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Custom ARN without Claude in name does NOT get prompt caching", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-llama-model") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Cross-region inference profiles with Claude get prompt caching", () => { + const model = createBedrockModel("us.anthropic.claude-3-5-sonnet-20241022-v2:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Application inference profile gets prompt caching when Claude-based", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:application-inference-profile/my-claude-profile") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Application inference profile with options.caching=true gets prompt caching", () => { + const model = { + ...createBedrockModel("arn:aws:bedrock:eu-west-1:995555607786:application-inference-profile/bzg00wo23901"), + options: { caching: true }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Custom ARN with options.caching=true gets prompt caching", () => { + const model = { + ...createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-custom-model"), + options: { caching: true }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Claude model with options.caching=false does NOT get prompt caching", () => { + const model = { + ...createBedrockModel("anthropic.claude-3-5-sonnet-20241022-v2:0"), + options: { caching: false }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() }) })