From bbcf56a063b64f05563f20b78f061675ee6a3e5e Mon Sep 17 00:00:00 2001 From: kiwina Date: Thu, 5 Jun 2025 16:58:07 +0800 Subject: [PATCH 1/4] feat: add vertex API key support and refactor provider handlers - Add vertexApiKey field to ProviderSettings for simplified authentication - Refactor VertexHandler to be fully independent from GeminiHandler - Remove inheritance dependency and vertex-specific logic from GeminiHandler - Implement standalone VertexHandler with vertex-only authentication methods - Add comprehensive debug logging for model selection and API flow - Update vertex provider UI to include API key input field - Add i18n support for vertex API key label and placeholder - Fix vertex tests to use correct default model (claude-sonnet-4@20250514) - Ensure proper model selection and fallback logic for both providers Breaking changes: - VertexHandler no longer inherits from GeminiHandler - Each provider now has completely separate model sets and authentication Closes: #TBD --- packages/types/src/provider-settings.ts | 2 + packages/types/src/providers/vertex.ts | 1 + src/api/providers/__tests__/gemini.test.ts | 439 ++++++++++++++---- src/api/providers/__tests__/vertex.test.ts | 395 +++++++++++++++- src/api/providers/gemini.ts | 99 +--- src/api/providers/vertex.ts | 159 ++++++- .../__tests__/calculateCostGenai.test.ts | 127 +++++ src/utils/calculateCostGenai.ts | 58 +++ .../components/settings/providers/Vertex.tsx | 7 + webview-ui/src/i18n/locales/en/settings.json | 1 + 10 files changed, 1077 insertions(+), 211 deletions(-) create mode 100644 src/utils/__tests__/calculateCostGenai.test.ts create mode 100644 src/utils/calculateCostGenai.ts diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index a60f7e0b282..98e0135dac6 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -106,6 +106,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({ }) const vertexSchema = apiModelIdProviderModelSchema.extend({ + vertexApiKey: z.string().optional(), vertexKeyFile: z.string().optional(), vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), @@ -290,6 +291,7 @@ export const PROVIDER_SETTINGS_KEYS = keysOf()([ "awsBedrockEndpointEnabled", "awsBedrockEndpoint", // Google Vertex + "vertexApiKey", "vertexKeyFile", "vertexJsonCredentials", "vertexProjectId", diff --git a/packages/types/src/providers/vertex.ts b/packages/types/src/providers/vertex.ts index 028d3089235..a2d75602260 100644 --- a/packages/types/src/providers/vertex.ts +++ b/packages/types/src/providers/vertex.ts @@ -233,4 +233,5 @@ export const VERTEX_REGIONS = [ { value: "europe-west1", label: "europe-west1" }, { value: "europe-west4", label: "europe-west4" }, { value: "asia-southeast1", label: "asia-southeast1" }, + { value: "global", label: "global" }, ] diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index 837948af1d2..d62e6f5a2e6 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -3,19 +3,31 @@ import { Anthropic } from "@anthropic-ai/sdk" import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types" +import { calculateCostGenai } from "../../../utils/calculateCostGenai" import { GeminiHandler } from "../gemini" +// Mock the calculateCostGenai function +jest.mock("../../../utils/calculateCostGenai", () => ({ + calculateCostGenai: jest.fn().mockReturnValue(0.005), +})) + +const mockedCalculateCostGenai = calculateCostGenai as jest.MockedFunction + const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" describe("GeminiHandler", () => { let handler: GeminiHandler - beforeEach(() => { + // Reset mocks + jest.clearAllMocks() + mockedCalculateCostGenai.mockReturnValue(0.005) + // Create mock functions const mockGenerateContentStream = jest.fn() const mockGenerateContent = jest.fn() const mockGetGenerativeModel = jest.fn() + const mockCountTokens = jest.fn() handler = new GeminiHandler({ apiKey: "test-key", @@ -29,15 +41,43 @@ describe("GeminiHandler", () => { generateContentStream: mockGenerateContentStream, generateContent: mockGenerateContent, getGenerativeModel: mockGetGenerativeModel, + countTokens: mockCountTokens, }, } as any }) - describe("constructor", () => { it("should initialize with provided config", () => { expect(handler["options"].geminiApiKey).toBe("test-key") expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME) }) + + it("should initialize with geminiApiKey", () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "specific-gemini-key", + apiModelId: "gemini-1.5-flash-002", + }) + + expect(testHandler["options"].geminiApiKey).toBe("specific-gemini-key") + expect(testHandler["options"].apiModelId).toBe("gemini-1.5-flash-002") + }) + + it("should handle missing API key gracefully", () => { + const testHandler = new GeminiHandler({ + apiModelId: "gemini-1.5-flash-002", + }) + + // Should not throw and should have undefined geminiApiKey + expect(testHandler["options"].geminiApiKey).toBeUndefined() + }) + + it("should initialize with baseUrl configuration", () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "test-key", + googleGeminiBaseUrl: "https://custom-gemini.example.com", + }) + + expect(testHandler["options"].googleGeminiBaseUrl).toBe("https://custom-gemini.example.com") + }) }) describe("createMessage", () => { @@ -69,13 +109,18 @@ describe("GeminiHandler", () => { for await (const chunk of stream) { chunks.push(chunk) - } - - // Should have 3 chunks: 'Hello', ' world!', and usage info + } // Should have 3 chunks: 'Hello', ' world!', and usage info expect(chunks.length).toBe(3) expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) expect(chunks[1]).toEqual({ type: "text", text: " world!" }) - expect(chunks[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 5 }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheReadTokens: undefined, + reasoningTokens: undefined, + totalCost: 0.005 + }) // Verify the call to generateContentStream expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith( @@ -88,18 +133,111 @@ describe("GeminiHandler", () => { }), ) }) + it("should handle reasoning/thinking output", async () => { + // Mock response with thinking parts + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [ + { thought: true, text: "Let me think about this..." }, + { text: "Here's my response" }, + ], + }, + }, + ], + } + yield { usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 8, thoughtsTokenCount: 5 } } + }, + }) - it("should handle API errors", async () => { - const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError) + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(3) + expect(chunks[0]).toEqual({ type: "reasoning", text: "Let me think about this..." }) + expect(chunks[1]).toEqual({ type: "text", text: "Here's my response" }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 15, + outputTokens: 8, + reasoningTokens: 5, + cacheReadTokens: undefined, + totalCost: 0.005, + }) + }) + + it("should handle custom baseUrl configuration", async () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "test-key", + googleGeminiBaseUrl: "https://custom-gemini.example.com", + apiModelId: "gemini-1.5-flash-002", + }) + + // Mock the client + testHandler["client"] = { + models: { + generateContentStream: jest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Custom response" } + }, + }), + }, + } as any + + const stream = testHandler.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(testHandler["client"].models.generateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + httpOptions: { baseUrl: "https://custom-gemini.example.com" }, + }), + }), + ) + }) + it("should handle usage metadata with cache and reasoning tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 100, + candidatesTokenCount: 50, + cachedContentTokenCount: 25, + thoughtsTokenCount: 15, + }, + } + }, + }) const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - await expect(async () => { - for await (const _chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow() + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(2) + expect(chunks[0]).toEqual({ type: "text", text: "Response" }) + expect(chunks[1]).toEqual({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 25, + reasoningTokens: 15, + totalCost: expect.any(Number), + }) }) }) @@ -143,7 +281,6 @@ describe("GeminiHandler", () => { expect(result).toBe("") }) }) - describe("getModel", () => { it("should return correct model info", () => { const modelInfo = handler.getModel() @@ -163,88 +300,196 @@ describe("GeminiHandler", () => { }) }) - describe("calculateCost", () => { - // Mock ModelInfo based on gemini-1.5-flash-latest pricing (per 1M tokens) - // Removed 'id' and 'name' as they are not part of ModelInfo type directly - const mockInfo: ModelInfo = { - inputPrice: 0.125, // $/1M tokens - outputPrice: 0.375, // $/1M tokens - cacheWritesPrice: 0.125, // Assume same as input for test - cacheReadsPrice: 0.125 * 0.25, // Assume 0.25x input for test - contextWindow: 1_000_000, - maxTokens: 8192, - supportsPromptCache: true, // Enable cache calculations for tests - } - - it("should calculate cost correctly based on input and output tokens", () => { - const inputTokens = 10000 // Use larger numbers for per-million pricing - const outputTokens = 20000 - // Added non-null assertions (!) as mockInfo guarantees these values - const expectedCost = - (inputTokens / 1_000_000) * mockInfo.inputPrice! + (outputTokens / 1_000_000) * mockInfo.outputPrice! - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should return 0 if token counts are zero", () => { - // Note: The method expects numbers, not undefined. Passing undefined would be a type error. - // The calculateCost method itself returns undefined if prices are missing, but 0 if tokens are 0 and prices exist. - expect(handler.calculateCost({ info: mockInfo, inputTokens: 0, outputTokens: 0 })).toBe(0) - }) - - it("should handle only input tokens", () => { - const inputTokens = 5000 - // Added non-null assertion (!) - const expectedCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! - expect(handler.calculateCost({ info: mockInfo, inputTokens, outputTokens: 0 })).toBeCloseTo(expectedCost) - }) - - it("should handle only output tokens", () => { - const outputTokens = 15000 - // Added non-null assertion (!) - const expectedCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - expect(handler.calculateCost({ info: mockInfo, inputTokens: 0, outputTokens })).toBeCloseTo(expectedCost) - }) - - it("should calculate cost with cache write tokens", () => { - const inputTokens = 10000 - const outputTokens = 20000 - const cacheWriteTokens = 5000 - const CACHE_TTL = 5 // Match the constant in gemini.ts - - // Added non-null assertions (!) - const expectedInputCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! - const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - const expectedCacheWriteCost = - mockInfo.cacheWritesPrice! * (cacheWriteTokens / 1_000_000) * (CACHE_TTL / 60) - const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheWriteCost - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should calculate cost with cache read tokens", () => { - const inputTokens = 10000 // Total logical input - const outputTokens = 20000 - const cacheReadTokens = 8000 // Part of inputTokens read from cache - - const uncachedReadTokens = inputTokens - cacheReadTokens - // Added non-null assertions (!) - const expectedInputCost = (uncachedReadTokens / 1_000_000) * mockInfo.inputPrice! - const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - const expectedCacheReadCost = mockInfo.cacheReadsPrice! * (cacheReadTokens / 1_000_000) - const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens, cacheReadTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should return undefined if pricing info is missing", () => { - // Create a copy and explicitly set a price to undefined - const incompleteInfo: ModelInfo = { ...mockInfo, outputPrice: undefined } - const cost = handler.calculateCost({ info: incompleteInfo, inputTokens: 1000, outputTokens: 1000 }) - expect(cost).toBeUndefined() + describe("getModel with :thinking suffix", () => { + it("should strip :thinking suffix from model ID", () => { + // Use a valid thinking model that exists in geminiModels + const thinkingHandler = new GeminiHandler({ + apiModelId: "gemini-2.5-flash-preview-04-17:thinking", + geminiApiKey: "test-key", + }) + const modelInfo = thinkingHandler.getModel() + expect(modelInfo.id).toBe("gemini-2.5-flash-preview-04-17") // Without :thinking suffix + }) + + it("should handle non-thinking models without modification", () => { + const regularHandler = new GeminiHandler({ + apiModelId: "gemini-1.5-flash-002", + geminiApiKey: "test-key", + }) + const modelInfo = regularHandler.getModel() + expect(modelInfo.id).toBe("gemini-1.5-flash-002") // No change + }) + + it("should handle missing model ID with default", () => { + const defaultHandler = new GeminiHandler({ + geminiApiKey: "test-key", + }) + const modelInfo = defaultHandler.getModel() + expect(modelInfo.id).toBe(geminiDefaultModelId) + }) + }) + + describe("countTokens", () => { + const mockContent = [{ type: "text", text: "Hello world" }] as Array + it("should return token count from Gemini API", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({ + totalTokens: 42, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(42) + + expect(handler["client"].models.countTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: [{ text: "Hello world" }], // Note: convertAnthropicContentToGemini format + }) + }) + it("should fall back to parent method on API error", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(new Error("API error")) + + // Mock the parent countTokens method by setting up the prototype + const parentCountTokens = jest.fn().mockResolvedValue(25) + const originalPrototype = Object.getPrototypeOf(handler) + Object.setPrototypeOf(handler, { + ...originalPrototype, + countTokens: parentCountTokens, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(25) + }) + + it("should handle empty content", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({ + totalTokens: 0, + }) + + const result = await handler.countTokens([]) + expect(result).toBe(0) + }) + it("should handle undefined totalTokens response", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({}) + + // Mock the parent countTokens method by setting up the prototype + const parentCountTokens = jest.fn().mockResolvedValue(10) + const originalPrototype = Object.getPrototypeOf(handler) + Object.setPrototypeOf(handler, { + ...originalPrototype, + countTokens: parentCountTokens, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(10) + }) + }) + + describe("calculateCostGenai utility integration", () => { + it("should calculate cost correctly for input/output tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(typeof usageChunk?.totalCost).toBe("number") + }) + + it("should calculate cost with reasoning tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + thoughtsTokenCount: 200, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(usageChunk?.reasoningTokens).toBe(200) + }) + + it("should calculate cost with cache tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + cachedContentTokenCount: 100, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(usageChunk?.cacheReadTokens).toBe(100) + }) + }) + + describe("error handling", () => { + it("should handle createMessage stream errors", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(new Error("Stream error")) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "Test" }]) + + await expect(async () => { + for await (const chunk of stream) { + // This should throw + } + }).rejects.toThrow("Stream error") + }) + + it("should handle countTokens errors gracefully", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(new Error("Count error")) + + // Mock the parent countTokens method + const parentCountTokens = jest.fn().mockResolvedValue(0) + Object.setPrototypeOf(handler, { countTokens: parentCountTokens }) + + const result = await handler.countTokens([{ type: "text", text: "Test" }]) + expect(result).toBe(0) + }) + }) + + describe("destruct", () => { + it("should clean up resources", () => { + expect(() => handler.destruct()).not.toThrow() }) }) }) diff --git a/src/api/providers/__tests__/vertex.test.ts b/src/api/providers/__tests__/vertex.test.ts index 99178c9ab7f..3b4a5c1f78f 100644 --- a/src/api/providers/__tests__/vertex.test.ts +++ b/src/api/providers/__tests__/vertex.test.ts @@ -1,22 +1,24 @@ // npx jest src/api/providers/__tests__/vertex.test.ts import { Anthropic } from "@anthropic-ai/sdk" +import type { ModelInfo } from "@roo-code/types" import { ApiStreamChunk } from "../../transform/stream" +import { calculateCostGenai } from "../../../utils/calculateCostGenai" import { VertexHandler } from "../vertex" describe("VertexHandler", () => { let handler: VertexHandler - beforeEach(() => { // Create mock functions const mockGenerateContentStream = jest.fn() const mockGenerateContent = jest.fn() const mockGetGenerativeModel = jest.fn() + const mockCountTokens = jest.fn() handler = new VertexHandler({ - apiModelId: "gemini-1.5-pro-001", + apiModelId: "gemini-2.0-flash-001", vertexProjectId: "test-project", vertexRegion: "us-central1", }) @@ -27,9 +29,59 @@ describe("VertexHandler", () => { generateContentStream: mockGenerateContentStream, generateContent: mockGenerateContent, getGenerativeModel: mockGetGenerativeModel, + countTokens: mockCountTokens, }, } as any }) + describe("constructor", () => { + it("should initialize with JSON credentials", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexJsonCredentials: '{"type": "service_account", "project_id": "test"}', + }) + + expect(testHandler["options"].vertexJsonCredentials).toBe( + '{"type": "service_account", "project_id": "test"}', + ) + expect(testHandler["options"].vertexProjectId).toBe("test-project") + expect(testHandler["options"].vertexRegion).toBe("us-central1") + }) + + it("should initialize with key file path", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexKeyFile: "/path/to/keyfile.json", + }) + + expect(testHandler["options"].vertexKeyFile).toBe("/path/to/keyfile.json") + }) + + it("should initialize with API key", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexApiKey: "test-api-key", + }) + + expect(testHandler["options"].vertexApiKey).toBe("test-api-key") + }) + + it("should handle missing credentials gracefully", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + expect(testHandler["options"].vertexProjectId).toBe("test-project") + expect(testHandler["options"].vertexRegion).toBe("us-central1") + }) + }) describe("createMessage", () => { const mockMessages: Anthropic.Messages.MessageParam[] = [ @@ -38,12 +90,11 @@ describe("VertexHandler", () => { ] const systemPrompt = "You are a helpful assistant" - - it("should handle streaming responses correctly for Gemini", async () => { + it("should handle streaming responses correctly for Vertex", async () => { // Let's examine the test expectations and adjust our mock accordingly // The test expects 4 chunks: // 1. Usage chunk with input tokens - // 2. Text chunk with "Gemini response part 1" + // 2. Text chunk with "Vertex response part 1" // 3. Text chunk with " part 2" // 4. Usage chunk with output tokens @@ -51,7 +102,7 @@ describe("VertexHandler", () => { // instead of mocking the client jest.spyOn(handler, "createMessage").mockImplementation(async function* () { yield { type: "usage", inputTokens: 10, outputTokens: 0 } - yield { type: "text", text: "Gemini response part 1" } + yield { type: "text", text: "Vertex response part 1" } yield { type: "text", text: " part 2" } yield { type: "usage", inputTokens: 0, outputTokens: 5 } }) @@ -66,7 +117,7 @@ describe("VertexHandler", () => { expect(chunks.length).toBe(4) expect(chunks[0]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 0 }) - expect(chunks[1]).toEqual({ type: "text", text: "Gemini response part 1" }) + expect(chunks[1]).toEqual({ type: "text", text: "Vertex response part 1" }) expect(chunks[2]).toEqual({ type: "text", text: " part 2" }) expect(chunks[3]).toEqual({ type: "usage", inputTokens: 0, outputTokens: 5 }) @@ -76,14 +127,14 @@ describe("VertexHandler", () => { }) describe("completePrompt", () => { - it("should complete prompt successfully for Gemini", async () => { + it("should complete prompt successfully for Vertex", async () => { // Mock the response with text property ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ - text: "Test Gemini response", + text: "Test Vertex response", }) const result = await handler.completePrompt("Test prompt") - expect(result).toBe("Test Gemini response") + expect(result).toBe("Test Vertex response") // Verify the call to generateContent expect(handler["client"].models.generateContent).toHaveBeenCalledWith( @@ -96,17 +147,15 @@ describe("VertexHandler", () => { }), ) }) - - it("should handle API errors for Gemini", async () => { + it("should handle API errors for Vertex", async () => { const mockError = new Error("Vertex API error") ;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( - "Gemini completion error: Vertex API error", + "Vertex completion error: Vertex API error", ) }) - - it("should handle empty response for Gemini", async () => { + it("should handle empty response for Vertex", async () => { // Mock the response with empty text ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ text: "", @@ -116,22 +165,330 @@ describe("VertexHandler", () => { expect(result).toBe("") }) }) - describe("getModel", () => { - it("should return correct model info for Gemini", () => { - // Create a new instance with specific model ID + it("should return correct model info for Vertex models", () => { + // Create a new instance with specific vertex model ID const testHandler = new VertexHandler({ apiModelId: "gemini-2.0-flash-001", vertexProjectId: "test-project", vertexRegion: "us-central1", }) - // Don't mock getModel here as we want to test the actual implementation const modelInfo = testHandler.getModel() expect(modelInfo.id).toBe("gemini-2.0-flash-001") expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(1048576) }) + it("should fall back to vertex default model when apiModelId is not provided", () => { + const testHandler = new VertexHandler({ + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + expect(modelInfo.id).toBe("claude-sonnet-4@20250514") // vertexDefaultModelId + }) + it("should fall back to vertex default when invalid model is provided", () => { + const testHandler = new VertexHandler({ + apiModelId: "invalid-model-id", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + expect(modelInfo.id).toBe("claude-sonnet-4@20250514") // Should fall back to default + }) + }) + describe("countTokens", () => { + it("should count tokens successfully", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockResponse = { totalTokens: 42 } + + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue(mockResponse) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(42) + + expect(handler["client"].models.countTokens).toHaveBeenCalledWith( + expect.objectContaining({ + model: expect.any(String), + contents: expect.any(Array), + }), + ) + }) + it("should use project path for API key authentication", async () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexApiKey: "test-key", + }) + + // Mock the client + testHandler["client"] = { + models: { + countTokens: jest.fn().mockResolvedValue({ totalTokens: 50 }), + }, + } as any + + const mockContent = [{ type: "text" as const, text: "Test content" }] + await testHandler.countTokens(mockContent) + + expect(testHandler["client"].models.countTokens).toHaveBeenCalledWith( + expect.objectContaining({ + model: "projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-001", + }), + ) + }) + + it("should fall back to base implementation when response is undefined", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockResponse = { totalTokens: undefined } + + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue(mockResponse) + + // Mock the super.countTokens method + const mockSuperCountTokens = jest.fn().mockResolvedValue(25) + Object.setPrototypeOf(handler, { countTokens: mockSuperCountTokens }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(25) + expect(mockSuperCountTokens).toHaveBeenCalledWith(mockContent) + }) + + it("should fall back to base implementation on API error", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockError = new Error("API error") + + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(mockError) + + // Mock the super.countTokens method + const mockSuperCountTokens = jest.fn().mockResolvedValue(30) + Object.setPrototypeOf(handler, { countTokens: mockSuperCountTokens }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(30) + expect(mockSuperCountTokens).toHaveBeenCalledWith(mockContent) + }) + }) + + describe("getModel with :thinking suffix", () => { + it("should remove :thinking suffix from model ID", () => { + // Note: this model doesn't exist in vertexModels, so it will fall back to default + const testHandler = new VertexHandler({ + apiModelId: "some-thinking-model:thinking", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { id } = testHandler.getModel() + // Since the model doesn't exist, it falls back to default + expect(id).toBe("claude-sonnet-4@20250514") + }) + + it("should not modify model ID without :thinking suffix", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { id } = testHandler.getModel() + expect(id).toBe("gemini-2.0-flash-001") + }) + + it("should remove :thinking suffix from actual thinking model", () => { + // Test the :thinking logic with the model selection logic separately + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // First verify the model exists and is selected + const { id: selectedId } = testHandler.getModel() + expect(selectedId).toBe("gemini-2.0-flash-thinking-exp-01-21") + + // Now test with :thinking suffix + const thinkingHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21:thinking", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // This should fall back to default since the :thinking version doesn't exist as a key + const { id: thinkingId } = thinkingHandler.getModel() + expect(thinkingId).toBe("claude-sonnet-4@20250514") + }) + }) + + describe("createMessage with detailed scenarios", () => { + it("should handle complex usage metadata with reasoning and cache tokens", async () => { + // Create a more detailed mock for createMessage + jest.spyOn(handler, "createMessage").mockImplementation(async function* () { + yield { type: "usage", inputTokens: 1000, outputTokens: 0 } + yield { type: "text", text: "Thinking..." } + yield { type: "text", text: "Response content" } + yield { + type: "usage", + inputTokens: 0, + outputTokens: 500, + cacheReadTokens: 200, + reasoningTokens: 150, + totalCost: 0.05, + } + }) + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Complex question" }] + + const stream = handler.createMessage("You are helpful", mockMessages) + const chunks: any[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(4) + expect(chunks[3]).toEqual({ + type: "usage", + inputTokens: 0, + outputTokens: 500, + cacheReadTokens: 200, + reasoningTokens: 150, + totalCost: 0.05, + }) + }) + it("should handle API key with project ID in model path", async () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "my-project", + vertexRegion: "europe-west1", + vertexApiKey: "test-api-key", + }) + + // Mock the client and its methods + const mockGenerateContentStream = jest.fn().mockImplementation(async function* () { + yield { text: "Response" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }) + + testHandler["client"] = { + models: { + generateContentStream: mockGenerateContentStream, + }, + } as any + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + const stream = testHandler.createMessage("System", mockMessages) + + // Consume the stream to trigger the API call + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockGenerateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + model: "projects/my-project/locations/europe-west1/publishers/google/models/gemini-2.0-flash-001", + }), + ) + }) + }) + + describe("calculateCostGenai utility integration", () => { + it("should work with vertex model info", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-exp", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const inputTokens = 10000 + const outputTokens = 5000 + + // Test that calculateCost works with vertex model info + const cost = calculateCostGenai({ info, inputTokens, outputTokens }) + + // Should return a number if pricing info is available, undefined if not + expect(typeof cost === "number" || cost === undefined).toBe(true) + }) + it("should calculate cost with cache read tokens using vertex model", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-1.5-pro-002", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const inputTokens = 20000 + const outputTokens = 10000 + const cacheReadTokens = 5000 + + const cost = calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }) + + // gemini-1.5-pro-002 has inputPrice and outputPrice but no cacheReadsPrice + // so calculateCostGenai returns undefined + expect(cost).toBeUndefined() + }) + + it("should handle models with zero pricing", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const cost = calculateCostGenai({ info, inputTokens: 1000, outputTokens: 500 }) + + // This model has inputPrice: 0, outputPrice: 0, but no cacheReadsPrice + // so calculateCostGenai returns undefined + expect(cost).toBeUndefined() + }) + + it("should handle models without pricing information", () => { + const handler = new VertexHandler({ + apiModelId: "claude-sonnet-4@20250514", // Default vertex model + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const cost = calculateCostGenai({ info, inputTokens: 1000, outputTokens: 500 }) + + // Claude models in vertex might not have pricing info + expect(typeof cost === "number" || cost === undefined).toBe(true) + }) + }) + + describe("error handling", () => { + it("should handle streaming errors gracefully", async () => { + const mockError = new Error("Streaming failed") + + // Mock createMessage to throw an error + jest.spyOn(handler, "createMessage").mockImplementation(async function* () { + throw mockError + }) + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + const stream = handler.createMessage("System", mockMessages) + + await expect(async () => { + for await (const chunk of stream) { + // This should throw + } + }).rejects.toThrow("Streaming failed") + }) + }) + + describe("destruct", () => { + it("should have a destruct method", () => { + expect(typeof handler.destruct).toBe("function") + expect(() => handler.destruct()).not.toThrow() + }) }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6765c8676d8..9d236ffe3b0 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -4,14 +4,14 @@ import { type GenerateContentResponseUsageMetadata, type GenerateContentParameters, type GenerateContentConfig, + CountTokensParameters, } from "@google/genai" -import type { JWTInput } from "google-auth-library" import { type ModelInfo, type GeminiModelId, geminiDefaultModelId, geminiModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { safeJsonParse } from "../../shared/safeJsonParse" +import { calculateCostGenai } from "../../utils/calculateCostGenai" import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" import type { ApiStream } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -19,43 +19,15 @@ import { getModelParams } from "../transform/model-params" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { BaseProvider } from "./base-provider" -type GeminiHandlerOptions = ApiHandlerOptions & { - isVertex?: boolean -} - export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: GoogleGenAI - constructor({ isVertex, ...options }: GeminiHandlerOptions) { + constructor(options: ApiHandlerOptions) { super() - this.options = options - - const project = this.options.vertexProjectId ?? "not-provided" - const location = this.options.vertexRegion ?? "not-provided" const apiKey = this.options.geminiApiKey ?? "not-provided" - - this.client = this.options.vertexJsonCredentials - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { - credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), - }, - }) - : this.options.vertexKeyFile - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { keyFile: this.options.vertexKeyFile }, - }) - : isVertex - ? new GoogleGenAI({ vertexai: true, project, location }) - : new GoogleGenAI({ apiKey }) + this.client = new GoogleGenAI({ apiKey }) } async *createMessage( @@ -66,7 +38,6 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl const { id: model, info, reasoning: thinkingConfig, maxTokens } = this.getModel() const contents = messages.map(convertAnthropicMessageToGemini) - const config: GenerateContentConfig = { systemInstruction, httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined, @@ -124,7 +95,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl outputTokens, cacheReadTokens, reasoningTokens, - totalCost: this.calculateCost({ info, inputTokens, outputTokens, cacheReadTokens }), + totalCost: calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }), } } } @@ -171,10 +142,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl try { const { id: model } = this.getModel() - const response = await this.client.models.countTokens({ + const params: CountTokensParameters = { model, contents: convertAnthropicContentToGemini(content), - }) + } + const response = await this.client.models.countTokens(params) if (response.totalTokens === undefined) { console.warn("Gemini token counting returned undefined, using fallback") @@ -187,58 +159,5 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return super.countTokens(content) } } - - public calculateCost({ - info, - inputTokens, - outputTokens, - cacheReadTokens = 0, - }: { - info: ModelInfo - inputTokens: number - outputTokens: number - cacheReadTokens?: number - }) { - if (!info.inputPrice || !info.outputPrice || !info.cacheReadsPrice) { - return undefined - } - - let inputPrice = info.inputPrice - let outputPrice = info.outputPrice - let cacheReadsPrice = info.cacheReadsPrice - - // If there's tiered pricing then adjust the input and output token prices - // based on the input tokens used. - if (info.tiers) { - const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) - - if (tier) { - inputPrice = tier.inputPrice ?? inputPrice - outputPrice = tier.outputPrice ?? outputPrice - cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice - } - } - - // Subtract the cached input tokens from the total input tokens. - const uncachedInputTokens = inputTokens - cacheReadTokens - - let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 - - const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) - const outputTokensCost = outputPrice * (outputTokens / 1_000_000) - const totalCost = inputTokensCost + outputTokensCost + cacheReadCost - - const trace: Record = { - input: { price: inputPrice, tokens: uncachedInputTokens, cost: inputTokensCost }, - output: { price: outputPrice, tokens: outputTokens, cost: outputTokensCost }, - } - - if (cacheReadTokens > 0) { - trace.cacheRead = { price: cacheReadsPrice, tokens: cacheReadTokens, cost: cacheReadCost } - } - - // console.log(`[GeminiHandler] calculateCost -> ${totalCost}`, trace) - - return totalCost - } + public destruct() {} } diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 2c077d97b7e..97162938573 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,18 +1,117 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import { + GoogleGenAI, + type GenerateContentResponseUsageMetadata, + type GenerateContentParameters, + type GenerateContentConfig, + CountTokensParameters, +} from "@google/genai" +import type { JWTInput } from "google-auth-library" + import { type ModelInfo, type VertexModelId, vertexDefaultModelId, vertexModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" +import { safeJsonParse } from "../../shared/safeJsonParse" import { getModelParams } from "../transform/model-params" +import { calculateCostGenai } from "../../utils/calculateCostGenai" +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" +import type { ApiStream } from "../transform/stream" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" -import { GeminiHandler } from "./gemini" -import { SingleCompletionHandler } from "../index" +export class VertexHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: GoogleGenAI -export class VertexHandler extends GeminiHandler implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super({ ...options, isVertex: true }) + super() + this.options = options + + const project = this.options.vertexProjectId ?? "not-provided" + const location = this.options.vertexRegion ?? "not-provided" + + if (this.options.vertexJsonCredentials) { + this.client = new GoogleGenAI({ + vertexai: true, + project, + location, + googleAuthOptions: { + credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), + }, + }) + } else if (this.options.vertexKeyFile) { + this.client = new GoogleGenAI({ + vertexai: true, + project, + location, + googleAuthOptions: { keyFile: this.options.vertexKeyFile }, + }) + } else if (this.options.vertexApiKey) { + this.client = new GoogleGenAI({ + vertexai: true, + apiKey: this.options.vertexApiKey, + apiVersion: "v1", + }) + } else { + this.client = new GoogleGenAI({ vertexai: true, project, location }) + } } + async *createMessage( + systemInstruction: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: model, reasoning: thinkingConfig, maxTokens: maxOutputTokens, info } = this.getModel() + + const contents = messages.map(convertAnthropicMessageToGemini) + + const config: GenerateContentConfig = { + systemInstruction, + thinkingConfig, + maxOutputTokens, + temperature: this.options.modelTemperature ?? 0, + } + + const params: GenerateContentParameters = { model, contents, config } - override getModel() { + if (this.options.vertexApiKey && this.options.vertexProjectId) { + params.model = `projects/${this.options.vertexProjectId}/locations/${this.options.vertexRegion}/publishers/google/models/${model}` + } + + const result = await this.client.models.generateContentStream(params) + + let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined + + for await (const chunk of result) { + if (chunk.text) { + yield { type: "text", text: chunk.text } + } + + if (chunk.usageMetadata) { + lastUsageMetadata = chunk.usageMetadata + } + } + + if (lastUsageMetadata) { + const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 + const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 + const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount + const reasoningTokens = lastUsageMetadata.thoughtsTokenCount + + yield { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }), + } + } + } + + getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId const info: ModelInfo = vertexModels[id] @@ -24,4 +123,54 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand // suffix. return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } + + async completePrompt(prompt: string): Promise { + try { + const { id: model } = this.getModel() + + const result = await this.client.models.generateContent({ + model, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: { + temperature: this.options.modelTemperature ?? 0, + }, + }) + + return result.text ?? "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Vertex completion error: ${error.message}`) + } + + throw error + } + } + + override async countTokens(content: Array): Promise { + try { + const { id: model } = this.getModel() + + const params: CountTokensParameters = { + model, + contents: convertAnthropicContentToGemini(content), + } + + if (this.options.vertexApiKey && this.options.vertexProjectId) { + params.model = `projects/${this.options.vertexProjectId}/locations/${this.options.vertexRegion}/publishers/google/models/${model}` + } + + const response = await this.client.models.countTokens(params) + + if (response.totalTokens === undefined) { + console.warn("Vertex token counting returned undefined, using fallback") + return super.countTokens(content) + } + + return response.totalTokens + } catch (error) { + console.warn("Vertex token counting failed, using fallback", error) + return super.countTokens(content) + } + } + public destruct() {} } diff --git a/src/utils/__tests__/calculateCostGenai.test.ts b/src/utils/__tests__/calculateCostGenai.test.ts new file mode 100644 index 00000000000..02f131413c0 --- /dev/null +++ b/src/utils/__tests__/calculateCostGenai.test.ts @@ -0,0 +1,127 @@ +// npx jest src/utils/__tests__/calculateCostGenai.test.ts + +import type { ModelInfo } from "@roo-code/types" + +import { calculateCostGenai } from "../calculateCostGenai" + +describe("calculateCostGenai", () => { + // Mock ModelInfo based on gemini-1.5-flash-latest pricing (per 1M tokens) + const mockInfo: ModelInfo = { + inputPrice: 0.125, // $/1M tokens + outputPrice: 0.375, // $/1M tokens + cacheWritesPrice: 0.125, // Assume same as input for test + cacheReadsPrice: 0.125 * 0.25, // Assume 0.25x input for test + contextWindow: 1_000_000, + maxTokens: 8192, + supportsPromptCache: true, // Enable cache calculations for tests + } + + it("should calculate cost correctly based on input and output tokens", () => { + const inputTokens = 10000 // Use larger numbers for per-million pricing + const outputTokens = 20000 + // Added non-null assertions (!) as mockInfo guarantees these values + const expectedCost = + (inputTokens / 1_000_000) * mockInfo.inputPrice! + (outputTokens / 1_000_000) * mockInfo.outputPrice! + + const cost = calculateCostGenai({ info: mockInfo, inputTokens, outputTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should return 0 if token counts are zero", () => { + // Note: The method expects numbers, not undefined. Passing undefined would be a type error. + // The calculateCost method itself returns undefined if prices are missing, but 0 if tokens are 0 and prices exist. + expect(calculateCostGenai({ info: mockInfo, inputTokens: 0, outputTokens: 0 })).toBe(0) + }) + + it("should handle only input tokens", () => { + const inputTokens = 5000 + // Added non-null assertion (!) + const expectedCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! + expect(calculateCostGenai({ info: mockInfo, inputTokens, outputTokens: 0 })).toBeCloseTo(expectedCost) + }) + + it("should handle only output tokens", () => { + const outputTokens = 15000 + // Added non-null assertion (!) + const expectedCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! + expect(calculateCostGenai({ info: mockInfo, inputTokens: 0, outputTokens })).toBeCloseTo(expectedCost) + }) + + it("should calculate cost with cache read tokens", () => { + const inputTokens = 10000 // Total logical input + const outputTokens = 20000 + const cacheReadTokens = 8000 // Part of inputTokens read from cache + + const uncachedReadTokens = inputTokens - cacheReadTokens + // Added non-null assertions (!) + const expectedInputCost = (uncachedReadTokens / 1_000_000) * mockInfo.inputPrice! + const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! + const expectedCacheReadCost = mockInfo.cacheReadsPrice! * (cacheReadTokens / 1_000_000) + const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost + + const cost = calculateCostGenai({ info: mockInfo, inputTokens, outputTokens, cacheReadTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should return undefined if pricing info is missing", () => { + // Create a copy and explicitly set a price to undefined + const incompleteInfo: ModelInfo = { ...mockInfo, outputPrice: undefined } + const cost = calculateCostGenai({ info: incompleteInfo, inputTokens: 1000, outputTokens: 1000 }) + expect(cost).toBeUndefined() + }) + + it("should handle tiered pricing", () => { + const tieredInfo: ModelInfo = { + ...mockInfo, + tiers: [ + { + contextWindow: 50000, + inputPrice: 0.2, + outputPrice: 0.6, + cacheReadsPrice: 0.05, + }, + { + contextWindow: 1000000, + inputPrice: 0.125, + outputPrice: 0.375, + cacheReadsPrice: 0.03125, + }, + ], + } + + // Should use first tier pricing for small input + const inputTokens = 30000 + const outputTokens = 20000 + const expectedCost = (inputTokens / 1_000_000) * 0.2 + (outputTokens / 1_000_000) * 0.6 + + const cost = calculateCostGenai({ info: tieredInfo, inputTokens, outputTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should handle tiered pricing with cache reads", () => { + const tieredInfo: ModelInfo = { + ...mockInfo, + tiers: [ + { + contextWindow: 50000, + inputPrice: 0.2, + outputPrice: 0.6, + cacheReadsPrice: 0.05, + }, + ], + } + + const inputTokens = 30000 + const outputTokens = 20000 + const cacheReadTokens = 10000 + const uncachedInputTokens = inputTokens - cacheReadTokens + + const expectedInputCost = (uncachedInputTokens / 1_000_000) * 0.2 + const expectedOutputCost = (outputTokens / 1_000_000) * 0.6 + const expectedCacheReadCost = (cacheReadTokens / 1_000_000) * 0.05 + const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost + + const cost = calculateCostGenai({ info: tieredInfo, inputTokens, outputTokens, cacheReadTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) +}) diff --git a/src/utils/calculateCostGenai.ts b/src/utils/calculateCostGenai.ts new file mode 100644 index 00000000000..1f36e2cfb2f --- /dev/null +++ b/src/utils/calculateCostGenai.ts @@ -0,0 +1,58 @@ +import type { ModelInfo } from "@roo-code/types" + +/** + * Calculate the cost for GenAI models (Gemini and Vertex AI) based on token usage + * @param options - Token usage and model information + * @returns Total cost in USD or undefined if pricing info is missing + */ +export function calculateCostGenai({ + info, + inputTokens, + outputTokens, + cacheReadTokens = 0, +}: { + info: ModelInfo + inputTokens: number + outputTokens: number + cacheReadTokens?: number +}): number | undefined { + if (!info.inputPrice || !info.outputPrice || !info.cacheReadsPrice) { + return undefined + } + + let inputPrice = info.inputPrice + let outputPrice = info.outputPrice + let cacheReadsPrice = info.cacheReadsPrice + + // If there's tiered pricing then adjust the input and output token prices + // based on the input tokens used. + if (info.tiers) { + const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) + + if (tier) { + inputPrice = tier.inputPrice ?? inputPrice + outputPrice = tier.outputPrice ?? outputPrice + cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice + } + } + + // Subtract the cached input tokens from the total input tokens. + const uncachedInputTokens = inputTokens - cacheReadTokens + + let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 + + const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) + const outputTokensCost = outputPrice * (outputTokens / 1_000_000) + const totalCost = inputTokensCost + outputTokensCost + cacheReadCost + + const trace: Record = { + input: { price: inputPrice, tokens: uncachedInputTokens, cost: inputTokensCost }, + output: { price: outputPrice, tokens: outputTokens, cost: outputTokensCost }, + } + + if (cacheReadTokens > 0) { + trace.cacheRead = { price: cacheReadsPrice, tokens: cacheReadTokens, cost: cacheReadCost } + } + + return totalCost +} diff --git a/webview-ui/src/components/settings/providers/Vertex.tsx b/webview-ui/src/components/settings/providers/Vertex.tsx index 19a136927a2..4aa3c91c0d9 100644 --- a/webview-ui/src/components/settings/providers/Vertex.tsx +++ b/webview-ui/src/components/settings/providers/Vertex.tsx @@ -74,6 +74,13 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro className="w-full"> + + +