Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import * as path from "path"
import fs from "fs/promises"
import * as path from "path"

import NodeCache from "node-cache"
import { safeWriteJson } from "../../../utils/safeWriteJson"

import { ContextProxy } from "../../../core/config/ContextProxy"
import { getCacheDirectoryPath } from "../../../utils/storage"
import { RouterName, ModelRecord } from "../../../shared/api"
import { ModelRecord, RouterName } from "../../../shared/api"
import { fileExistsAtPath } from "../../../utils/fs"
import { getCacheDirectoryPath } from "../../../utils/storage"

import { getOpenRouterModels } from "./openrouter"
import { getRequestyModels } from "./requesty"
import { GetModelsOptions } from "../../../shared/api"
import { getGlamaModels } from "./glama"
import { getUnboundModels } from "./unbound"
import { getLiteLLMModels } from "./litellm"
import { GetModelsOptions } from "../../../shared/api"
import { getOllamaModels } from "./ollama"
import { getLMStudioModels } from "./lmstudio"
import { getOllamaModels } from "./ollama"
import { getOpenRouterModels } from "./openrouter"
import { getRequestyModels } from "./requesty"
import { getUnboundModels } from "./unbound"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

Expand Down
65 changes: 46 additions & 19 deletions src/api/providers/lm-studio.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import axios from "axios"

import { type ModelInfo, openAiModelInfoSaneDefaults, LMSTUDIO_DEFAULT_TEMPERATURE } from "@roo-code/types"
import { LMSTUDIO_DEFAULT_TEMPERATURE, type ModelInfo, openAiModelInfoSaneDefaults } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

Expand All @@ -11,22 +10,41 @@ import { XmlMatcher } from "../../utils/xml-matcher"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"

import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index"
import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { flushModels, getModels } from "./fetchers/modelCache"

type ModelInfoCaching = {
modelInfo: ModelInfo
cached: boolean
}

export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
private cachedModelInfo: ModelInfoCaching = {
modelInfo: openAiModelInfoSaneDefaults,
cached: false,
}
private lastRecacheTime: number = -1

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new OpenAI({
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
baseURL: this.getBaseUrl() + "/v1",
apiKey: "noop",
})
}

private getBaseUrl(): string {
if (this.options.lmStudioBaseUrl && this.options.lmStudioBaseUrl.trim() !== "") {
return this.options.lmStudioBaseUrl.trim()
} else {
return "http://localhost:1234"
}
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
Expand Down Expand Up @@ -118,6 +136,29 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
outputTokens = 0
}

if (
!this.cachedModelInfo.cached &&
(this.lastRecacheTime < 0 || Date.now() - this.lastRecacheTime > 30 * 1000)
) {
// assume that if we didn't get a response in 30 seconds
this.lastRecacheTime = Date.now() // Update last recache time to avoid race condition

// We need to fetch the model info every time we open a new session
// to ensure we have the latest context window and other details
// since LM Studio models can chance their context windows on reload
await flushModels("lmstudio")
const models = await getModels({ provider: "lmstudio", baseUrl: this.getBaseUrl() })
if (models && models[this.getModel().id]) {
this.cachedModelInfo = {
modelInfo: models[this.getModel().id],
cached: true,
}
} else {
// if model info is not found, still mark the result as cached to avoid retries on every chunk
this.cachedModelInfo.cached = true
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the model info fetching happens after the streaming has already started. This means the first message in a conversation will still use the default context window (128,000) from openAiModelInfoSaneDefaults instead of the actual model's context window.

I know that you mentioned that LMStudio uses JIT model loading, the concern is that this initial request might overwhelm the context window if that initial request has a lot of tokens, any idea on how to handle this?


yield {
type: "usage",
inputTokens,
Expand All @@ -133,7 +174,7 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
override getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
info: this.cachedModelInfo.modelInfo,
}
}

Expand Down Expand Up @@ -161,17 +202,3 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
}
}
}

export async function getLmStudioModels(baseUrl = "http://localhost:1234") {
try {
if (!URL.canParse(baseUrl)) {
return []
}

const response = await axios.get(`${baseUrl}/v1/models`)
const modelsArray = response.data?.data?.map((model: any) => model.id) || []
return [...new Set<string>(modelsArray)]
} catch (error) {
return []
}
}
19 changes: 17 additions & 2 deletions src/core/webview/__tests__/ClineProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task"
import { safeWriteJson } from "../../../utils/safeWriteJson"

import { ClineProvider } from "../ClineProvider"
import { LmStudioHandler } from "../../../api/providers"

