Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ describe("BaseOpenAiCompatibleProvider", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 })
})
})
})
16 changes: 6 additions & 10 deletions src/api/providers/__tests__/featherless.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,9 @@ describe("FeatherlessHandler", () => {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "reasoning", text: "Thinking..." },
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(chunks[0]).toEqual({ type: "reasoning", text: "Thinking..." })
expect(chunks[1]).toEqual({ type: "text", text: "Hello" })
expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 })
})

it("should fall back to base provider for non-DeepSeek models", async () => {
Expand All @@ -145,10 +143,8 @@ describe("FeatherlessHandler", () => {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Test response" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(chunks[0]).toEqual({ type: "text", text: "Test response" })
expect(chunks[1]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 })
})

it("should return default model when no model is specified", () => {
Expand Down Expand Up @@ -226,7 +222,7 @@ describe("FeatherlessHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to Featherless client for DeepSeek R1", async () => {
Expand Down
10 changes: 4 additions & 6 deletions src/api/providers/__tests__/fireworks.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ describe("FireworksHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to Fireworks client", async () => {
Expand Down Expand Up @@ -494,10 +494,8 @@ describe("FireworksHandler", () => {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " world" },
{ type: "usage", inputTokens: 5, outputTokens: 10 },
])
expect(chunks[0]).toEqual({ type: "text", text: "Hello" })
expect(chunks[1]).toEqual({ type: "text", text: " world" })
expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
})
})
8 changes: 5 additions & 3 deletions src/api/providers/__tests__/groq.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ describe("GroqHandler", () => {
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheWriteTokens: 0,
cacheReadTokens: 0,
})
// cacheWriteTokens and cacheReadTokens will be undefined when 0
expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
expect(firstChunk.value.cacheReadTokens).toBeUndefined()
// Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts)
expect(typeof firstChunk.value.totalCost).toBe("number")
})
Expand Down Expand Up @@ -151,9 +152,10 @@ describe("GroqHandler", () => {
type: "usage",
inputTokens: 100,
outputTokens: 50,
cacheWriteTokens: 0,
cacheReadTokens: 30,
})
// cacheWriteTokens will be undefined when 0
expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
expect(typeof firstChunk.value.totalCost).toBe("number")
})

Expand Down
4 changes: 2 additions & 2 deletions src/api/providers/__tests__/io-intelligence.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ describe("IOIntelligenceHandler", () => {
expect(results).toHaveLength(3)
expect(results[0]).toEqual({ type: "text", text: "Hello" })
expect(results[1]).toEqual({ type: "text", text: " world" })
expect(results[2]).toEqual({
expect(results[2]).toMatchObject({
type: "usage",
inputTokens: 10,
outputTokens: 5,
Expand Down Expand Up @@ -243,7 +243,7 @@ describe("IOIntelligenceHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("should return model info from cache when available", () => {
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/__tests__/sambanova.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ describe("SambaNovaHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to SambaNova client", async () => {
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/__tests__/zai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ describe("ZAiHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to Z AI client", async () => {
Expand Down
70 changes: 55 additions & 15 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import type { ModelInfo } from "@roo-code/types"

import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
import { XmlMatcher } from "../../utils/xml-matcher"
import { ApiStream } from "../transform/stream"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { handleOpenAIError } from "./utils/openai-error-handler"
import { calculateApiCostOpenAI } from "../../shared/cost"

type BaseOpenAiCompatibleProviderOptions<ModelName extends string> = ApiHandlerOptions & {
providerName: string
Expand Down Expand Up @@ -94,6 +95,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
}

// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && info.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}

try {
return this.client.chat.completions.create(params, requestOptions)
} catch (error) {
Expand All @@ -119,6 +125,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()

let lastUsage: OpenAI.CompletionUsage | undefined

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
Expand All @@ -137,10 +145,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

if (delta && "reasoning_content" in delta) {
const reasoning_content = (delta.reasoning_content as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
}
break
}
}
}

Expand Down Expand Up @@ -176,11 +189,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
lastUsage = chunk.usage
}
}

Expand All @@ -198,20 +207,51 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
toolCallAccumulator.clear()
}

if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}

// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
}

protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

const { totalCost } = modelInfo
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
: { totalCost: 0 }

return {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens: cacheWriteTokens || undefined,
cacheReadTokens: cacheReadTokens || undefined,
totalCost,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new processUsageMetrics method in the base class now includes cost calculation via calculateApiCostOpenAI. However, FeatherlessHandler has a custom createMessage implementation for DeepSeek-R1 models that yields usage without cost calculation (lines 79-85). This creates an inconsistency: non-R1 models (which use super.createMessage()) will get cost calculation, but R1 models won't. Consider updating the R1 path to also use processUsageMetrics or call calculateApiCostOpenAI directly to maintain consistent cost tracking across all Featherless models.

Fix it with Roo Code or mention @roomote and request a fix.

}

async completePrompt(prompt: string): Promise<string> {
const { id: modelId } = this.getModel()
const { id: modelId, info: modelInfo } = this.getModel()

const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
model: modelId,
messages: [{ role: "user", content: prompt }],
}

// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}

try {
const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: "user", content: prompt }],
})
const response = await this.client.chat.completions.create(params)

// Check for provider-specific error responses (e.g., MiniMax base_resp)
const responseAny = response as any
Expand Down
59 changes: 0 additions & 59 deletions src/api/providers/groq.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import type { ApiHandlerCreateMessageMetadata } from "../index"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { calculateApiCostOpenAI } from "../../shared/cost"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

// Enhanced usage interface to support Groq's cached token fields
interface GroqUsage extends OpenAI.CompletionUsage {
prompt_tokens_details?: {
cached_tokens?: number
}
}

export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
constructor(options: ApiHandlerOptions) {
super({
Expand All @@ -29,50 +16,4 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
defaultTemperature: 0.5,
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield* this.yieldUsage(chunk.usage as GroqUsage)
}
}
}

private async *yieldUsage(usage: GroqUsage | undefined): ApiStream {
const { info } = this.getModel()
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0

const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

// Groq does not track cache writes
const cacheWriteTokens = 0

// Calculate cost using OpenAI-compatible cost calculation
const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)

yield {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
totalCost,
}
}
}
Loading
Loading