diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index e287787ede..01f1fb5f41 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -752,6 +752,9 @@ importers: '@ai-sdk/deepseek': specifier: ^2.0.14 version: 2.0.14(zod@3.25.76) + '@ai-sdk/fireworks': + specifier: ^2.0.26 + version: 2.0.26(zod@3.25.76) '@ai-sdk/groq': specifier: ^3.0.19 version: 3.0.19(zod@3.25.76) @@ -1411,6 +1414,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/fireworks@2.0.26': + resolution: {integrity: sha512-vBqSSksHhDGrSNYnmEmVGvLicHFjL4yAxFZfCb6ydrg+qgnlW2bdyTQDMI69BKG4spNZ1/iHMxRNIQpx19Yf6w==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/gateway@3.0.25': resolution: {integrity: sha512-j0AQeA7hOVqwImykQlganf/Euj3uEXf0h3G0O4qKTDpEwE+EZGIPnVimCWht5W91lAetPZSfavDyvfpuPDd2PQ==} engines: {node: '>=18'} @@ -1429,6 +1438,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.24': + resolution: {integrity: sha512-3QrCKpQCn3g6sIMoFGuEroaqk7Xg+qfsohRp4dKszjto5stjBg4SdtOKqHg+CpE3X4woj2O62w2qr5dSekMZeQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -11042,6 +11057,13 @@ snapshots: '@ai-sdk/provider-utils': 4.0.10(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/fireworks@2.0.26(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.24(zod@3.25.76) + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/gateway@3.0.25(zod@3.25.76)': dependencies: '@ai-sdk/provider': 3.0.5 @@ -11061,6 +11083,12 @@ snapshots: '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.24(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -15070,7 +15098,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.57)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: diff --git a/src/api/providers/__tests__/cerebras.spec.ts b/src/api/providers/__tests__/cerebras.spec.ts index aefb8a599c..caf8861b46 100644 --- a/src/api/providers/__tests__/cerebras.spec.ts +++ b/src/api/providers/__tests__/cerebras.spec.ts @@ -452,51 +452,4 @@ describe("CerebrasHandler", () => { expect(toolCallChunks.length).toBe(0) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestCerebrasHandler extends CerebrasHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestCerebrasHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/__tests__/deepseek.spec.ts b/src/api/providers/__tests__/deepseek.spec.ts index 82b08aaad5..ece03c068e 100644 --- a/src/api/providers/__tests__/deepseek.spec.ts +++ b/src/api/providers/__tests__/deepseek.spec.ts @@ -733,51 +733,4 @@ describe("DeepSeekHandler", () => { expect(result).toBe(8192) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestDeepSeekHandler extends DeepSeekHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestDeepSeekHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index 79f69f868b..77c4b10f45 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -1,594 +1,845 @@ -// npx vitest run api/providers/__tests__/fireworks.spec.ts +// npx vitest run src/api/providers/__tests__/fireworks.spec.ts -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) -import { FireworksHandler } from "../fireworks" +vi.mock("@ai-sdk/fireworks", () => ({ + createFireworks: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", + provider: "fireworks", + })) + }), +})) -// Create mock functions -const mockCreate = vi.fn() +import type { Anthropic } from "@anthropic-ai/sdk" -// Mock OpenAI module -vi.mock("openai", () => ({ - default: vi.fn(() => ({ - chat: { - completions: { - create: mockCreate, - }, - }, - })), -})) +import { fireworksDefaultModelId, fireworksModels, type FireworksModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { FireworksHandler } from "../fireworks" describe("FireworksHandler", () => { let handler: FireworksHandler + let mockOptions: ApiHandlerOptions beforeEach(() => { + mockOptions = { + fireworksApiKey: "test-fireworks-api-key", + apiModelId: "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", + } + handler = new FireworksHandler(mockOptions) vi.clearAllMocks() - // Set up default mock implementation - mockCreate.mockImplementation(async () => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - usage: null, - } - yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }, - })) - handler = new FireworksHandler({ fireworksApiKey: "test-key" }) }) - afterEach(() => { - vi.restoreAllMocks() - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(FireworksHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it("should use the correct Fireworks base URL", () => { - new FireworksHandler({ fireworksApiKey: "test-fireworks-api-key" }) - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ baseURL: "https://api.fireworks.ai/inference/v1" }), - ) + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new FireworksHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(fireworksDefaultModelId) + }) }) - it("should use the provided API key", () => { - const fireworksApiKey = "test-fireworks-api-key" - new FireworksHandler({ fireworksApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: fireworksApiKey })) - }) + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new FireworksHandler({ + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(fireworksDefaultModelId) + expect(model.info).toEqual(fireworksModels[fireworksDefaultModelId]) + }) - it("should throw error when API key is not provided", () => { - expect(() => new FireworksHandler({})).toThrow("API key is required") - }) + it("should return specified model when valid model is provided", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(fireworksModels[testModelId]) + }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(fireworksDefaultModelId) - expect(model.info).toEqual(expect.objectContaining(fireworksModels[fireworksDefaultModelId])) - }) + it("should return Kimi K2 Instruct model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.6, + outputPrice: 2.5, + description: expect.stringContaining("Kimi K2 is a state-of-the-art mixture-of-experts"), + }), + ) + }) - it("should return specified model when valid model is provided", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return Kimi K2 Thinking model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16000, + contextWindow: 256000, + supportsImages: false, + supportsPromptCache: true, + supportsTemperature: true, + preserveReasoning: true, + defaultTemperature: 1.0, + inputPrice: 0.6, + outputPrice: 2.5, + cacheReadsPrice: 0.15, + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(expect.objectContaining(fireworksModels[testModelId])) - }) - it("should return Kimi K2 Instruct model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return MiniMax M2 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/minimax-m2" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 4096, + contextWindow: 204800, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.3, + outputPrice: 1.2, + description: expect.stringContaining("MiniMax M2 is a high-performance language model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.6, - outputPrice: 2.5, - description: expect.stringContaining("Kimi K2 is a state-of-the-art mixture-of-experts"), - }), - ) - }) - it("should return Kimi K2 Thinking model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return Qwen3 235B model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 32768, + contextWindow: 256000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.22, + outputPrice: 0.88, + description: + "Latest Qwen3 thinking model, competitive against the best closed source models in Jul 2025.", + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16000, - contextWindow: 256000, - supportsImages: false, - supportsPromptCache: true, - supportsTemperature: true, - preserveReasoning: true, - defaultTemperature: 1.0, - inputPrice: 0.6, - outputPrice: 2.5, - cacheReadsPrice: 0.15, - }), - ) - }) - it("should return MiniMax M2 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/minimax-m2" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek R1 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-r1-0528" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 20480, + contextWindow: 160000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 3, + outputPrice: 8, + description: expect.stringContaining("05/28 updated checkpoint of Deepseek R1"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 4096, - contextWindow: 204800, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.3, - outputPrice: 1.2, - description: expect.stringContaining("MiniMax M2 is a high-performance language model"), - }), - ) - }) - it("should return Qwen3 235B model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek V3 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: expect.stringContaining("strong Mixture-of-Experts (MoE) language model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 32768, - contextWindow: 256000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.22, - outputPrice: 0.88, - description: - "Latest Qwen3 thinking model, competitive against the best closed source models in Jul 2025.", - }), - ) - }) - it("should return DeepSeek R1 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-r1-0528" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return DeepSeek V3.1 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3p1" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 163840, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.56, + outputPrice: 1.68, + description: expect.stringContaining("DeepSeek v3.1 is an improved version"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 20480, - contextWindow: 160000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 3, - outputPrice: 8, - description: expect.stringContaining("05/28 updated checkpoint of Deepseek R1"), - }), - ) - }) - it("should return DeepSeek V3 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.5 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.5 with 355B total parameters"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.9, - outputPrice: 0.9, - description: expect.stringContaining("strong Mixture-of-Experts (MoE) language model"), - }), - ) - }) - it("should return DeepSeek V3.1 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/deepseek-v3p1" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.5-Air model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5-air" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.5-Air with 106B total parameters"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 163840, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.56, - outputPrice: 1.68, - description: expect.stringContaining("DeepSeek v3.1 is an improved version"), - }), - ) - }) - it("should return GLM-4.5 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return GLM-4.6 model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p6" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 25344, + contextWindow: 198000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.55, + outputPrice: 2.19, + description: expect.stringContaining("Z.ai GLM-4.6 is an advanced coding model"), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.5 with 355B total parameters"), - }), - ) - }) - it("should return GLM-4.5-Air model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p5-air" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return gpt-oss-20b model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-20b" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.07, + outputPrice: 0.3, + description: expect.stringContaining( + "OpenAI gpt-oss-20b: Compact model for local/edge deployments", + ), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.5-Air with 106B total parameters"), - }), - ) - }) - it("should return GLM-4.6 model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/glm-4p6" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return gpt-oss-120b model with correct configuration", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-120b" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + description: expect.stringContaining( + "OpenAI gpt-oss-120b: Production-grade, general-purpose model", + ), + }), + ) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 25344, - contextWindow: 198000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.55, - outputPrice: 2.19, - description: expect.stringContaining("Z.ai GLM-4.6 is an advanced coding model"), - }), - ) - }) - it("should return gpt-oss-20b model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-20b" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new FireworksHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + // Should use default model info + expect(model.info).toBe(fireworksModels[fireworksDefaultModelId]) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.07, - outputPrice: 0.3, - description: expect.stringContaining("OpenAI gpt-oss-20b: Compact model for local/edge deployments"), - }), - ) - }) - it("should return gpt-oss-120b model with correct configuration", () => { - const testModelId: FireworksModelId = "accounts/fireworks/models/gpt-oss-120b" - const handlerWithModel = new FireworksHandler({ - apiModelId: testModelId, - fireworksApiKey: "test-fireworks-api-key", + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual( - expect.objectContaining({ - maxTokens: 16384, - contextWindow: 128000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0.15, - outputPrice: 0.6, - description: expect.stringContaining("OpenAI gpt-oss-120b: Production-grade, general-purpose model"), - }), - ) }) - it("completePrompt method should return text from Fireworks API", async () => { - const expectedResponse = "This is a test response from Fireworks" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] - it("should handle errors in completePrompt", async () => { - const errorMessage = "Fireworks API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow( - `Fireworks completion error: ${errorMessage}`, - ) - }) + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from Fireworks" } + } - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from Fireworks stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from Fireworks") }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) - }) + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), - }), + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) - }) + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) - it("createMessage should pass correct parameters to Fireworks client", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const modelInfo = fireworksModels[modelId] - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", + // Fireworks provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + fireworks: { + promptCacheHitTokens: 30, + promptCacheMissTokens: 70, + }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) + expect(usageChunks[0].cacheWriteTokens).toBe(70) }) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() + }) + + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-instruct", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, }), + ) + }) + + it("should use model defaultTemperature (1.0) over provider default (0.5) for kimi-k2-thinking", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithThinkingModel = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-thinking", + }) + + const stream = handlerWithThinkingModel.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream } + + // Model's defaultTemperature (1.0) should take precedence over provider's default (0.5) + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 1.0, + }), + ) }) - const systemPrompt = "Test system prompt for Fireworks" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Fireworks" }] + it("should use user-specified temperature over model and provider defaults", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new FireworksHandler({ + fireworksApiKey: "test-key", + apiModelId: "accounts/fireworks/models/kimi-k2-thinking", + modelTemperature: 0.7, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0.5, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - undefined, - ) + // User-specified temperature should take precedence over everything + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) + }) + + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) }) - it("should use provider default temperature of 0.5 for models without defaultTemperature", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from Fireworks", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from Fireworks") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) }) - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - })) + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + await handler.completePrompt("Test prompt") - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.5, - }), - undefined, - ) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) }) - it("should use model defaultTemperature (1.0) over provider default (0.5) for kimi-k2-thinking", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", - }) + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + const providerMetadata = { + fireworks: { + promptCacheHitTokens: 20, + promptCacheMissTokens: 80, }, - }), - })) + } - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) - // Model's defaultTemperature (1.0) should take precedence over provider's default (0.5) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 1.0, - }), - undefined, - ) - }) + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) + }) - it("should use user-specified temperature over model and provider defaults", async () => { - const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-thinking" - const handlerWithModel = new FireworksHandler({ - apiModelId: modelId, - fireworksApiKey: "test-fireworks-api-key", - modelTemperature: 0.7, + it("should handle missing cache metrics gracefully", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestFireworksHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() }) - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - })) + it("should include reasoning tokens when provided", () => { + class TestFireworksHandler extends FireworksHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } - const messageGenerator = handlerWithModel.createMessage("system", []) - await messageGenerator.next() + const testHandler = new TestFireworksHandler(mockOptions) - // User-specified temperature should take precedence over everything - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.7, - }), - undefined, - ) - }) + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } - it("should handle empty response in completePrompt", async () => { - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: null } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe("") - }) + const result = testHandler.testProcessUsageMetrics(usage) - it("should handle missing choices in completePrompt", async () => { - mockCreate.mockResolvedValueOnce({ choices: [] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe("") + expect(result.reasoningTokens).toBe(30) + }) }) - it("createMessage should handle stream with multiple chunks", async () => { - mockCreate.mockImplementationOnce(async () => ({ - [Symbol.asyncIterator]: async function* () { + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { yield { - choices: [ - { - delta: { content: "Hello" }, - index: 0, - }, - ], - usage: null, + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", } yield { - choices: [ - { - delta: { content: " world" }, - index: 0, - }, - ], - usage: null, + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', } yield { - choices: [ - { - delta: {}, - index: 0, + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, }, - ], - usage: { - prompt_tokens: 5, - completion_tokens: 10, - total_tokens: 15, }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, } - }, - })) + } - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) - const stream = handler.createMessage(systemPrompt, messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) - expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) - expect(chunks[1]).toEqual({ type: "text", text: " world" }) - expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // tool-call events should be ignored (only tool-input-start/delta/end are processed) + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) + }) }) }) diff --git a/src/api/providers/__tests__/groq.spec.ts b/src/api/providers/__tests__/groq.spec.ts index c4a9471c87..efb5712cb9 100644 --- a/src/api/providers/__tests__/groq.spec.ts +++ b/src/api/providers/__tests__/groq.spec.ts @@ -575,51 +575,4 @@ describe("GroqHandler", () => { expect(result).toBe(customMaxTokens) }) }) - - describe("mapToolChoice", () => { - it("should handle string tool choices", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - expect(testHandler.testMapToolChoice("auto")).toBe("auto") - expect(testHandler.testMapToolChoice("none")).toBe("none") - expect(testHandler.testMapToolChoice("required")).toBe("required") - expect(testHandler.testMapToolChoice("unknown")).toBe("auto") - }) - - it("should handle object tool choice with function name", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - const result = testHandler.testMapToolChoice({ - type: "function", - function: { name: "my_tool" }, - }) - - expect(result).toEqual({ type: "tool", toolName: "my_tool" }) - }) - - it("should return undefined for null or undefined", () => { - class TestGroqHandler extends GroqHandler { - public testMapToolChoice(toolChoice: any) { - return this.mapToolChoice(toolChoice) - } - } - - const testHandler = new TestGroqHandler(mockOptions) - - expect(testHandler.testMapToolChoice(null)).toBeUndefined() - expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() - }) - }) }) diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index 0fbc375bdc..de1a4b2dbb 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -6,7 +6,13 @@ import { cerebrasModels, cerebrasDefaultModelId, type CerebrasModelId, type Mode import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -75,40 +81,6 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -143,23 +115,28 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end - const usage = await result.usage - if (usage) { - yield this.processUsageMetrics(usage) + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Cerebras") } } diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index 27ecaebd5e..ba9c9d47e3 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -6,7 +6,13 @@ import { deepSeekModels, deepSeekDefaultModelId, DEEP_SEEK_DEFAULT_TEMPERATURE, import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -83,40 +89,6 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -152,24 +124,29 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan temperature: this.options.modelTemperature ?? temperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end, including cache metrics from providerMetadata - const usage = await result.usage - const providerMetadata = await result.providerMetadata - if (usage) { - yield this.processUsageMetrics(usage, providerMetadata as any) + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "DeepSeek") } } diff --git a/src/api/providers/fireworks.ts b/src/api/providers/fireworks.ts index db29e7bf3f..52bf431bb6 100644 --- a/src/api/providers/fireworks.ts +++ b/src/api/providers/fireworks.ts @@ -1,19 +1,175 @@ -import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createFireworks } from "@ai-sdk/fireworks" +import { streamText, generateText, ToolSet } from "ai" + +import { fireworksModels, fireworksDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const FIREWORKS_DEFAULT_TEMPERATURE = 0.5 + +/** + * Fireworks provider using the dedicated @ai-sdk/fireworks package. + * Provides native support for various models including reasoning models. + */ +export class FireworksHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType -export class FireworksHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "Fireworks", + super() + this.options = options + + // Create the Fireworks provider using AI SDK + this.provider = createFireworks({ baseURL: "https://api.fireworks.ai/inference/v1", - apiKey: options.fireworksApiKey, - defaultProviderModelId: fireworksDefaultModelId, - providerModels: fireworksModels, - defaultTemperature: 0.5, + apiKey: options.fireworksApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, + }) + } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? fireworksDefaultModelId + const info = fireworksModels[id as keyof typeof fireworksModels] || fireworksModels[fireworksDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: FIREWORKS_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } + } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + fireworks?: { + promptCacheHitTokens?: number + promptCacheMissTokens?: number + } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from Fireworks' providerMetadata if available + const cacheReadTokens = providerMetadata?.fireworks?.promptCacheHitTokens ?? usage.details?.cachedInputTokens + const cacheWriteTokens = providerMetadata?.fireworks?.promptCacheMissTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + } + + // Use streamText for streaming responses + const result = streamText(requestOptions) + + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Fireworks") + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE, }) + + return text } } diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 64399ad674..648679f92c 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -6,7 +6,13 @@ import { groqModels, groqDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -91,40 +97,6 @@ export class GroqHandler extends BaseProvider implements SingleCompletionHandler } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: any, - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -160,24 +132,29 @@ export class GroqHandler extends BaseProvider implements SingleCompletionHandler temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events including reasoning - for await (const part of result.fullStream) { - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end, including cache metrics from providerMetadata - const usage = await result.usage - const providerMetadata = await result.providerMetadata - if (usage) { - yield this.processUsageMetrics(usage, providerMetadata as any) + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Groq") } } diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index d129e72452..240de747be 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -12,7 +12,13 @@ import type { ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { DEFAULT_HEADERS } from "./constants" @@ -103,40 +109,6 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si } } - /** - * Map OpenAI tool_choice to AI SDK toolChoice format. - */ - protected mapToolChoice( - toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], - ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { - if (!toolChoice) { - return undefined - } - - // Handle string values - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "auto": - return "auto" - case "none": - return "none" - case "required": - return "required" - default: - return "auto" - } - } - - // Handle object values (OpenAI ChatCompletionNamedToolChoice format) - if (typeof toolChoice === "object" && "type" in toolChoice) { - if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { - return { type: "tool", toolName: toolChoice.function.name } - } - } - - return undefined - } - /** * Get the max tokens parameter to include in the request. */ @@ -173,24 +145,29 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si temperature: model.temperature ?? this.config.temperature ?? 0, maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, - toolChoice: this.mapToolChoice(metadata?.tool_choice), + toolChoice: mapToolChoice(metadata?.tool_choice), } // Use streamText for streaming responses const result = streamText(requestOptions) - // Process the full stream to get all events - for await (const part of result.fullStream) { - // Use the processAiSdkStreamPart utility to convert stream parts - for (const chunk of processAiSdkStreamPart(part)) { - yield chunk + try { + // Process the full stream to get all events + for await (const part of result.fullStream) { + // Use the processAiSdkStreamPart utility to convert stream parts + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - } - // Yield usage metrics at the end - const usage = await result.usage - if (usage) { - yield this.processUsageMetrics(usage) + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, this.config.providerName) } } diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index 293d720d48..bd87fd8eeb 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -1,6 +1,13 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../ai-sdk" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + extractAiSdkErrorMessage, + handleAiSdkError, +} from "../ai-sdk" vitest.mock("ai", () => ({ tool: vitest.fn((t) => t), @@ -486,4 +493,155 @@ describe("AI SDK conversion utilities", () => { } }) }) + + describe("mapToolChoice", () => { + it("should return undefined for null or undefined", () => { + expect(mapToolChoice(null)).toBeUndefined() + expect(mapToolChoice(undefined)).toBeUndefined() + }) + + it("should handle string tool choices", () => { + expect(mapToolChoice("auto")).toBe("auto") + expect(mapToolChoice("none")).toBe("none") + expect(mapToolChoice("required")).toBe("required") + }) + + it("should return auto for unknown string values", () => { + expect(mapToolChoice("unknown")).toBe("auto") + expect(mapToolChoice("invalid")).toBe("auto") + }) + + it("should handle object tool choice with function name", () => { + const result = mapToolChoice({ + type: "function", + function: { name: "my_tool" }, + }) + + expect(result).toEqual({ type: "tool", toolName: "my_tool" }) + }) + + it("should return undefined for object without function name", () => { + const result = mapToolChoice({ + type: "function", + function: {}, + }) + + expect(result).toBeUndefined() + }) + + it("should return undefined for object with non-function type", () => { + const result = mapToolChoice({ + type: "other", + function: { name: "my_tool" }, + }) + + expect(result).toBeUndefined() + }) + }) + + describe("extractAiSdkErrorMessage", () => { + it("should return 'Unknown error' for null/undefined", () => { + expect(extractAiSdkErrorMessage(null)).toBe("Unknown error") + expect(extractAiSdkErrorMessage(undefined)).toBe("Unknown error") + }) + + it("should extract message from AI_RetryError", () => { + const retryError = { + name: "AI_RetryError", + message: "Failed after 3 attempts", + errors: [new Error("Error 1"), new Error("Error 2"), new Error("Too Many Requests")], + lastError: { message: "Too Many Requests", status: 429 }, + } + + const result = extractAiSdkErrorMessage(retryError) + expect(result).toBe("Failed after 3 attempts (429): Too Many Requests") + }) + + it("should handle AI_RetryError without status", () => { + const retryError = { + name: "AI_RetryError", + message: "Failed after 2 attempts", + errors: [new Error("Error 1"), new Error("Connection failed")], + lastError: { message: "Connection failed" }, + } + + const result = extractAiSdkErrorMessage(retryError) + expect(result).toBe("Failed after 2 attempts: Connection failed") + }) + + it("should extract message from AI_APICallError", () => { + const apiError = { + name: "AI_APICallError", + message: "Rate limit exceeded", + status: 429, + } + + const result = extractAiSdkErrorMessage(apiError) + expect(result).toBe("API Error (429): Rate limit exceeded") + }) + + it("should handle AI_APICallError without status", () => { + const apiError = { + name: "AI_APICallError", + message: "Connection timeout", + } + + const result = extractAiSdkErrorMessage(apiError) + expect(result).toBe("Connection timeout") + }) + + it("should extract message from standard Error", () => { + const error = new Error("Something went wrong") + expect(extractAiSdkErrorMessage(error)).toBe("Something went wrong") + }) + + it("should convert non-Error to string", () => { + expect(extractAiSdkErrorMessage("string error")).toBe("string error") + expect(extractAiSdkErrorMessage({ custom: "object" })).toBe("[object Object]") + }) + }) + + describe("handleAiSdkError", () => { + it("should wrap error with provider name", () => { + const error = new Error("API Error") + const result = handleAiSdkError(error, "Fireworks") + + expect(result.message).toBe("Fireworks: API Error") + }) + + it("should preserve status code from AI_RetryError", () => { + const retryError = { + name: "AI_RetryError", + errors: [new Error("Too Many Requests")], + lastError: { message: "Too Many Requests", status: 429 }, + } + + const result = handleAiSdkError(retryError, "Groq") + + expect(result.message).toContain("Groq:") + expect(result.message).toContain("429") + expect((result as any).status).toBe(429) + }) + + it("should preserve status code from AI_APICallError", () => { + const apiError = { + name: "AI_APICallError", + message: "Unauthorized", + status: 401, + } + + const result = handleAiSdkError(apiError, "DeepSeek") + + expect(result.message).toContain("DeepSeek:") + expect(result.message).toContain("401") + expect((result as any).status).toBe(401) + }) + + it("should preserve original error as cause", () => { + const originalError = new Error("Original error") + const result = handleAiSdkError(originalError, "Cerebras") + + expect((result as any).cause).toBe(originalError) + }) + }) }) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index fd86532bf1..ebbf1a8661 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -273,3 +273,125 @@ export function* processAiSdkStreamPart(part: ExtendedStreamPart): Generator