From 8655db725770b959ce0e05e4a803219c08c3dfc6 Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Fri, 30 Jan 2026 13:30:50 -0500 Subject: [PATCH 1/3] feat(api): migrate Fireworks provider to AI SDK - Replace BaseOpenAiCompatibleProvider with BaseProvider + AI SDK - Use @ai-sdk/fireworks package for native Fireworks support - Implement createMessage using streamText for streaming responses - Implement completePrompt using generateText - Add processUsageMetrics for handling cache and reasoning tokens - Add mapToolChoice for AI SDK toolChoice format conversion - Update tests to use AI SDK mocking pattern Resolves: EXT-698 --- pnpm-lock.yaml | 30 +- src/api/providers/__tests__/fireworks.spec.ts | 1285 ++++++++++------- src/api/providers/fireworks.ts | 199 ++- src/package.json | 1 + 4 files changed, 1010 insertions(+), 505 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index e287787ede..01f1fb5f41 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -752,6 +752,9 @@ importers: '@ai-sdk/deepseek': specifier: ^2.0.14 version: 2.0.14(zod@3.25.76) + '@ai-sdk/fireworks': + specifier: ^2.0.26 + version: 2.0.26(zod@3.25.76) '@ai-sdk/groq': specifier: ^3.0.19 version: 3.0.19(zod@3.25.76) @@ -1411,6 +1414,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/fireworks@2.0.26': + resolution: {integrity: sha512-vBqSSksHhDGrSNYnmEmVGvLicHFjL4yAxFZfCb6ydrg+qgnlW2bdyTQDMI69BKG4spNZ1/iHMxRNIQpx19Yf6w==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/gateway@3.0.25': resolution: {integrity: sha512-j0AQeA7hOVqwImykQlganf/Euj3uEXf0h3G0O4qKTDpEwE+EZGIPnVimCWht5W91lAetPZSfavDyvfpuPDd2PQ==} engines: {node: '>=18'} @@ -1429,6 +1438,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.24': + resolution: {integrity: sha512-3QrCKpQCn3g6sIMoFGuEroaqk7Xg+qfsohRp4dKszjto5stjBg4SdtOKqHg+CpE3X4woj2O62w2qr5dSekMZeQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -11042,6 +11057,13 @@ snapshots: '@ai-sdk/provider-utils': 4.0.10(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/fireworks@2.0.26(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.24(zod@3.25.76) + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/gateway@3.0.25(zod@3.25.76)': dependencies: '@ai-sdk/provider': 3.0.5 @@ -11061,6 +11083,12 @@ snapshots: '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.24(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -15070,7 +15098,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.57)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index 79f69f868b..8307abd7ff 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -1,594 +1,891 @@ -// npx vitest run api/providers/__tests__/fireworks.spec.ts +// npx vitest run src/api/providers/__tests__/fireworks.spec.ts -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) -import { FireworksHandler } from "../fireworks" +vi.mock("@ai-sdk/fireworks", () => ({ + createFireworks: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", + provider: "fireworks", + })) + }), +})) -// Create mock functions -const mockCreate = vi.fn() +import type { Anthropic } from "@anthropic-ai/sdk" -// Mock OpenAI module -vi.mock("openai", () => ({ - default: vi.fn(() => ({ - chat: { - completions: { - create: mockCreate, - }, - }, - })), -})) +import { fireworksDefaultModelId, fireworksModels, type FireworksModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { FireworksHandler } from "../fireworks" describe("FireworksHandler", () => { let handler: FireworksHandler + let mockOptions: ApiHandlerOptions beforeEach(() => { + mockOptions = { + fireworksApiKey: "test-fireworks-api-key", + apiModelId: "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", + } + handler = new FireworksHandler(mockOptions) vi.clearAllMocks() - // Set up default mock implementation - mockCreate.mockImplementation(async () => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - })) - handler = new FireworksHandler({ fireworksApiKey: "test-key" }) }) - afterEach(() => { - vi.restoreAllMocks() - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(FireworksHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it("should use the correct Fireworks base URL", () => { - new FireworksHandler({ fireworksApiKey: "test-fireworks-api-key" }) - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ baseURL: "https://api.fireworks.ai/inference/v1" }), - ) + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new FireworksHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(fireworksDefaultModelId) + }) }) - it("should use the provided API key", () => { - const fireworksApiKey = "test-fireworks-api-key" - new FireworksHandler({ fireworksApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: fireworksApiKey })) - }) + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new FireworksHandler({ + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(fireworksDefaultModelId) + expect(model.info).toEqual(fireworksModels[fireworksDefaultModelId]) + }) - it("should throw error when API key is not provided", () => { - expect(() => new FireworksHandler({})).toThrow("API key is required") - }) + it("should return specified model when valid model is provided", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(fireworksModels[testModelId]) + }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(fireworksDefaultModelId) - expect(model.info).toEqual(expect.objectContaining(fireworksModels[fireworksDefaultModelId])) - }) + it("should return Kimi K2 Instruct model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.6, + outputPrice: 2.5, + description: expect.stringContaining("Kimi K2 is a state-of-the-art mixture-of-experts"), + }), + ) + }) - it("should return specified model when valid model is provided", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return Kimi K2 Thinking model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16000, + contextWindow: 256000, + supportsImages: false, + supportsPromptCache: true, + supportsTemperature: true, + preserveReasoning: true, + defaultTemperature: 1.0, + inputPrice: 0.6, + outputPrice: 2.5, + cacheReadsPrice: 0.15, + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(expect.objectContaining(fireworksModels[testModelId])) - }) - it("should return Kimi K2 Instruct model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return MiniMax M2 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/minimax-m2" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 4096, + contextWindow: 204800, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.3, + outputPrice: 1.2, + description: expect.stringContaining("MiniMax M2 is a high-performance language model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.6, - outputPrice: 2.5, - description: expect.stringContaining("Kimi K2 is a state-of-the-art mixture-of-experts"), - }), - ) - }) - it("should return Kimi K2 Thinking model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return Qwen3 235B model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 32768, + contextWindow: 256000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.22, + outputPrice: 0.88, + description: + "Latest Qwen3 thinking model, competitive against the best closed source models in Jul 2025.", + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16000, - contextWindow: 256000, - supportsImages: false, - supportsPromptCache: true, - supportsTemperature: true, - preserveReasoning: true, - defaultTemperature: 1.0, - inputPrice: 0.6, - outputPrice: 2.5, - cacheReadsPrice: 0.15, - }), - ) - }) - it("should return MiniMax M2 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/minimax-m2" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek R1 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-r1-0528" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 20480, + contextWindow: 160000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 3, + outputPrice: 8, + description: expect.stringContaining("05/28 updated checkpoint of Deepseek R1"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 4096, - contextWindow: 204800, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.3, - outputPrice: 1.2, - description: expect.stringContaining("MiniMax M2 is a high-performance language model"), - }), - ) - }) - it("should return Qwen3 235B model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek V3 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: expect.stringContaining("strong Mixture-of-Experts (MoE) language model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 32768, - contextWindow: 256000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.22, - outputPrice: 0.88, - description: - "Latest Qwen3 thinking model, competitive against the best closed source models in Jul 2025.", - }), - ) - }) - it("should return DeepSeek R1 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-r1-0528" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek V3.1 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3p1" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.56, + outputPrice: 1.68, + description: expect.stringContaining("DeepSeek v3.1 is an improved version"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 20480, - contextWindow: 160000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 3, - outputPrice: 8, - description: expect.stringContaining("05/28 updated checkpoint of Deepseek R1"), - }), - ) - }) - it("should return DeepSeek V3 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.5 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.5 with 355B total parameters"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.9, - outputPrice: 0.9, - description: expect.stringContaining("strong Mixture-of-Experts (MoE) language model"), - }), - ) - }) - it("should return DeepSeek V3.1 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3p1" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.5-Air model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5-air" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.5-Air with 106B total parameters"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 163840, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.56, - outputPrice: 1.68, - description: expect.stringContaining("DeepSeek v3.1 is an improved version"), - }), - ) - }) - it("should return GLM-4.5 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.6 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p6" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 25344, + contextWindow: 198000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.6 is an advanced coding model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.5 with 355B total parameters"), - }), - ) - }) - it("should return GLM-4.5-Air model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5-air" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return gpt-oss-20b model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-20b" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.07, + outputPrice: 0.3, + description: expect.stringContaining( + "OpenAI gpt-oss-20b: Compact model for local/edge deployments", + ), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.5-Air with 106B total parameters"), - }), - ) - }) - it("should return GLM-4.6 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p6" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return gpt-oss-120b model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-120b" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + description: expect.stringContaining( + "OpenAI gpt-oss-120b: Production-grade, general-purpose model", + ), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 25344, - contextWindow: 198000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.6 is an advanced coding model"), - }), - ) - }) - it("should return gpt-oss-20b model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-20b" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new FireworksHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + // Should use default model info + expect(model.info).toBe(fireworksModels[fireworksDefaultModelId]) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.07, - outputPrice: 0.3, - description: expect.stringContaining("OpenAI gpt-oss-20b: Compact model for local/edge deployments"), - }), - ) - }) - it("should return gpt-oss-120b model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-120b" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.15, - outputPrice: 0.6, - description: expect.stringContaining("OpenAI gpt-oss-120b: Production-grade, general-purpose model"), - }), - ) }) - it("completePrompt method should return text from Fireworks API", async () => { - const expectedResponse = "This is a test response from Fireworks" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] - it("should handle errors in completePrompt", async () => { - const errorMessage = "Fireworks API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow( - `Fireworks completion error: ${errorMessage}`, - ) - }) + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from Fireworks" } + } - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from Fireworks stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from Fireworks") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) - }) + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), - }), + // Fireworks provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + fireworks: { + promptCacheHitTokens: 30, + promptCacheMissTokens: 70, + }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) + expect(usageChunks[0].cacheWriteTokens).toBe(70) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) - }) + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) - it("createMessage should pass correct parameters to Fireworks client", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const modelInfo = fireworksModels[modelId] - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() }) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-instruct", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, }), + ) + }) + + it("should use model defaultTemperature (1.0) over provider default (0.5) for kimi-k2-thinking", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithThinkingModel = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-thinking", + }) + + const stream = handlerWithThinkingModel.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + // Model's defaultTemperature (1.0) should take precedence over provider's default (0.5) + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 1.0, + }), + ) + }) + + it("should use user-specified temperature over model and provider defaults", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-thinking", + modelTemperature: 0.7, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream } + + // User-specified temperature should take precedence over everything + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) }) - const systemPrompt = "Test system prompt for Fireworks" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Fireworks" }] + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0.5, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - undefined, - ) + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) }) - it("should use provider default temperature of 0.5 for models without defaultTemperature", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from Fireworks", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from Fireworks") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) }) - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - })) + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + await handler.completePrompt("Test prompt") - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.5, - }), - undefined, - ) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) }) - it("should use model defaultTemperature (1.0) over provider default (0.5) for kimi-k2-thinking", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", - }) + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + fireworks: { + promptCacheHitTokens: 20, + promptCacheMissTokens: 80, }, - }), - })) + } - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) - // Model's defaultTemperature (1.0) should take precedence over provider's default (0.5) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 1.0, - }), - undefined, - ) - }) + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) + }) - it("should use user-specified temperature over model and provider defaults", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", - modelTemperature: 0.7, + it("should handle missing cache metrics gracefully", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() }) - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - })) + it("should include reasoning tokens when provided", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + const testHandler = new TestFireworksHandler(mockOptions) - // User-specified temperature should take precedence over everything - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.7, - }), - undefined, - ) - }) + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } - it("should handle empty response in completePrompt", async () => { - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: null } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe("") - }) + const result = testHandler.testProcessUsageMetrics(usage) - it("should handle missing choices in completePrompt", async () => { - mockCreate.mockResolvedValueOnce({ choices: [] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe("") + expect(result.reasoningTokens).toBe(30) + }) }) - it("createMessage should handle stream with multiple chunks", async () => { - mockCreate.mockImplementationOnce(async () => ({ - [Symbol.asyncIterator]: async function* () { + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { yield { - choices: [ - { - delta: { content: "Hello" }, - index: 0, - }, - ], - usage: null, + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", } yield { - choices: [ - { - delta: { content: " world" }, - index: 0, - }, - ], - usage: null, + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', } yield { - choices: [ - { - delta: {}, - index: 0, + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, }, - ], - usage: { - prompt_tokens: 5, - completion_tokens: 10, - total_tokens: 15, }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, } - }, - })) + } - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) - const stream = handler.createMessage(systemPrompt, messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // tool-call events should be ignored (only tool-input-start/delta/end are processed) + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) + }) + }) + + describe("mapToolChoice", () => { + it("should map string tool choices correctly", () => { + class TestFireworksHandler extends FireworksHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) - expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) - expect(chunks[1]).toEqual({ type: "text", text: " world" }) - expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + expect(testHandler.testMapToolChoice("auto")).toBe("auto") + expect(testHandler.testMapToolChoice("none")).toBe("none") + expect(testHandler.testMapToolChoice("required")).toBe("required") + expect(testHandler.testMapToolChoice("unknown")).toBe("auto") + }) + + it("should map object tool choices correctly", () => { + class TestFireworksHandler extends FireworksHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) + + const result = testHandler.testMapToolChoice({ + type: "function", + function: { name: "read_file" }, + }) + expect(result).toEqual({ type: "tool", toolName: "read_file" }) + }) + + it("should return undefined for null/undefined", () => { + class TestFireworksHandler extends FireworksHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) + + expect(testHandler.testMapToolChoice(null)).toBeUndefined() + expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() + }) }) }) diff --git a/src/api/providers/fireworks.ts b/src/api/providers/fireworks.ts index db29e7bf3f..0a7a074083 100644 --- a/src/api/providers/fireworks.ts +++ b/src/api/providers/fireworks.ts @@ -1,19 +1,198 @@ -import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createFireworks } from "@ai-sdk/fireworks" +import { streamText, generateText, ToolSet } from "ai" + +import { fireworksModels, fireworksDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const FIREWORKS_DEFAULT_TEMPERATURE = 0.5 + +/** + * Fireworks provider using the dedicated @ai-sdk/fireworks package. + * Provides native support for various models including reasoning models. + */ +export class FireworksHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType -export class FireworksHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "Fireworks", + super() + this.options = options + + // Create the Fireworks provider using AI SDK + this.provider = createFireworks({ baseURL: "https://api.fireworks.ai/inference/v1", - apiKey: options.fireworksApiKey, - defaultProviderModelId: fireworksDefaultModelId, - providerModels: fireworksModels, - defaultTemperature: 0.5, + apiKey: options.fireworksApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, }) } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? fireworksDefaultModelId + const info = fireworksModels[id as keyof typeof fireworksModels] || fireworksModels[fireworksDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: FIREWORKS_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } + } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + fireworks?: { + promptCacheHitTokens?: number + promptCacheMissTokens?: number + } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from Fireworks' providerMetadata if available + const cacheReadTokens = providerMetadata?.fireworks?.promptCacheHitTokens ?? usage.details?.cachedInputTokens + const cacheWriteTokens = providerMetadata?.fireworks?.promptCacheMissTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Map OpenAI tool_choice to AI SDK toolChoice format. + */ + protected mapToolChoice( + toolChoice: any, + ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { + if (!toolChoice) { + return undefined + } + + // Handle string values + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "auto": + return "auto" + case "none": + return "none" + case "required": + return "required" + default: + return "auto" + } + } + + // Handle object values (OpenAI ChatCompletionNamedToolChoice format) + if (typeof toolChoice === "object" && "type" in toolChoice) { + if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { + return { type: "tool", toolName: toolChoice.function.name } + } + } + + return undefined + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: this.mapToolChoice(metadata?.tool_choice), + } + + // Use streamText for streaming responses + const result = streamText(requestOptions) + + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE, + }) + + return text + } } diff --git a/src/package.json b/src/package.json index 7641e8d061..b5c79fff8c 100644 --- a/src/package.json +++ b/src/package.json @@ -452,6 +452,7 @@ "dependencies": { "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", + "@ai-sdk/fireworks": "^2.0.26", "@ai-sdk/groq": "^3.0.19", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", From b6c44114c72d8506fdea9ca7f9fbf848aee26b8f Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Fri, 30 Jan 2026 13:39:36 -0500 Subject: [PATCH 2/3] refactor: extract mapToolChoice to shared ai-sdk utilities - Move duplicated mapToolChoice function from all providers to ai-sdk.ts - Update fireworks, groq, deepseek, cerebras, openai-compatible providers - Consolidate tests for mapToolChoice in ai-sdk.spec.ts - Remove duplicate tests from individual provider test files --- src/api/providers/__tests__/cerebras.spec.ts | 47 ------------------- src/api/providers/__tests__/deepseek.spec.ts | 47 ------------------- src/api/providers/__tests__/fireworks.spec.ts | 46 ------------------ src/api/providers/__tests__/groq.spec.ts | 47 ------------------- src/api/providers/cerebras.ts | 43 +++-------------- src/api/providers/deepseek.ts | 43 +++-------------- src/api/providers/fireworks.ts | 43 +++-------------- src/api/providers/groq.ts | 43 +++-------------- src/api/providers/openai-compatible.ts | 43 +++-------------- src/api/transform/__tests__/ai-sdk.spec.ts | 47 ++++++++++++++++++- src/api/transform/ai-sdk.ts | 41 ++++++++++++++++ 11 files changed, 122 insertions(+), 368 deletions(-) diff --git a/src/api/providers/__tests__/cerebras.spec.ts b/src/api/providers/__tests__/cerebras.spec.ts index aefb8a599c..caf8861b46 100644 --- a/src/api/providers/__tests__/cerebras.spec.ts +++ b/src/api/providers/__tests__/cerebras.spec.ts @@ -452,51 +452,4 @@ describe("CerebrasHandler", () => { expect(toolCallChunks.length).toBe(0) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/__tests__/deepseek.spec.ts b/src/api/providers/__tests__/deepseek.spec.ts index 82b08aaad5..ece03c068e 100644 --- a/src/api/providers/__tests__/deepseek.spec.ts +++ b/src/api/providers/__tests__/deepseek.spec.ts @@ -733,51 +733,4 @@ describe("DeepSeekHandler", () => { expect(result).toBe(8192) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index 8307abd7ff..77c4b10f45 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -842,50 +842,4 @@ describe("FireworksHandler", () => { expect(toolCallChunks.length).toBe(0) }) }) - - describe("mapToolChoice", () => { - it("should map string tool choices correctly", () => { - class TestFireworksHandler extends FireworksHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestFireworksHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should map object tool choices correctly", () => { - class TestFireworksHandler extends FireworksHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestFireworksHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "read_file" }, - }) - expect(result).toEqual({ type: "tool", toolName: "read_file" }) - }) - - it("should return undefined for null/undefined", () => { - class TestFireworksHandler extends FireworksHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestFireworksHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/__tests__/groq.spec.ts b/src/api/providers/__tests__/groq.spec.ts index c4a9471c87..efb5712cb9 100644 --- a/src/api/providers/__tests__/groq.spec.ts +++ b/src/api/providers/__tests__/groq.spec.ts @@ -575,51 +575,4 @@ describe("GroqHandler", () => { expect(result).toBe(customMaxTokens) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index 0fbc375bdc..b86c2c519b 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -6,7 +6,12 @@ import { cerebrasModels, cerebrasDefaultModelId, type CerebrasModelId, type Mode import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -75,40 +80,6 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -143,7 +114,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index 27ecaebd5e..949dcfb306 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -6,7 +6,12 @@ import { deepSeekModels, deepSeekDefaultModelId, DEEP_SEEK_DEFAULT_TEMPERATURE, import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -83,40 +88,6 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -152,7 +123,7 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan temperature: this.options.modelTemperature ?? temperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses diff --git a/src/api/providers/fireworks.ts b/src/api/providers/fireworks.ts index 0a7a074083..ee1096805c 100644 --- a/src/api/providers/fireworks.ts +++ b/src/api/providers/fireworks.ts @@ -6,7 +6,12 @@ import { fireworksModels, fireworksDefaultModelId, type ModelInfo } from "@roo-c import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -90,40 +95,6 @@ export class FireworksHandler extends BaseProvider implements SingleCompletionHa } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -158,7 +129,7 @@ export class FireworksHandler extends BaseProvider implements SingleCompletionHa temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 64399ad674..27c7bf2c27 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -6,7 +6,12 @@ import { groqModels, groqDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -91,40 +96,6 @@ export class GroqHandler extends BaseProvider implements SingleCompletionHandler } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -160,7 +131,7 @@ export class GroqHandler extends BaseProvider implements SingleCompletionHandler temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index d129e72452..e2b5843442 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -12,7 +12,12 @@ import type { ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { DEFAULT_HEADERS } from "./constants" @@ -103,40 +108,6 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -173,7 +144,7 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si temperature: model.temperature ?? this.config.temperature ?? 0, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index 293d720d48..a7f0ece3ec 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -1,6 +1,6 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../ai-sdk" +import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice } from "../ai-sdk" vitest.mock("ai", () => ({ tool: vitest.fn((t) => t), @@ -486,4 +486,49 @@ describe("AI SDK conversion utilities", () => { } }) }) + + describe("mapToolChoice", () => { + it("should return undefined for null or undefined", () => { + expect(mapToolChoice(null)).toBeUndefined() + expect(mapToolChoice(undefined)).toBeUndefined() + }) + + it("should handle string tool choices", () => { + expect(mapToolChoice("auto")).toBe("auto") + expect(mapToolChoice("none")).toBe("none") + expect(mapToolChoice("required")).toBe("required") + }) + + it("should return auto for unknown string values", () => { + expect(mapToolChoice("unknown")).toBe("auto") + expect(mapToolChoice("invalid")).toBe("auto") + }) + + it("should handle object tool choice with function name", () => { + const result = mapToolChoice({ + type: "function", + function: { name: "my_tool" }, + }) + + expect(result).toEqual({ type: "tool", toolName: "my_tool" }) + }) + + it("should return undefined for object without function name", () => { + const result = mapToolChoice({ + type: "function", + function: {}, + }) + + expect(result).toBeUndefined() + }) + + it("should return undefined for object with non-function type", () => { + const result = mapToolChoice({ + type: "other", + function: { name: "my_tool" }, + }) + + expect(result).toBeUndefined() + }) + }) }) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index fd86532bf1..0722767072 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -273,3 +273,44 @@ export function* processAiSdkStreamPart(part: ExtendedStreamPart): Generator Date: Fri, 30 Jan 2026 14:31:52 -0500 Subject: [PATCH 3/3] fix(api): add proper error handling for AI SDK providers - Added extractAiSdkErrorMessage utility to extract user-friendly error messages from AI SDK errors (AI_RetryError, AI_APICallError) - Added handleAiSdkError utility to wrap errors with provider name and preserve status codes for retry logic - Updated all AI SDK providers (Fireworks, Groq, DeepSeek, Cerebras, OpenAI-compatible) with try/catch error handling - Added comprehensive tests for error handling utilities This ensures errors like 'AI_RetryError: Failed after 3 attempts. Last error: Too Many Requests' are properly surfaced in the UI instead of showing 'No output generated. Check the stream for errors.' --- src/api/providers/cerebras.ts | 24 +++-- src/api/providers/deepseek.ts | 26 +++-- src/api/providers/fireworks.ts | 26 +++-- src/api/providers/groq.ts | 26 +++-- src/api/providers/openai-compatible.ts | 26 +++-- src/api/transform/__tests__/ai-sdk.spec.ts | 115 ++++++++++++++++++++- src/api/transform/ai-sdk.ts | 81 +++++++++++++++ 7 files changed, 274 insertions(+), 50 deletions(-) diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index b86c2c519b..de1a4b2dbb 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -11,6 +11,7 @@ import { convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice, + handleAiSdkError, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -120,17 +121,22 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end - const usage = await result.usage - if (usage) { - yield this.processUsageMetrics(usage) + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Cerebras") } } diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index 949dcfb306..ba9c9d47e3 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -11,6 +11,7 @@ import { convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice, + handleAiSdkError, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -129,18 +130,23 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end, including cache metrics from providerMetadata - const usage = await result.usage - const providerMetadata = await result.providerMetadata - if (usage) { - yield this.processUsageMetrics(usage, providerMetadata as any) + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "DeepSeek") } } diff --git a/src/api/providers/fireworks.ts b/src/api/providers/fireworks.ts index ee1096805c..52bf431bb6 100644 --- a/src/api/providers/fireworks.ts +++ b/src/api/providers/fireworks.ts @@ -11,6 +11,7 @@ import { convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice, + handleAiSdkError, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -135,18 +136,23 @@ export class FireworksHandler extends BaseProvider implements SingleCompletionHa // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end, including cache metrics from providerMetadata - const usage = await result.usage - const providerMetadata = await result.providerMetadata - if (usage) { - yield this.processUsageMetrics(usage, providerMetadata as any) + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Fireworks") } } diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 27c7bf2c27..648679f92c 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -11,6 +11,7 @@ import { convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice, + handleAiSdkError, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -137,18 +138,23 @@ export class GroqHandler extends BaseProvider implements SingleCompletionHandler // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end, including cache metrics from providerMetadata - const usage = await result.usage - const providerMetadata = await result.providerMetadata - if (usage) { - yield this.processUsageMetrics(usage, providerMetadata as any) + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Groq") } } diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index e2b5843442..240de747be 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -17,6 +17,7 @@ import { convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice, + handleAiSdkError, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" @@ -150,18 +151,23 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events - for await (const part of result.fullStream) { - // Use the processAiSdkStreamPart utility to convert stream parts - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events + for await (const part of result.fullStream) { + // Use the processAiSdkStreamPart utility to convert stream parts + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end - const usage = await result.usage - if (usage) { - yield this.processUsageMetrics(usage) + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, this.config.providerName) } } diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index a7f0ece3ec..bd87fd8eeb 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -1,6 +1,13 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart, mapToolChoice } from "../ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + extractAiSdkErrorMessage, + handleAiSdkError, +} from "../ai-sdk" vitest.mock("ai", () => ({ tool: vitest.fn((t) => t), @@ -531,4 +538,110 @@ describe("AI SDK conversion utilities", () => { expect(result).toBeUndefined() }) }) + + describe("extractAiSdkErrorMessage", () => { + it("should return 'Unknown error' for null/undefined", () => { + expect(extractAiSdkErrorMessage(null)).toBe("Unknown error") + expect(extractAiSdkErrorMessage(undefined)).toBe("Unknown error") + }) + + it("should extract message from AI_RetryError", () => { + const retryError = { + name: "AI_RetryError", + message: "Failed after 3 attempts", + errors: [new Error("Error 1"), new Error("Error 2"), new Error("Too Many Requests")], + lastError: { message: "Too Many Requests", status: 429 }, + } + + const result = extractAiSdkErrorMessage(retryError) + expect(result).toBe("Failed after 3 attempts (429): Too Many Requests") + }) + + it("should handle AI_RetryError without status", () => { + const retryError = { + name: "AI_RetryError", + message: "Failed after 2 attempts", + errors: [new Error("Error 1"), new Error("Connection failed")], + lastError: { message: "Connection failed" }, + } + + const result = extractAiSdkErrorMessage(retryError) + expect(result).toBe("Failed after 2 attempts: Connection failed") + }) + + it("should extract message from AI_APICallError", () => { + const apiError = { + name: "AI_APICallError", + message: "Rate limit exceeded", + status: 429, + } + + const result = extractAiSdkErrorMessage(apiError) + expect(result).toBe("API Error (429): Rate limit exceeded") + }) + + it("should handle AI_APICallError without status", () => { + const apiError = { + name: "AI_APICallError", + message: "Connection timeout", + } + + const result = extractAiSdkErrorMessage(apiError) + expect(result).toBe("Connection timeout") + }) + + it("should extract message from standard Error", () => { + const error = new Error("Something went wrong") + expect(extractAiSdkErrorMessage(error)).toBe("Something went wrong") + }) + + it("should convert non-Error to string", () => { + expect(extractAiSdkErrorMessage("string error")).toBe("string error") + expect(extractAiSdkErrorMessage({ custom: "object" })).toBe("[object Object]") + }) + }) + + describe("handleAiSdkError", () => { + it("should wrap error with provider name", () => { + const error = new Error("API Error") + const result = handleAiSdkError(error, "Fireworks") + + expect(result.message).toBe("Fireworks: API Error") + }) + + it("should preserve status code from AI_RetryError", () => { + const retryError = { + name: "AI_RetryError", + errors: [new Error("Too Many Requests")], + lastError: { message: "Too Many Requests", status: 429 }, + } + + const result = handleAiSdkError(retryError, "Groq") + + expect(result.message).toContain("Groq:") + expect(result.message).toContain("429") + expect((result as any).status).toBe(429) + }) + + it("should preserve status code from AI_APICallError", () => { + const apiError = { + name: "AI_APICallError", + message: "Unauthorized", + status: 401, + } + + const result = handleAiSdkError(apiError, "DeepSeek") + + expect(result.message).toContain("DeepSeek:") + expect(result.message).toContain("401") + expect((result as any).status).toBe(401) + }) + + it("should preserve original error as cause", () => { + const originalError = new Error("Original error") + const result = handleAiSdkError(originalError, "Cerebras") + + expect((result as any).cause).toBe(originalError) + }) + }) }) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index 0722767072..ebbf1a8661 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -314,3 +314,84 @@ export function mapToolChoice(toolChoice: any): AiSdkToolChoice { return undefined } + +/** + * Extract a user-friendly error message from AI SDK errors. + * The AI SDK wraps errors in types like AI_RetryError and AI_APICallError + * which need to be unwrapped to get the actual error message. + * + * @param error - The error to extract the message from + * @returns A user-friendly error message + */ +export function extractAiSdkErrorMessage(error: unknown): string { + if (!error) { + return "Unknown error" + } + + // Cast to access AI SDK error properties + const anyError = error as any + + // AI_RetryError has a lastError property with the actual error + if (anyError.name === "AI_RetryError") { + const retryCount = anyError.errors?.length || 0 + const lastError = anyError.lastError + const lastErrorMessage = lastError?.message || lastError?.toString() || "Unknown error" + + // Extract status code if available + const statusCode = + lastError?.status || lastError?.statusCode || anyError.status || anyError.statusCode || undefined + + if (statusCode) { + return `Failed after ${retryCount} attempts (${statusCode}): ${lastErrorMessage}` + } + return `Failed after ${retryCount} attempts: ${lastErrorMessage}` + } + + // AI_APICallError has message and optional status + if (anyError.name === "AI_APICallError") { + const statusCode = anyError.status || anyError.statusCode + if (statusCode) { + return `API Error (${statusCode}): ${anyError.message}` + } + return anyError.message || "API call failed" + } + + // Standard Error + if (error instanceof Error) { + return error.message + } + + // Fallback for non-Error objects + return String(error) +} + +/** + * Handle AI SDK errors by extracting the message and preserving status codes. + * Returns an Error object with proper status preserved for retry logic. + * + * @param error - The AI SDK error to handle + * @param providerName - The name of the provider for context + * @returns An Error with preserved status code + */ +export function handleAiSdkError(error: unknown, providerName: string): Error { + const message = extractAiSdkErrorMessage(error) + const wrappedError = new Error(`${providerName}: ${message}`) + + // Preserve status code for retry logic + const anyError = error as any + const statusCode = + anyError?.lastError?.status || + anyError?.lastError?.statusCode || + anyError?.status || + anyError?.statusCode || + undefined + + if (statusCode) { + ;(wrappedError as any).status = statusCode + } + + // Preserve the original error for debugging + ;(wrappedError as any).cause = error + + return wrappedError +}