diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 452664e7dda..05b3e405f2e 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -279,6 +279,56 @@ describe("OpenAiHandler", () => { }) }) + it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => { + mockCreate.mockImplementation(async (options) => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_fallback", + function: { name: "fallback_tool", arguments: '{"test":"fallback"}' }, + }, + ], + }, + finish_reason: null, + }, + ], + } + // Stream ends without finish_reason being set to "tool_calls" + yield { + choices: [ + { + delta: {}, + finish_reason: "stop", // Different finish reason + }, + ], + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Tool calls should still be yielded via the fallback mechanism + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_fallback", + name: "fallback_tool", + arguments: '{"test":"fallback"}', + }) + }) + it("should include reasoning_effort when reasoning effort is enabled", async () => { const reasoningOptions: ApiHandlerOptions = { ...mockOptions, @@ -779,6 +829,58 @@ describe("OpenAiHandler", () => { }) }) + it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => { + const o3Handler = new OpenAiHandler(o3Options) + + mockCreate.mockImplementation(async (options) => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_o3_fallback", + function: { name: "o3_fallback_tool", arguments: '{"o3":"test"}' }, + }, + ], + }, + finish_reason: null, + }, + ], + } + // Stream ends with different finish reason + yield { + choices: [ + { + delta: {}, + finish_reason: "length", // Different finish reason + }, + ], + } + }, + } + }) + + const stream = o3Handler.createMessage("system", []) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Tool calls should still be yielded via the fallback mechanism + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_o3_fallback", + name: "o3_fallback_tool", + arguments: '{"o3":"test"}', + }) + }) + it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => { const o3Handler = new OpenAiHandler({ ...o3Options, diff --git a/src/api/providers/__tests__/roo.spec.ts b/src/api/providers/__tests__/roo.spec.ts index 7555a49d498..9dc9aff3db8 100644 --- a/src/api/providers/__tests__/roo.spec.ts +++ b/src/api/providers/__tests__/roo.spec.ts @@ -630,4 +630,280 @@ describe("RooHandler", () => { ) }) }) + + describe("tool calls handling", () => { + beforeEach(() => { + handler = new RooHandler(mockOptions) + }) + + it("should yield tool calls when finish_reason is tool_calls", async () => { + mockCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + function: { name: "read_file", arguments: '{"path":"' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: 'test.ts"}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0].id).toBe("call_123") + expect(toolCallChunks[0].name).toBe("read_file") + expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts"}') + }) + + it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => { + mockCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_456", + function: { + name: "write_to_file", + arguments: '{"path":"test.ts","content":"hello"}', + }, + }, + ], + }, + index: 0, + }, + ], + } + // Stream ends without finish_reason being set to "tool_calls" + yield { + choices: [ + { + delta: {}, + finish_reason: "stop", // Different finish reason + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Tool calls should still be yielded via the fallback mechanism + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0].id).toBe("call_456") + expect(toolCallChunks[0].name).toBe("write_to_file") + expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts","content":"hello"}') + }) + + it("should handle multiple tool calls", async () => { + mockCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_1", + function: { name: "read_file", arguments: '{"path":"file1.ts"}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 1, + id: "call_2", + function: { name: "read_file", arguments: '{"path":"file2.ts"}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(2) + expect(toolCallChunks[0].id).toBe("call_1") + expect(toolCallChunks[0].name).toBe("read_file") + expect(toolCallChunks[1].id).toBe("call_2") + expect(toolCallChunks[1].name).toBe("read_file") + }) + + it("should accumulate tool call arguments across multiple chunks", async () => { + mockCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_789", + function: { name: "execute_command", arguments: '{"command":"' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: "npm install" }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: '"}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0].id).toBe("call_789") + expect(toolCallChunks[0].name).toBe("execute_command") + expect(toolCallChunks[0].arguments).toBe('{"command":"npm install"}') + }) + + it("should not yield empty tool calls when no tool calls present", async () => { + mockCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ delta: { content: "Regular text response" }, index: 0 }], + } + yield { + choices: [{ delta: {}, finish_reason: "stop", index: 0 }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(0) + }) + }) }) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index 51db85410e1..ea0dc1b2e83 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -184,6 +184,20 @@ export abstract class BaseOpenAiCompatibleProvider } } + // Fallback: If stream ends with accumulated tool calls that weren't yielded + // (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls') + if (toolCallAccumulator.size > 0) { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + // Process any remaining content for (const processedChunk of matcher.final()) { yield processedChunk diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 79d65e82e2b..1c8d3c7d9d1 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -246,6 +246,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + // Fallback: If stream ends with accumulated tool calls that weren't yielded + // (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls') + if (toolCallAccumulator.size > 0) { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + for (const chunk of matcher.final()) { yield chunk } @@ -506,6 +520,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } } + + // Fallback: If stream ends with accumulated tool calls that weren't yielded + // (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls') + if (toolCallAccumulator.size > 0) { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } } private _getUrlHost(baseUrl?: string): string { diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index fa3aa5e5b0f..c63142aad92 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -265,6 +265,20 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } } + // Fallback: If stream ends with accumulated tool calls that weren't yielded + // (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls') + if (toolCallAccumulator.size > 0) { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + if (lastUsage) { yield { type: "usage", diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index 8c1cee939c6..393740d3bd4 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -199,6 +199,20 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { } } + // Fallback: If stream ends with accumulated tool calls that weren't yielded + // (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls') + if (toolCallAccumulator.size > 0) { + for (const [index, toolCall] of toolCallAccumulator.entries()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + if (lastUsage) { // Check if the current model is marked as free const model = this.getModel()