diff --git a/src/api/providers/lmstudio.ts b/src/api/providers/lmstudio.ts index 6053ed056d4..0901cb27680 100644 --- a/src/api/providers/lmstudio.ts +++ b/src/api/providers/lmstudio.ts @@ -25,57 +25,108 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] - try { - // Create params object with optional draft model - const params: any = { - model: this.getModel().id, - messages: openAiMessages, - temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE, - stream: true, - } - - // Add draft model if speculative decoding is enabled and a draft model is specified - if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) { - params.draft_model = this.options.lmStudioDraftModelId - } + // ------------------------- + // Track token usage + // ------------------------- + const toContentBlocks = ( + blocks: Anthropic.Messages.MessageParam[] | string, + ): Anthropic.Messages.ContentBlockParam[] => { + if (typeof blocks === "string") { + return [{ type: "text", text: blocks }] + } - const results = await this.client.chat.completions.create(params) - - const matcher = new XmlMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) - - // Stream handling - // @ts-ignore - for await (const chunk of results) { - const delta = chunk.choices[0]?.delta - - if (delta?.content) { - for (const chunk of matcher.update(delta.content)) { - yield chunk + const result: Anthropic.Messages.ContentBlockParam[] = [] + for (const msg of blocks) { + if (typeof msg.content === "string") { + result.push({ type: "text", text: msg.content }) + } else if (Array.isArray(msg.content)) { + for (const part of msg.content) { + if (part.type === "text") { + result.push({ type: "text", text: part.text }) } } } - for (const chunk of matcher.final()) { - yield chunk + } + return result + } + + let inputTokens = 0 + try { + inputTokens = await this.countTokens([ + { type: "text", text: systemPrompt }, + ...toContentBlocks(messages), + ]) + } catch (err) { + console.error("[LmStudio] Failed to count input tokens:", err) + inputTokens = 0 + } + + let assistantText = "" + + try { + const params: OpenAI.Chat.ChatCompletionCreateParamsStreaming & { draft_model?: string } = { + model: this.getModel().id, + messages: openAiMessages, + temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE, + stream: true, + } + + if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) { + params.draft_model = this.options.lmStudioDraftModelId + } + + const results = await this.client.chat.completions.create(params) + + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) + + for await (const chunk of results) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + assistantText += delta.content + for (const processedChunk of matcher.update(delta.content)) { + yield processedChunk + } } - } catch (error) { - // LM Studio doesn't return an error code/body for now - throw new Error( - "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.", - ) } + + for (const processedChunk of matcher.final()) { + yield processedChunk + } + + + let outputTokens = 0 + try { + outputTokens = await this.countTokens([{ type: "text", text: assistantText }]) + } catch (err) { + console.error("[LmStudio] Failed to count output tokens:", err) + outputTokens = 0 + } + + yield { + type: "usage", + inputTokens, + outputTokens, + } as const + } catch (error) { + throw new Error( + "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.", + ) } +} + override getModel(): { id: string; info: ModelInfo } { return { diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index 26374d5d583..1b721a59093 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -11,6 +11,9 @@ import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" import { XmlMatcher } from "../../utils/xml-matcher" import { BaseProvider } from "./base-provider" +// Alias for the usage object returned in streaming chunks +type CompletionUsage = OpenAI.Chat.Completions.ChatCompletionChunk["usage"] + export class OllamaHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI @@ -37,6 +40,7 @@ export class OllamaHandler extends BaseProvider implements SingleCompletionHandl messages: openAiMessages, temperature: this.options.modelTemperature ?? 0, stream: true, + stream_options: { include_usage: true }, }) const matcher = new XmlMatcher( "think", @@ -46,18 +50,30 @@ export class OllamaHandler extends BaseProvider implements SingleCompletionHandl text: chunk.data, }) as const, ) + let lastUsage: CompletionUsage | undefined for await (const chunk of stream) { const delta = chunk.choices[0]?.delta if (delta?.content) { - for (const chunk of matcher.update(delta.content)) { - yield chunk + for (const matcherChunk of matcher.update(delta.content)) { + yield matcherChunk } } + if (chunk.usage) { + lastUsage = chunk.usage + } } for (const chunk of matcher.final()) { yield chunk } + + if (lastUsage) { + yield { + type: "usage", + inputTokens: lastUsage?.prompt_tokens || 0, + outputTokens: lastUsage?.completion_tokens || 0, + } + } } override getModel(): { id: string; info: ModelInfo } {