diff --git a/packages/opencode/src/provider/model-detection.ts b/packages/opencode/src/provider/model-detection.ts new file mode 100644 index 000000000000..0dc65fd3fcd1 --- /dev/null +++ b/packages/opencode/src/provider/model-detection.ts @@ -0,0 +1,63 @@ +import z from "zod" +import { iife } from "@/util/iife" +import { Log } from "@/util/log" +import { Provider } from "./provider" + +export namespace ProviderModelDetection { + export async function detect(provider: Provider.Info): Promise { + const log = Log.create({ service: "provider.model-detection" }) + + const model = Object.values(provider.models)[0] + const providerNPM = model?.api?.npm ?? "@ai-sdk/openai-compatible" + const providerBaseURL = provider.options["baseURL"] ?? model?.api?.url ?? "" + + const detectedModels = await iife(async () => { + try { + if (providerNPM === "@ai-sdk/openai-compatible" && providerBaseURL) { + log.info("using OpenAI-compatible method", { providerID: provider.id }) + return await ProviderModelDetection.OpenAICompatible.listModels(providerBaseURL, provider) + } + } catch (error) { + log.warn(`failed to detect models\n${error}`, { providerID: provider.id }) + } + }) + + if (!detectedModels || detectedModels.length === 0) return + + log.info("detected models", { providerID: provider.id, count: detectedModels.length }) + return detectedModels + } +} + +export namespace ProviderModelDetection.OpenAICompatible { + const OpenAICompatibleResponse = z.object({ + object: z.string(), + data: z.array( + z.object({ + id: z.string(), + object: z.string().optional(), + created: z.number().optional(), + owned_by: z.string().optional(), + }), + ), + }) + type OpenAICompatibleResponse = z.infer + + export async function listModels(baseURL: string, provider: Provider.Info): Promise { + const fetchFn = provider.options["fetch"] ?? fetch + const apiKey = provider.options["apiKey"] ?? provider.key ?? "" + const headers = new Headers() + if (apiKey) headers.append("Authorization", `Bearer ${apiKey}`) + + const res = await fetchFn(`${baseURL}/models`, { + headers, + signal: AbortSignal.timeout(3 * 1000), + }) + if (!res.ok) throw new Error(`bad http status ${res.status}`) + const parsed = OpenAICompatibleResponse.parse(await res.json()) + + return parsed.data + .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed")) + .map((model) => model.id) + } +} diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index a02a017e77ce..909f7b6e4b03 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -9,7 +9,7 @@ import { BunProc } from "../bun" import { Plugin } from "../plugin" import { ModelsDev } from "./models" import { NamedError } from "@opencode-ai/util/error" -import { Auth } from "../auth" +import { Auth, OAUTH_DUMMY_KEY } from "../auth" import { Env } from "../env" import { Instance } from "../project/instance" import { Flag } from "../flag/flag" @@ -694,11 +694,52 @@ export namespace Provider { source: "custom", name: provider.name, env: provider.env ?? [], - options: {}, + options: { + ...(provider.api && { baseURL: provider.api }), + }, models: mapValues(provider.models, (model) => fromModelsDevModel(provider, model)), } } + const ModelsList = z.object({ + object: z.string(), + data: z.array( + z + .object({ + id: z.string(), + object: z.string().optional(), + created: z.number().optional(), + owned_by: z.string().optional(), + }) + .catchall(z.any()), + ), + }) + type ModelsList = z.infer + + async function listModels(provider: Info) { + const baseURL = provider.options["baseURL"] + const fetchFn = (provider.options["fetch"] as typeof fetch) ?? fetch + const apiKey = provider.options["apiKey"] ?? provider.key ?? "" + const headers = new Headers() + if (apiKey && apiKey !== OAUTH_DUMMY_KEY) headers.append("Authorization", `Bearer ${apiKey}`) + const models = await fetchFn(`${baseURL}/models`, { + headers, + signal: AbortSignal.timeout(3 * 1000), + }) + .then(async (resp) => { + if (!resp.ok) return + return ModelsList.parse(await resp.json()) + }) + .catch((err) => { + log.error(`Failed to fetch models from: ${baseURL}/models`, { error: err }) + }) + if (!models) return + + return models.data + .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed")) + .map((model) => model.id) + } + const state = Instance.state(async () => { using _ = log.time("state") const config = await Config.get() @@ -930,6 +971,20 @@ export namespace Provider { mergeProvider(providerID, partial) } + // detect models and prune invalid ones + await Promise.all( + Object.values(providers).map(async (provider) => { + const detected = await listModels(provider) + if (!detected) return + const detectedSet = new Set(detected) + for (const modelID of Object.keys(provider.models)) { + if (!detectedSet.has(modelID)) delete provider.models[modelID] + } + // TODO: add detected models not present in config/models.dev + // for (const modelID of detected) {} + }), + ) + for (const [providerID, provider] of Object.entries(providers)) { if (!isProviderAllowed(providerID)) { delete providers[providerID]