// Mock setup must come before imports
vi.mock("../../prompts/sections/custom-instructions")
Expand Down Expand Up @@ -2371,6 +2372,7 @@ describe("ClineProvider - Router Models", () => {
unboundApiKey: "unbound-key",
litellmApiKey: "litellm-key",
litellmBaseUrl: "http://localhost:4000",
lmStudioBaseUrl: "http://localhost:1234",
},
} as any)

Expand Down Expand Up @@ -2404,6 +2406,10 @@ describe("ClineProvider - Router Models", () => {
apiKey: "litellm-key",
baseUrl: "http://localhost:4000",
})
expect(getModels).toHaveBeenCalledWith({
provider: "lmstudio",
baseUrl: "http://localhost:1234",
})

// Verify response was sent
expect(mockPostMessage).toHaveBeenCalledWith({
Expand All @@ -2415,7 +2421,7 @@ describe("ClineProvider - Router Models", () => {
unbound: mockModels,
litellm: mockModels,
ollama: {},
lmstudio: {},
lmstudio: mockModels,
},
})
})
Expand All @@ -2432,6 +2438,7 @@ describe("ClineProvider - Router Models", () => {
unboundApiKey: "unbound-key",
litellmApiKey: "litellm-key",
litellmBaseUrl: "http://localhost:4000",
lmStudioBaseUrl: "http://localhost:1234",
},
} as any)

Expand All @@ -2447,6 +2454,7 @@ describe("ClineProvider - Router Models", () => {
.mockResolvedValueOnce(mockModels) // glama success
.mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail
.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail
.mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio fail

await messageHandler({ type: "requestRouterModels" })

Expand Down Expand Up @@ -2492,6 +2500,13 @@ describe("ClineProvider - Router Models", () => {
error: "LiteLLM connection failed",
values: { provider: "litellm" },
})

expect(mockPostMessage).toHaveBeenCalledWith({
type: "singleRouterModelFetchResponse",
success: false,
error: "LMStudio API error",
values: { provider: "lmstudio" },
})
})

test("handles requestRouterModels with LiteLLM values from message", async () => {
Expand Down Expand Up @@ -2570,7 +2585,7 @@ describe("ClineProvider - Router Models", () => {
unbound: mockModels,
litellm: {},
ollama: {},
lmstudio: {},
lmstudio: mockModels,
},
})
})
Expand Down
18 changes: 16 additions & 2 deletions src/core/webview/__tests__/webviewMessageHandler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
unboundApiKey: "unbound-key",
litellmApiKey: "litellm-key",
litellmBaseUrl: "http://localhost:4000",
lmStudioBaseUrl: "http://localhost:1234",
},
})
})
Expand Down Expand Up @@ -141,6 +142,10 @@ describe("webviewMessageHandler - requestRouterModels", () => {
apiKey: "litellm-key",
baseUrl: "http://localhost:4000",
})
expect(mockGetModels).toHaveBeenCalledWith({
provider: "lmstudio",
baseUrl: "http://localhost:1234",
})

// Verify response was sent
expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
Expand All @@ -152,7 +157,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
unbound: mockModels,
litellm: mockModels,
ollama: {},
lmstudio: {},
lmstudio: mockModels,
},
})
})
Expand Down Expand Up @@ -239,7 +244,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
unbound: mockModels,
litellm: {},
ollama: {},
lmstudio: {},
lmstudio: mockModels,
},
})
})
Expand All @@ -261,6 +266,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
.mockResolvedValueOnce(mockModels) // glama
.mockRejectedValueOnce(new Error("Unbound API error")) // unbound
.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm
.mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio"))

await webviewMessageHandler(mockClineProvider, {
type: "requestRouterModels",
Expand Down Expand Up @@ -311,6 +317,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
.mockRejectedValueOnce(new Error("Glama API error")) // glama
.mockRejectedValueOnce(new Error("Unbound API error")) // unbound
.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm
.mockRejectedValueOnce(new Error("LMStudio API error")) // lmstudio

await webviewMessageHandler(mockClineProvider, {
type: "requestRouterModels",
Expand Down Expand Up @@ -351,6 +358,13 @@ describe("webviewMessageHandler - requestRouterModels", () => {
error: "LiteLLM connection failed",
values: { provider: "litellm" },
})

expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "singleRouterModelFetchResponse",
success: false,
error: "LMStudio API error",
values: { provider: "lmstudio" },
})
})

it("prefers config values over message values for LiteLLM", async () => {
Expand Down
Loading