diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index fef700268dc..409e250c635 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -1,22 +1,22 @@ -import * as path from "path" import fs from "fs/promises" +import * as path from "path" import NodeCache from "node-cache" import { safeWriteJson } from "../../../utils/safeWriteJson" import { ContextProxy } from "../../../core/config/ContextProxy" -import { getCacheDirectoryPath } from "../../../utils/storage" -import { RouterName, ModelRecord } from "../../../shared/api" +import { ModelRecord, RouterName } from "../../../shared/api" import { fileExistsAtPath } from "../../../utils/fs" +import { getCacheDirectoryPath } from "../../../utils/storage" -import { getOpenRouterModels } from "./openrouter" -import { getRequestyModels } from "./requesty" +import { GetModelsOptions } from "../../../shared/api" import { getGlamaModels } from "./glama" -import { getUnboundModels } from "./unbound" import { getLiteLLMModels } from "./litellm" -import { GetModelsOptions } from "../../../shared/api" -import { getOllamaModels } from "./ollama" import { getLMStudioModels } from "./lmstudio" +import { getOllamaModels } from "./ollama" +import { getOpenRouterModels } from "./openrouter" +import { getRequestyModels } from "./requesty" +import { getUnboundModels } from "./unbound" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index f032e2d5605..34c4edd951e 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -1,8 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import axios from "axios" -import { type ModelInfo, openAiModelInfoSaneDefaults, LMSTUDIO_DEFAULT_TEMPERATURE } from "@roo-code/types" +import { LMSTUDIO_DEFAULT_TEMPERATURE, type ModelInfo, openAiModelInfoSaneDefaults } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" @@ -11,22 +10,41 @@ import { XmlMatcher } from "../../utils/xml-matcher" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" +import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { flushModels, getModels } from "./fetchers/modelCache" + +type ModelInfoCaching = { + modelInfo: ModelInfo + cached: boolean +} export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI + private cachedModelInfo: ModelInfoCaching = { + modelInfo: openAiModelInfoSaneDefaults, + cached: false, + } + private lastRecacheTime: number = -1 constructor(options: ApiHandlerOptions) { super() this.options = options this.client = new OpenAI({ - baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1", + baseURL: this.getBaseUrl() + "/v1", apiKey: "noop", }) } + private getBaseUrl(): string { + if (this.options.lmStudioBaseUrl && this.options.lmStudioBaseUrl.trim() !== "") { + return this.options.lmStudioBaseUrl.trim() + } else { + return "http://localhost:1234" + } + } + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], @@ -118,6 +136,29 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan outputTokens = 0 } + if ( + !this.cachedModelInfo.cached && + (this.lastRecacheTime < 0 || Date.now() - this.lastRecacheTime > 30 * 1000) + ) { + // assume that if we didn't get a response in 30 seconds + this.lastRecacheTime = Date.now() // Update last recache time to avoid race condition + + // We need to fetch the model info every time we open a new session + // to ensure we have the latest context window and other details + // since LM Studio models can chance their context windows on reload + await flushModels("lmstudio") + const models = await getModels({ provider: "lmstudio", baseUrl: this.getBaseUrl() }) + if (models && models[this.getModel().id]) { + this.cachedModelInfo = { + modelInfo: models[this.getModel().id], + cached: true, + } + } else { + // if model info is not found, still mark the result as cached to avoid retries on every chunk + this.cachedModelInfo.cached = true + } + } + yield { type: "usage", inputTokens, @@ -133,7 +174,7 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan override getModel(): { id: string; info: ModelInfo } { return { id: this.options.lmStudioModelId || "", - info: openAiModelInfoSaneDefaults, + info: this.cachedModelInfo.modelInfo, } } @@ -161,17 +202,3 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } } } - -export async function getLmStudioModels(baseUrl = "http://localhost:1234") { - try { - if (!URL.canParse(baseUrl)) { - return [] - } - - const response = await axios.get(`${baseUrl}/v1/models`) - const modelsArray = response.data?.data?.map((model: any) => model.id) || [] - return [...new Set(modelsArray)] - } catch (error) { - return [] - } -} diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 5272c334510..ee0ec20dc2d 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task" import { safeWriteJson } from "../../../utils/safeWriteJson" import { ClineProvider } from "../ClineProvider" +import { LmStudioHandler } from "../../../api/providers" // Mock setup must come before imports vi.mock("../../prompts/sections/custom-instructions") @@ -2371,6 +2372,7 @@ describe("ClineProvider - Router Models", () => { unboundApiKey: "unbound-key", litellmApiKey: "litellm-key", litellmBaseUrl: "http://localhost:4000", + lmStudioBaseUrl: "http://localhost:1234", }, } as any) @@ -2404,6 +2406,10 @@ describe("ClineProvider - Router Models", () => { apiKey: "litellm-key", baseUrl: "http://localhost:4000", }) + expect(getModels).toHaveBeenCalledWith({ + provider: "lmstudio", + baseUrl: "http://localhost:1234", + }) // Verify response was sent expect(mockPostMessage).toHaveBeenCalledWith({ @@ -2415,7 +2421,7 @@ describe("ClineProvider - Router Models", () => { unbound: mockModels, litellm: mockModels, ollama: {}, - lmstudio: {}, + lmstudio: mockModels, }, }) }) @@ -2432,6 +2438,7 @@ describe("ClineProvider - Router Models", () => { unboundApiKey: "unbound-key", litellmApiKey: "litellm-key", litellmBaseUrl: "http://localhost:4000", + lmStudioBaseUrl: "http://localhost:1234", }, } as any) @@ -2447,6 +2454,7 @@ describe("ClineProvider - Router Models", () => { .mockResolvedValueOnce(mockModels) // glama success .mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail + .mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio fail await messageHandler({ type: "requestRouterModels" }) @@ -2492,6 +2500,13 @@ describe("ClineProvider - Router Models", () => { error: "LiteLLM connection failed", values: { provider: "litellm" }, }) + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "LMStudio API error", + values: { provider: "lmstudio" }, + }) }) test("handles requestRouterModels with LiteLLM values from message", async () => { @@ -2570,7 +2585,7 @@ describe("ClineProvider - Router Models", () => { unbound: mockModels, litellm: {}, ollama: {}, - lmstudio: {}, + lmstudio: mockModels, }, }) }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 284ee989444..82e09775eff 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -105,6 +105,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { unboundApiKey: "unbound-key", litellmApiKey: "litellm-key", litellmBaseUrl: "http://localhost:4000", + lmStudioBaseUrl: "http://localhost:1234", }, }) }) @@ -141,6 +142,10 @@ describe("webviewMessageHandler - requestRouterModels", () => { apiKey: "litellm-key", baseUrl: "http://localhost:4000", }) + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "lmstudio", + baseUrl: "http://localhost:1234", + }) // Verify response was sent expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ @@ -152,7 +157,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { unbound: mockModels, litellm: mockModels, ollama: {}, - lmstudio: {}, + lmstudio: mockModels, }, }) }) @@ -239,7 +244,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { unbound: mockModels, litellm: {}, ollama: {}, - lmstudio: {}, + lmstudio: mockModels, }, }) }) @@ -261,6 +266,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockResolvedValueOnce(mockModels) // glama .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm + .mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio")) await webviewMessageHandler(mockClineProvider, { type: "requestRouterModels", @@ -311,6 +317,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("Glama API error")) // glama .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm + .mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio await webviewMessageHandler(mockClineProvider, { type: "requestRouterModels", @@ -351,6 +358,13 @@ describe("webviewMessageHandler - requestRouterModels", () => { error: "LiteLLM connection failed", values: { provider: "litellm" }, }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "LMStudio API error", + values: { provider: "lmstudio" }, + }) }) it("prefers config values over message values for LiteLLM", async () => { diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 780d40df891..9d3c3b9058f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1,58 +1,56 @@ -import { safeWriteJson } from "../../utils/safeWriteJson" -import * as path from "path" -import * as os from "os" import * as fs from "fs/promises" +import * as os from "os" import pWaitFor from "p-wait-for" +import * as path from "path" import * as vscode from "vscode" -import * as yaml from "yaml" +import { safeWriteJson } from "../../utils/safeWriteJson" +import { CloudService } from "@roo-code/cloud" +import { TelemetryService } from "@roo-code/telemetry" import { + type ClineMessage, + type GlobalState, type Language, type ProviderSettings, - type GlobalState, - type ClineMessage, TelemetryEventName, } from "@roo-code/types" -import { CloudService } from "@roo-code/cloud" -import { TelemetryService } from "@roo-code/telemetry" import { type ApiMessage } from "../task-persistence/apiMessages" -import { ClineProvider } from "./ClineProvider" import { changeLanguage, t } from "../../i18n" +import { ModelRecord, RouterName, toRouterName } from "../../shared/api" import { Package } from "../../shared/package" -import { RouterName, toRouterName, ModelRecord } from "../../shared/api" import { supportPrompt } from "../../shared/support-prompt" +import { ClineProvider } from "./ClineProvider" -import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" -import { checkExistKey } from "../../shared/checkExistApiConfig" -import { experimentDefault } from "../../shared/experiments" -import { Terminal } from "../../integrations/terminal/Terminal" -import { openFile } from "../../integrations/misc/open-file" +import { flushModels, getModels } from "../../api/providers/fetchers/modelCache" +import { getOpenAiModels } from "../../api/providers/openai" +import { getVsCodeLmModels } from "../../api/providers/vscode-lm" import { openImage, saveImage } from "../../integrations/misc/image-handler" +import { openFile } from "../../integrations/misc/open-file" import { selectImages } from "../../integrations/misc/process-images" +import { Terminal } from "../../integrations/terminal/Terminal" import { getTheme } from "../../integrations/theme/getTheme" import { discoverChromeHostUrl, tryChromeHostUrl } from "../../services/browser/browserDiscovery" import { searchWorkspaceFiles } from "../../services/search/file-search" +import { TelemetrySetting } from "../../shared/TelemetrySetting" +import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" +import { GetModelsOptions } from "../../shared/api" +import { checkExistKey } from "../../shared/checkExistApiConfig" +import { experimentDefault } from "../../shared/experiments" +import { defaultModeSlug, Mode } from "../../shared/modes" +import { getCommand } from "../../utils/commands" import { fileExistsAtPath } from "../../utils/fs" -import { playTts, setTtsEnabled, setTtsSpeed, stopTts } from "../../utils/tts" -import { singleCompletionHandler } from "../../utils/single-completion-handler" import { searchCommits } from "../../utils/git" +import { getWorkspacePath } from "../../utils/path" +import { singleCompletionHandler } from "../../utils/single-completion-handler" +import { playTts, setTtsEnabled, setTtsSpeed, stopTts } from "../../utils/tts" import { exportSettings, importSettingsWithFeedback } from "../config/importExport" -import { getOpenAiModels } from "../../api/providers/openai" -import { getVsCodeLmModels } from "../../api/providers/vscode-lm" import { openMention } from "../mentions" -import { TelemetrySetting } from "../../shared/TelemetrySetting" -import { getWorkspacePath } from "../../utils/path" -import { ensureSettingsDirectoryExists } from "../../utils/globalContext" -import { Mode, defaultModeSlug } from "../../shared/modes" -import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" -import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" -import { getCommand } from "../../utils/commands" const ALLOWED_VSCODE_SETTINGS = new Set(["terminal.integrated.inheritEnv"]) -import { MarketplaceManager, MarketplaceItemType } from "../../services/marketplace" +import { MarketplaceItemType, MarketplaceManager } from "../../services/marketplace" import { setPendingTodoList } from "../tools/updateTodoListTool" export const webviewMessageHandler = async ( @@ -555,6 +553,12 @@ export const webviewMessageHandler = async ( }) } + const lmStudioBaseUrl = apiConfiguration.lmStudioBaseUrl || message?.values?.lmStudioBaseUrl + modelFetchPromises.push({ + key: "lmstudio", + options: { provider: "lmstudio", baseUrl: lmStudioBaseUrl }, + }) + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -564,9 +568,8 @@ export const webviewMessageHandler = async ( const fetchedRouterModels: Partial> = { ...routerModels, - // Initialize ollama and lmstudio with empty objects since they use separate handlers + // Initialize ollama with empty objects since it uses separate handlers ollama: {}, - lmstudio: {}, } results.forEach((result, index) => { @@ -575,18 +578,14 @@ export const webviewMessageHandler = async ( if (result.status === "fulfilled") { fetchedRouterModels[routerName] = result.value.models - // Ollama and LM Studio settings pages still need these events + // Ollama settings pages still need these events if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { provider.postMessageToWebview({ type: "ollamaModels", ollamaModels: Object.keys(result.value.models), }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: Object.keys(result.value.models), - }) } + // LM Studio models have moved to main router models message } else { // Handle rejection: Post a specific error message for this provider const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) @@ -633,30 +632,6 @@ export const webviewMessageHandler = async ( } break } - case "requestLmStudioModels": { - // Specific handler for LM Studio models only - const { apiConfiguration: lmStudioApiConfig } = await provider.getState() - try { - // Flush cache first to ensure fresh models - await flushModels("lmstudio") - - const lmStudioModels = await getModels({ - provider: "lmstudio", - baseUrl: lmStudioApiConfig.lmStudioBaseUrl, - }) - - if (Object.keys(lmStudioModels).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: Object.keys(lmStudioModels), - }) - } - } catch (error) { - // Silently fail - user hasn't configured LM Studio yet - console.debug("LM Studio models fetch failed:", error) - } - break - } case "requestOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { const openAiModels = await getOpenAiModels( diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 1f56829f7b3..163beb845fa 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -65,7 +65,6 @@ export interface WebviewMessage { | "requestRouterModels" | "requestOpenAiModels" | "requestOllamaModels" - | "requestLmStudioModels" | "requestVsCodeLmModels" | "openImage" | "saveImage" diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 6c6c621956c..f0d3660c441 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -193,12 +193,10 @@ const ApiOptions = ({ }, }) } else if (selectedProvider === "ollama") { - vscode.postMessage({ type: "requestOllamaModels" }) - } else if (selectedProvider === "lmstudio") { - vscode.postMessage({ type: "requestLmStudioModels" }) + vscode.postMessage({ type: "requestOllamaModels", text: apiConfiguration?.ollamaBaseUrl }) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) - } else if (selectedProvider === "litellm") { + } else if (selectedProvider === "litellm" || selectedProvider === "lmstudio") { vscode.postMessage({ type: "requestRouterModels" }) } }, diff --git a/webview-ui/src/components/settings/providers/LMStudio.tsx b/webview-ui/src/components/settings/providers/LMStudio.tsx index a907e43e1b0..74e1f1b193d 100644 --- a/webview-ui/src/components/settings/providers/LMStudio.tsx +++ b/webview-ui/src/components/settings/providers/LMStudio.tsx @@ -39,9 +39,9 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi const message: ExtensionMessage = event.data switch (message.type) { - case "lmStudioModels": + case "routerModels": { - const newModels = message.lmStudioModels ?? [] + const newModels = Object.keys(message.routerModels?.lmstudio || {}) setLmStudioModels(newModels) } break @@ -53,7 +53,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi // Refresh models on mount useEffect(() => { // Request fresh models - the handler now flushes cache automatically - vscode.postMessage({ type: "requestLmStudioModels" }) + vscode.postMessage({ type: "requestRouterModels" }) }, []) // Check if the selected model exists in the fetched models