From 03549249feb19e5417bdab5cd62ea8fc015b8c70 Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Wed, 29 Oct 2025 15:18:41 -0500 Subject: [PATCH] feat: Add provider filtering support to router models backend Allows frontend to request specific subset of router models instead of fetching all providers. This significantly reduces payload sizes and memory usage when only specific providers are needed. - Honor message.values.providers filter in requestRouterModels handler - Fetch only requested providers when filter is present - Maintain backward compatibility with existing aggregate behavior - Add comprehensive test coverage for filtering logic --- ...webviewMessageHandler.routerModels.spec.ts | 167 ++++++++++++++++++ src/core/webview/webviewMessageHandler.ts | 81 +++++---- 2 files changed, 211 insertions(+), 37 deletions(-) create mode 100644 src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts new file mode 100644 index 00000000000..7954dc14a26 --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -0,0 +1,167 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" + +// Mock vscode (minimal) +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + showWarningMessage: vi.fn(), + showInformationMessage: vi.fn(), + }, + workspace: { + workspaceFolders: undefined, + getConfiguration: vi.fn(() => ({ + get: vi.fn(), + update: vi.fn(), + })), + }, + env: { + clipboard: { writeText: vi.fn() }, + openExternal: vi.fn(), + }, + commands: { + executeCommand: vi.fn(), + }, + Uri: { + parse: vi.fn((s: string) => ({ toString: () => s })), + file: vi.fn((p: string) => ({ fsPath: p })), + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3, + }, +})) + +// Mock modelCache getModels/flushModels used by the handler +const getModelsMock = vi.fn() +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: (...args: any[]) => getModelsMock(...args), + flushModels: vi.fn(), +})) + +describe("webviewMessageHandler - requestRouterModels providers filter", () => { + let mockProvider: ClineProvider & { + postMessageToWebview: ReturnType + getState: ReturnType + contextProxy: any + log: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + + mockProvider = { + // Only methods used by this code path + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }), + contextProxy: { + getValue: vi.fn(), + setValue: vi.fn(), + globalStorageUri: { fsPath: "/mock/storage" }, + }, + log: vi.fn(), + } as any + + // Default mock: return distinct model maps per provider so we can verify keys + getModelsMock.mockImplementation(async (options: any) => { + switch (options?.provider) { + case "roo": + return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } } + case "openrouter": + return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } } + case "requesty": + return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } } + case "deepinfra": + return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } } + case "glama": + return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } } + case "unbound": + return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } } + case "vercel-ai-gateway": + return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } } + case "io-intelligence": + return { "io/model": { contextWindow: 8192, supportsPromptCache: false } } + case "litellm": + return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } } + default: + return {} + } + }) + }) + + it("fetches only requested provider when values.providers is present (['roo'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["roo"] }, + } as any, + ) + + // Should post a single routerModels message + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }), + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const payload = call[0] + const routerModels = payload.routerModels as Record> + + // Only "roo" key should be present + const keys = Object.keys(routerModels) + expect(keys).toEqual(["roo"]) + expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet") + + // getModels should have been called exactly once for roo + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["roo"]) + }) + + it("defaults to aggregate fetching when no providers filter is sent", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + + // Aggregate handler initializes many known routers - ensure a few expected keys exist + expect(routerModels).toHaveProperty("openrouter") + expect(routerModels).toHaveProperty("roo") + expect(routerModels).toHaveProperty("requesty") + }) + + it("supports filtering another single provider (['openrouter'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["openrouter"] }, + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + const keys = Object.keys(routerModels) + + expect(keys).toEqual(["openrouter"]) + expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5") + + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["openrouter"]) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index e32b818a96e..4a149593554 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -757,20 +757,38 @@ export const webviewMessageHandler = async ( case "requestRouterModels": const { apiConfiguration } = await provider.getState() - const routerModels: Record = { - openrouter: {}, - "vercel-ai-gateway": {}, - huggingface: {}, - litellm: {}, - deepinfra: {}, - "io-intelligence": {}, - requesty: {}, - unbound: {}, - glama: {}, - ollama: {}, - lmstudio: {}, - roo: {}, - } + // Optional providers filter coming from the webview + const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined + const requestedProviders = providersFilterRaw + ?.filter((p: unknown) => typeof p === "string") + .map((p: string) => { + try { + return toRouterName(p) + } catch { + return undefined + } + }) + .filter((p): p is RouterName => !!p) + + const hasFilter = !!requestedProviders && requestedProviders.length > 0 + const requestedSet = new Set(requestedProviders || []) + + const routerModels: Record = hasFilter + ? ({} as Record) + : { + openrouter: {}, + "vercel-ai-gateway": {}, + huggingface: {}, + litellm: {}, + deepinfra: {}, + "io-intelligence": {}, + requesty: {}, + unbound: {}, + glama: {}, + ollama: {}, + lmstudio: {}, + roo: {}, + } const safeGetModels = async (options: GetModelsOptions): Promise => { try { @@ -785,7 +803,8 @@ export const webviewMessageHandler = async ( } } - const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [ + // Base candidates (only those handled by this aggregate fetcher) + const candidates: { key: RouterName; options: GetModelsOptions }[] = [ { key: "openrouter", options: { provider: "openrouter" } }, { key: "requesty", @@ -818,29 +837,28 @@ export const webviewMessageHandler = async ( }, ] - // Add IO Intelligence if API key is provided. - const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey - - if (ioIntelligenceApiKey) { - modelFetchPromises.push({ + // IO Intelligence is conditional on api key + if (apiConfiguration.ioIntelligenceApiKey) { + candidates.push({ key: "io-intelligence", - options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey }, + options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey }, }) } - // Don't fetch Ollama and LM Studio models by default anymore. - // They have their own specific handlers: requestOllamaModels and requestLmStudioModels. - + // LiteLLM is conditional on baseUrl+apiKey const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { - modelFetchPromises.push({ + candidates.push({ key: "litellm", options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, }) } + // Apply providers filter (if any) + const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key))) + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -854,18 +872,7 @@ export const webviewMessageHandler = async ( if (result.status === "fulfilled") { routerModels[routerName] = result.value.models - // Ollama and LM Studio settings pages still need these events. - if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "ollamaModels", - ollamaModels: result.value.models, - }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: result.value.models, - }) - } + // Ollama and LM Studio settings pages still need these events. They are not fetched here. } else { // Handle rejection: Post a specific error message for this provider. const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)