Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/api/providers/__tests__/chutes.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,83 @@ describe("ChutesHandler", () => {
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should yield tool_call_partial from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_123",
function: { name: "test_tool", arguments: '{"arg":"value"}' },
},
],
},
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "test_tool",
arguments: '{"arg":"value"}',
})
})

it("createMessage should pass tools and tool_choice to API", async () => {
const tools = [
{
type: "function" as const,
function: {
name: "test_tool",
description: "A test tool",
parameters: { type: "object", properties: {} },
},
},
]
const tool_choice = "auto" as const

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi.fn().mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [], { tools, tool_choice, taskId: "test-task-id" })
// Consume stream
for await (const _ of stream) {
// noop
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tools,
tool_choice,
}),
)
})

it("should apply DeepSeek default temperature for R1 models", () => {
const testModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({
Expand Down
36 changes: 34 additions & 2 deletions src/api/providers/chutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
private getCompletionParams(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
const { id: model, info } = this.getModel()

Expand All @@ -46,6 +47,8 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
...(metadata?.tools && { tools: metadata.tools }),
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
}

// Only add temperature if model supports it
Expand All @@ -65,7 +68,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan

if (model.id.includes("DeepSeek-R1")) {
const stream = await this.client.chat.completions.create({
...this.getCompletionParams(systemPrompt, messages),
...this.getCompletionParams(systemPrompt, messages, metadata),
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
})

Expand All @@ -87,6 +90,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
}
}

// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

if (chunk.usage) {
yield {
type: "usage",
Expand All @@ -102,7 +118,9 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
}
} else {
// For non-DeepSeek-R1 models, use standard OpenAI streaming
const stream = await this.client.chat.completions.create(this.getCompletionParams(systemPrompt, messages))
const stream = await this.client.chat.completions.create(
this.getCompletionParams(systemPrompt, messages, metadata),
)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
Expand All @@ -115,6 +133,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
}

// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

if (chunk.usage) {
yield {
type: "usage",
Expand Down Expand Up @@ -166,6 +197,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
override getModel() {
const model = super.getModel()
const isDeepSeekR1 = model.id.includes("DeepSeek-R1")

return {
...model,
info: {
Expand Down
215 changes: 215 additions & 0 deletions src/api/providers/fetchers/__tests__/chutes.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Mocks must come first, before imports
vi.mock("axios")

import type { Mock } from "vitest"
import type { ModelInfo } from "@roo-code/types"
import axios from "axios"
import { getChutesModels } from "../chutes"
import { chutesModels } from "@roo-code/types"

const mockedAxios = axios as typeof axios & {
get: Mock
}

describe("getChutesModels", () => {
beforeEach(() => {
vi.clearAllMocks()
})

it("should fetch and parse models successfully", async () => {
const mockResponse = {
data: {
data: [
{
id: "test/new-model",
object: "model",
owned_by: "test",
created: 1234567890,
context_length: 128000,
max_model_len: 8192,
input_modalities: ["text"],
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

expect(mockedAxios.get).toHaveBeenCalledWith(
"https://llm.chutes.ai/v1/models",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer test-api-key",
}),
}),
)

expect(models["test/new-model"]).toEqual({
maxTokens: 8192,
contextWindow: 128000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: false,
inputPrice: 0,
outputPrice: 0,
description: "Chutes AI model: test/new-model",
})
})

it("should override hardcoded models with dynamic API data", async () => {
// Find any hardcoded model
const [modelId] = Object.entries(chutesModels)[0]

const mockResponse = {
data: {
data: [
{
id: modelId,
object: "model",
owned_by: "test",
created: 1234567890,
context_length: 200000, // Different from hardcoded
max_model_len: 10000, // Different from hardcoded
input_modalities: ["text", "image"],
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

// Dynamic values should override hardcoded
expect(models[modelId]).toBeDefined()
expect(models[modelId].contextWindow).toBe(200000)
expect(models[modelId].maxTokens).toBe(10000)
expect(models[modelId].supportsImages).toBe(true)
})

it("should return hardcoded models when API returns empty", async () => {
const mockResponse = {
data: {
data: [],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

// Should still have hardcoded models
expect(Object.keys(models).length).toBeGreaterThan(0)
expect(models).toEqual(expect.objectContaining(chutesModels))
})

it("should return hardcoded models on API error", async () => {
mockedAxios.get.mockRejectedValue(new Error("Network error"))

const models = await getChutesModels("test-api-key")

// Should still have hardcoded models
expect(Object.keys(models).length).toBeGreaterThan(0)
expect(models).toEqual(chutesModels)
})

it("should work without API key", async () => {
const mockResponse = {
data: {
data: [],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels()

expect(mockedAxios.get).toHaveBeenCalledWith(
"https://llm.chutes.ai/v1/models",
expect.objectContaining({
headers: expect.not.objectContaining({
Authorization: expect.anything(),
}),
}),
)

expect(Object.keys(models).length).toBeGreaterThan(0)
})

it("should detect image support from input_modalities", async () => {
const mockResponse = {
data: {
data: [
{
id: "test/image-model",
object: "model",
owned_by: "test",
created: 1234567890,
context_length: 128000,
max_model_len: 8192,
input_modalities: ["text", "image"],
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

expect(models["test/image-model"].supportsImages).toBe(true)
})

it("should detect native tool support from supported_features", async () => {
const mockResponse = {
data: {
data: [
{
id: "test/tools-model",
object: "model",
owned_by: "test",
created: 1234567890,
context_length: 128000,
max_model_len: 8192,
input_modalities: ["text"],
supported_features: ["json_mode", "tools", "reasoning"],
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

expect(models["test/tools-model"].supportsNativeTools).toBe(true)
})

it("should not enable native tool support when tools is not in supported_features", async () => {
const mockResponse = {
data: {
data: [
{
id: "test/no-tools-model",
object: "model",
owned_by: "test",
created: 1234567890,
context_length: 128000,
max_model_len: 8192,
input_modalities: ["text"],
supported_features: ["json_mode", "reasoning"],
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const models = await getChutesModels("test-api-key")

expect(models["test/no-tools-model"].supportsNativeTools).toBe(false)
expect(models["test/no-tools-model"].defaultToolProtocol).toBeUndefined()
})
})
Loading
Loading