From f5b6409275eeb5a315aac681bcf5a1acf569bced Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Thu, 29 Jan 2026 17:36:36 -0500 Subject: [PATCH] feat: migrate Groq provider to @ai-sdk/groq - Replace openai SDK with @ai-sdk/groq package - Extend BaseAiSdkProvider instead of BaseOpenAiCompatibleProvider - Implement Groq-specific cache metrics via providerMetadata - Update tests to use AI SDK mocking patterns Part of EXT-644 --- pnpm-lock.yaml | 27 +- src/api/providers/__tests__/groq.spec.ts | 715 ++++++++++++++++++----- src/api/providers/groq.ts | 201 ++++++- src/package.json | 1 + 4 files changed, 787 insertions(+), 157 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a7eebc66f10..8a9015cb079 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -748,7 +748,10 @@ importers: version: 1.0.35(zod@3.25.76) '@ai-sdk/deepseek': specifier: ^2.0.14 - version: 2.0.15(zod@3.25.76) + version: 2.0.14(zod@3.25.76) + '@ai-sdk/groq': + specifier: ^3.0.19 + version: 3.0.19(zod@3.25.76) '@anthropic-ai/bedrock-sdk': specifier: ^0.10.2 version: 0.10.4 @@ -1399,8 +1402,8 @@ packages: peerDependencies: zod: 3.25.76 - '@ai-sdk/deepseek@2.0.15': - resolution: {integrity: sha512-3wJUjNjGrTZS3K8OEfHD1PZYhzkcXuoL8KIVtzi6WrC5xrDQPjCBPATmdKPV7DgDCF+wujQOaMz5cv40Yg+hog==} + '@ai-sdk/deepseek@2.0.14': + resolution: {integrity: sha512-1vXh8sVwRJYd1JO57qdy1rACucaNLDoBRCwOER3EbPgSF2vNVPcdJywGutA01Bhn7Cta+UJQ+k5y/yzMAIpP2w==} engines: {node: '>=18'} peerDependencies: zod: 3.25.76 @@ -1411,6 +1414,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/groq@3.0.19': + resolution: {integrity: sha512-WAeGVnp9rvU3RUvu6S1HiD8hAjKgNlhq+z3m4j5Z1fIKRXqcKjOscVZGwL36If8qxsqXNVCtG3ltXawM5UAa8w==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/openai-compatible@1.0.31': resolution: {integrity: sha512-znBvaVHM0M6yWNerIEy3hR+O8ZK2sPcE7e2cxfb6kYLEX3k//JH5VDnRnajseVofg7LXtTCFFdjsB7WLf1BdeQ==} engines: {node: '>=18'} @@ -10845,10 +10854,10 @@ snapshots: '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 - '@ai-sdk/deepseek@2.0.15(zod@3.25.76)': + '@ai-sdk/deepseek@2.0.14(zod@3.25.76)': dependencies: - '@ai-sdk/provider': 3.0.6 - '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + '@ai-sdk/provider': 3.0.5 + '@ai-sdk/provider-utils': 4.0.10(zod@3.25.76) zod: 3.25.76 '@ai-sdk/gateway@3.0.25(zod@3.25.76)': @@ -10858,6 +10867,12 @@ snapshots: '@vercel/oidc': 3.1.0 zod: 3.25.76 + '@ai-sdk/groq@3.0.19(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/openai-compatible@1.0.31(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 diff --git a/src/api/providers/__tests__/groq.spec.ts b/src/api/providers/__tests__/groq.spec.ts index f89fd62a7fd..c4a9471c87c 100644 --- a/src/api/providers/__tests__/groq.spec.ts +++ b/src/api/providers/__tests__/groq.spec.ts @@ -1,192 +1,625 @@ // npx vitest run src/api/providers/__tests__/groq.spec.ts -import OpenAI from "openai" -import { Anthropic } from "@anthropic-ai/sdk" +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" - -import { GroqHandler } from "../groq" - -vitest.mock("openai", () => { - const createMock = vitest.fn() +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("@ai-sdk/groq", () => ({ + createGroq: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "moonshotai/kimi-k2-instruct-0905", + provider: "groq", + })) + }), +})) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import { groqDefaultModelId, groqModels, type GroqModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { GroqHandler } from "../groq" + describe("GroqHandler", () => { let handler: GroqHandler - let mockCreate: any + let mockOptions: ApiHandlerOptions beforeEach(() => { - vitest.clearAllMocks() - mockCreate = (OpenAI as unknown as any)().chat.completions.create - handler = new GroqHandler({ groqApiKey: "test-groq-api-key" }) + mockOptions = { + groqApiKey: "test-groq-api-key", + apiModelId: "moonshotai/kimi-k2-instruct-0905", + } + handler = new GroqHandler(mockOptions) + vi.clearAllMocks() }) - it("should use the correct Groq base URL", () => { - new GroqHandler({ groqApiKey: "test-groq-api-key" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.groq.com/openai/v1" })) - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(GroqHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it("should use the provided API key", () => { - const groqApiKey = "test-groq-api-key" - new GroqHandler({ groqApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: groqApiKey })) + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new GroqHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(groqDefaultModelId) + }) }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(groqDefaultModelId) - expect(model.info).toEqual(groqModels[groqDefaultModelId]) - }) + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new GroqHandler({ + groqApiKey: "test-groq-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(groqDefaultModelId) + expect(model.info).toEqual(groqModels[groqDefaultModelId]) + }) - it("should return specified model when valid model is provided", () => { - const testModelId: GroqModelId = "llama-3.3-70b-versatile" - const handlerWithModel = new GroqHandler({ apiModelId: testModelId, groqApiKey: "test-groq-api-key" }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(groqModels[testModelId]) - }) + it("should return specified model when valid model is provided", () => { + const testModelId: GroqModelId = "llama-3.3-70b-versatile" + const handlerWithModel = new GroqHandler({ + apiModelId: testModelId, + groqApiKey: "test-groq-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(groqModels[testModelId]) + }) - it("completePrompt method should return text from Groq API", async () => { - const expectedResponse = "This is a test response from Groq" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + it("should return model info for llama-3.1-8b-instant", () => { + const handlerWithLlama = new GroqHandler({ + ...mockOptions, + apiModelId: "llama-3.1-8b-instant", + }) + const model = handlerWithLlama.getModel() + expect(model.id).toBe("llama-3.1-8b-instant") + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(131072) + expect(model.info.supportsImages).toBe(false) + expect(model.info.supportsPromptCache).toBe(false) + }) - it("should handle errors in completePrompt", async () => { - const errorMessage = "Groq API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Groq completion error: ${errorMessage}`) + it("should return model info for kimi-k2 which supports prompt cache", () => { + const handlerWithKimi = new GroqHandler({ + ...mockOptions, + apiModelId: "moonshotai/kimi-k2-instruct-0905", + }) + const model = handlerWithKimi.getModel() + expect(model.id).toBe("moonshotai/kimi-k2-instruct-0905") + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(16384) + expect(model.info.contextWindow).toBe(262144) + expect(model.info.supportsPromptCache).toBe(true) + }) + + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new GroqHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + // Should use default model info + expect(model.info).toBe(groqModels[groqDefaultModelId]) + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) }) - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from Groq stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from Groq" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from Groq") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) + + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) + + // Groq provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + groq: { + promptCacheHitTokens: 30, + promptCacheMissTokens: 70, + }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) + expect(usageChunks[0].cacheWriteTokens).toBe(70) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) + + const mockProviderMetadata = Promise.resolve({}) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() + }) + + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new GroqHandler({ + groqApiKey: "test-key", + apiModelId: "llama-3.1-8b-instant", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) }) - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from Groq", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from Groq") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) + }) + + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + await handler.completePrompt("Test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, }), + ) + }) + }) + + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestGroqHandler extends GroqHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestGroqHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, } + + const providerMetadata = { + groq: { + promptCacheHitTokens: 20, + promptCacheMissTokens: 80, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle missing cache metrics gracefully", () => { + class TestGroqHandler extends GroqHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestGroqHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ - type: "usage", - inputTokens: 10, - outputTokens: 20, + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() + }) + + it("should include reasoning tokens when provided", () => { + class TestGroqHandler extends GroqHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestGroqHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.reasoningTokens).toBe(30) }) - // cacheWriteTokens and cacheReadTokens will be undefined when 0 - expect(firstChunk.value.cacheWriteTokens).toBeUndefined() - expect(firstChunk.value.cacheReadTokens).toBeUndefined() - // Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts) - expect(typeof firstChunk.value.totalCost).toBe("number") }) - it("createMessage should handle cached tokens in usage data", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], - usage: { - prompt_tokens: 100, - completion_tokens: 50, - prompt_tokens_details: { - cached_tokens: 30, - }, - }, + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], }, - }) - .mockResolvedValueOnce({ done: true }), - }), + }, + }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, + }, + ], + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ - type: "usage", - inputTokens: 100, - outputTokens: 50, - cacheReadTokens: 30, + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // tool-call events are ignored, so no tool_call chunks should be emitted + const toolCallChunks = chunks.filter((c) => c.type === "tool_call") + expect(toolCallChunks.length).toBe(0) }) - // cacheWriteTokens will be undefined when 0 - expect(firstChunk.value.cacheWriteTokens).toBeUndefined() - expect(typeof firstChunk.value.totalCost).toBe("number") }) - it("createMessage should pass correct parameters to Groq client", async () => { - const modelId: GroqModelId = "llama-3.1-8b-instant" - const modelInfo = groqModels[modelId] - const handlerWithModel = new GroqHandler({ apiModelId: modelId, groqApiKey: "test-groq-api-key" }) + describe("getMaxOutputTokens", () => { + it("should return maxTokens from model info", () => { + class TestGroqHandler extends GroqHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() + } + } - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), + const testHandler = new TestGroqHandler({ + ...mockOptions, + apiModelId: "llama-3.1-8b-instant", + }) + const result = testHandler.testGetMaxOutputTokens() + + // llama-3.1-8b-instant has maxTokens of 8192 + expect(result).toBe(8192) + }) + + it("should use modelMaxTokens when provided", () => { + class TestGroqHandler extends GroqHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() + } } + + const customMaxTokens = 5000 + const testHandler = new TestGroqHandler({ + ...mockOptions, + modelMaxTokens: customMaxTokens, + }) + + const result = testHandler.testGetMaxOutputTokens() + expect(result).toBe(customMaxTokens) }) + }) - const systemPrompt = "Test system prompt for Groq" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Groq" }] + describe("mapToolChoice", () => { + it("should handle string tool choices", () => { + class TestGroqHandler extends GroqHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() + const testHandler = new TestGroqHandler(mockOptions) + + expect(testHandler.testMapToolChoice("auto")).toBe("auto") + expect(testHandler.testMapToolChoice("none")).toBe("none") + expect(testHandler.testMapToolChoice("required")).toBe("required") + expect(testHandler.testMapToolChoice("unknown")).toBe("auto") + }) + + it("should handle object tool choice with function name", () => { + class TestGroqHandler extends GroqHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0.5, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - undefined, - ) + const testHandler = new TestGroqHandler(mockOptions) + + const result = testHandler.testMapToolChoice({ + type: "function", + function: { name: "my_tool" }, + }) + + expect(result).toEqual({ type: "tool", toolName: "my_tool" }) + }) + + it("should return undefined for null or undefined", () => { + class TestGroqHandler extends GroqHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) + } + } + + const testHandler = new TestGroqHandler(mockOptions) + + expect(testHandler.testMapToolChoice(null)).toBeUndefined() + expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() + }) }) }) diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 7583edc51cb..64399ad6749 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -1,19 +1,200 @@ -import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createGroq } from "@ai-sdk/groq" +import { streamText, generateText, ToolSet } from "ai" + +import { groqModels, groqDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const GROQ_DEFAULT_TEMPERATURE = 0.5 + +/** + * Groq provider using the dedicated @ai-sdk/groq package. + * Provides native support for reasoning models and prompt caching. + */ +export class GroqHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType -export class GroqHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "Groq", + super() + this.options = options + + // Create the Groq provider using AI SDK + this.provider = createGroq({ baseURL: "https://api.groq.com/openai/v1", - apiKey: options.groqApiKey, - defaultProviderModelId: groqDefaultModelId, - providerModels: groqModels, - defaultTemperature: 0.5, + apiKey: options.groqApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, }) } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? groqDefaultModelId + const info = groqModels[id as keyof typeof groqModels] || groqModels[groqDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: GROQ_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } + } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response, including Groq's cache metrics. + * Groq provides cache hit/miss info via providerMetadata for supported models. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + groq?: { + promptCacheHitTokens?: number + promptCacheMissTokens?: number + } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from Groq's providerMetadata + const cacheReadTokens = providerMetadata?.groq?.promptCacheHitTokens ?? usage.details?.cachedInputTokens + const cacheWriteTokens = providerMetadata?.groq?.promptCacheMissTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Map OpenAI tool_choice to AI SDK toolChoice format. + */ + protected mapToolChoice( + toolChoice: any, + ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { + if (!toolChoice) { + return undefined + } + + // Handle string values + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "auto": + return "auto" + case "none": + return "none" + case "required": + return "required" + default: + return "auto" + } + } + + // Handle object values (OpenAI ChatCompletionNamedToolChoice format) + if (typeof toolChoice === "object" && "type" in toolChoice) { + if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { + return { type: "tool", toolName: toolChoice.function.name } + } + } + + return undefined + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + * Groq supports reasoning for models like qwen/qwen3-32b via reasoningFormat: 'parsed'. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: this.mapToolChoice(metadata?.tool_choice), + } + + // Use streamText for streaming responses + const result = streamText(requestOptions) + + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE, + }) + + return text + } } diff --git a/src/package.json b/src/package.json index edacfd403d2..7641e8d0612 100644 --- a/src/package.json +++ b/src/package.json @@ -452,6 +452,7 @@ "dependencies": { "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", + "@ai-sdk/groq": "^3.0.19", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/vertex-sdk": "^0.7.0",