diff --git a/packages/evals/README.md b/packages/evals/README.md index bb202a70940..2880145de3b 100644 --- a/packages/evals/README.md +++ b/packages/evals/README.md @@ -68,7 +68,6 @@ To stop an evals run early you can simply stop the "controller" container using Screenshot 2025-06-06 at 9 00 41 AM - ## Advanced Usage / Debugging The evals system runs VS Code headlessly in Docker containers for consistent, reproducible environments. While this design ensures reliability, it can make debugging more challenging. For debugging purposes, you can run the system locally on macOS, though this approach is less reliable due to hardware and environment variability. diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index a60f7e0b282..98e0135dac6 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -106,6 +106,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({ }) const vertexSchema = apiModelIdProviderModelSchema.extend({ + vertexApiKey: z.string().optional(), vertexKeyFile: z.string().optional(), vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), @@ -290,6 +291,7 @@ export const PROVIDER_SETTINGS_KEYS = keysOf()([ "awsBedrockEndpointEnabled", "awsBedrockEndpoint", // Google Vertex + "vertexApiKey", "vertexKeyFile", "vertexJsonCredentials", "vertexProjectId", diff --git a/packages/types/src/providers/vertex.ts b/packages/types/src/providers/vertex.ts index 028d3089235..a2d75602260 100644 --- a/packages/types/src/providers/vertex.ts +++ b/packages/types/src/providers/vertex.ts @@ -233,4 +233,5 @@ export const VERTEX_REGIONS = [ { value: "europe-west1", label: "europe-west1" }, { value: "europe-west4", label: "europe-west4" }, { value: "asia-southeast1", label: "asia-southeast1" }, + { value: "global", label: "global" }, ] diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index 837948af1d2..d62e6f5a2e6 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -3,19 +3,31 @@ import { Anthropic } from "@anthropic-ai/sdk" import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types" +import { calculateCostGenai } from "../../../utils/calculateCostGenai" import { GeminiHandler } from "../gemini" +// Mock the calculateCostGenai function +jest.mock("../../../utils/calculateCostGenai", () => ({ + calculateCostGenai: jest.fn().mockReturnValue(0.005), +})) + +const mockedCalculateCostGenai = calculateCostGenai as jest.MockedFunction + const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219" describe("GeminiHandler", () => { let handler: GeminiHandler - beforeEach(() => { + // Reset mocks + jest.clearAllMocks() + mockedCalculateCostGenai.mockReturnValue(0.005) + // Create mock functions const mockGenerateContentStream = jest.fn() const mockGenerateContent = jest.fn() const mockGetGenerativeModel = jest.fn() + const mockCountTokens = jest.fn() handler = new GeminiHandler({ apiKey: "test-key", @@ -29,15 +41,43 @@ describe("GeminiHandler", () => { generateContentStream: mockGenerateContentStream, generateContent: mockGenerateContent, getGenerativeModel: mockGetGenerativeModel, + countTokens: mockCountTokens, }, } as any }) - describe("constructor", () => { it("should initialize with provided config", () => { expect(handler["options"].geminiApiKey).toBe("test-key") expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME) }) + + it("should initialize with geminiApiKey", () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "specific-gemini-key", + apiModelId: "gemini-1.5-flash-002", + }) + + expect(testHandler["options"].geminiApiKey).toBe("specific-gemini-key") + expect(testHandler["options"].apiModelId).toBe("gemini-1.5-flash-002") + }) + + it("should handle missing API key gracefully", () => { + const testHandler = new GeminiHandler({ + apiModelId: "gemini-1.5-flash-002", + }) + + // Should not throw and should have undefined geminiApiKey + expect(testHandler["options"].geminiApiKey).toBeUndefined() + }) + + it("should initialize with baseUrl configuration", () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "test-key", + googleGeminiBaseUrl: "https://custom-gemini.example.com", + }) + + expect(testHandler["options"].googleGeminiBaseUrl).toBe("https://custom-gemini.example.com") + }) }) describe("createMessage", () => { @@ -69,13 +109,18 @@ describe("GeminiHandler", () => { for await (const chunk of stream) { chunks.push(chunk) - } - - // Should have 3 chunks: 'Hello', ' world!', and usage info + } // Should have 3 chunks: 'Hello', ' world!', and usage info expect(chunks.length).toBe(3) expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) expect(chunks[1]).toEqual({ type: "text", text: " world!" }) - expect(chunks[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 5 }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + cacheReadTokens: undefined, + reasoningTokens: undefined, + totalCost: 0.005 + }) // Verify the call to generateContentStream expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith( @@ -88,18 +133,111 @@ describe("GeminiHandler", () => { }), ) }) + it("should handle reasoning/thinking output", async () => { + // Mock response with thinking parts + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [ + { thought: true, text: "Let me think about this..." }, + { text: "Here's my response" }, + ], + }, + }, + ], + } + yield { usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 8, thoughtsTokenCount: 5 } } + }, + }) - it("should handle API errors", async () => { - const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError) + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(3) + expect(chunks[0]).toEqual({ type: "reasoning", text: "Let me think about this..." }) + expect(chunks[1]).toEqual({ type: "text", text: "Here's my response" }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 15, + outputTokens: 8, + reasoningTokens: 5, + cacheReadTokens: undefined, + totalCost: 0.005, + }) + }) + + it("should handle custom baseUrl configuration", async () => { + const testHandler = new GeminiHandler({ + geminiApiKey: "test-key", + googleGeminiBaseUrl: "https://custom-gemini.example.com", + apiModelId: "gemini-1.5-flash-002", + }) + + // Mock the client + testHandler["client"] = { + models: { + generateContentStream: jest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Custom response" } + }, + }), + }, + } as any + + const stream = testHandler.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(testHandler["client"].models.generateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + httpOptions: { baseUrl: "https://custom-gemini.example.com" }, + }), + }), + ) + }) + it("should handle usage metadata with cache and reasoning tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 100, + candidatesTokenCount: 50, + cachedContentTokenCount: 25, + thoughtsTokenCount: 15, + }, + } + }, + }) const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] - await expect(async () => { - for await (const _chunk of stream) { - // Should throw before yielding any chunks - } - }).rejects.toThrow() + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(2) + expect(chunks[0]).toEqual({ type: "text", text: "Response" }) + expect(chunks[1]).toEqual({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 25, + reasoningTokens: 15, + totalCost: expect.any(Number), + }) }) }) @@ -143,7 +281,6 @@ describe("GeminiHandler", () => { expect(result).toBe("") }) }) - describe("getModel", () => { it("should return correct model info", () => { const modelInfo = handler.getModel() @@ -163,88 +300,196 @@ describe("GeminiHandler", () => { }) }) - describe("calculateCost", () => { - // Mock ModelInfo based on gemini-1.5-flash-latest pricing (per 1M tokens) - // Removed 'id' and 'name' as they are not part of ModelInfo type directly - const mockInfo: ModelInfo = { - inputPrice: 0.125, // $/1M tokens - outputPrice: 0.375, // $/1M tokens - cacheWritesPrice: 0.125, // Assume same as input for test - cacheReadsPrice: 0.125 * 0.25, // Assume 0.25x input for test - contextWindow: 1_000_000, - maxTokens: 8192, - supportsPromptCache: true, // Enable cache calculations for tests - } - - it("should calculate cost correctly based on input and output tokens", () => { - const inputTokens = 10000 // Use larger numbers for per-million pricing - const outputTokens = 20000 - // Added non-null assertions (!) as mockInfo guarantees these values - const expectedCost = - (inputTokens / 1_000_000) * mockInfo.inputPrice! + (outputTokens / 1_000_000) * mockInfo.outputPrice! - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should return 0 if token counts are zero", () => { - // Note: The method expects numbers, not undefined. Passing undefined would be a type error. - // The calculateCost method itself returns undefined if prices are missing, but 0 if tokens are 0 and prices exist. - expect(handler.calculateCost({ info: mockInfo, inputTokens: 0, outputTokens: 0 })).toBe(0) - }) - - it("should handle only input tokens", () => { - const inputTokens = 5000 - // Added non-null assertion (!) - const expectedCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! - expect(handler.calculateCost({ info: mockInfo, inputTokens, outputTokens: 0 })).toBeCloseTo(expectedCost) - }) - - it("should handle only output tokens", () => { - const outputTokens = 15000 - // Added non-null assertion (!) - const expectedCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - expect(handler.calculateCost({ info: mockInfo, inputTokens: 0, outputTokens })).toBeCloseTo(expectedCost) - }) - - it("should calculate cost with cache write tokens", () => { - const inputTokens = 10000 - const outputTokens = 20000 - const cacheWriteTokens = 5000 - const CACHE_TTL = 5 // Match the constant in gemini.ts - - // Added non-null assertions (!) - const expectedInputCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! - const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - const expectedCacheWriteCost = - mockInfo.cacheWritesPrice! * (cacheWriteTokens / 1_000_000) * (CACHE_TTL / 60) - const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheWriteCost - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should calculate cost with cache read tokens", () => { - const inputTokens = 10000 // Total logical input - const outputTokens = 20000 - const cacheReadTokens = 8000 // Part of inputTokens read from cache - - const uncachedReadTokens = inputTokens - cacheReadTokens - // Added non-null assertions (!) - const expectedInputCost = (uncachedReadTokens / 1_000_000) * mockInfo.inputPrice! - const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! - const expectedCacheReadCost = mockInfo.cacheReadsPrice! * (cacheReadTokens / 1_000_000) - const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost - - const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens, cacheReadTokens }) - expect(cost).toBeCloseTo(expectedCost) - }) - - it("should return undefined if pricing info is missing", () => { - // Create a copy and explicitly set a price to undefined - const incompleteInfo: ModelInfo = { ...mockInfo, outputPrice: undefined } - const cost = handler.calculateCost({ info: incompleteInfo, inputTokens: 1000, outputTokens: 1000 }) - expect(cost).toBeUndefined() + describe("getModel with :thinking suffix", () => { + it("should strip :thinking suffix from model ID", () => { + // Use a valid thinking model that exists in geminiModels + const thinkingHandler = new GeminiHandler({ + apiModelId: "gemini-2.5-flash-preview-04-17:thinking", + geminiApiKey: "test-key", + }) + const modelInfo = thinkingHandler.getModel() + expect(modelInfo.id).toBe("gemini-2.5-flash-preview-04-17") // Without :thinking suffix + }) + + it("should handle non-thinking models without modification", () => { + const regularHandler = new GeminiHandler({ + apiModelId: "gemini-1.5-flash-002", + geminiApiKey: "test-key", + }) + const modelInfo = regularHandler.getModel() + expect(modelInfo.id).toBe("gemini-1.5-flash-002") // No change + }) + + it("should handle missing model ID with default", () => { + const defaultHandler = new GeminiHandler({ + geminiApiKey: "test-key", + }) + const modelInfo = defaultHandler.getModel() + expect(modelInfo.id).toBe(geminiDefaultModelId) + }) + }) + + describe("countTokens", () => { + const mockContent = [{ type: "text", text: "Hello world" }] as Array + it("should return token count from Gemini API", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({ + totalTokens: 42, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(42) + + expect(handler["client"].models.countTokens).toHaveBeenCalledWith({ + model: GEMINI_20_FLASH_THINKING_NAME, + contents: [{ text: "Hello world" }], // Note: convertAnthropicContentToGemini format + }) + }) + it("should fall back to parent method on API error", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(new Error("API error")) + + // Mock the parent countTokens method by setting up the prototype + const parentCountTokens = jest.fn().mockResolvedValue(25) + const originalPrototype = Object.getPrototypeOf(handler) + Object.setPrototypeOf(handler, { + ...originalPrototype, + countTokens: parentCountTokens, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(25) + }) + + it("should handle empty content", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({ + totalTokens: 0, + }) + + const result = await handler.countTokens([]) + expect(result).toBe(0) + }) + it("should handle undefined totalTokens response", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue({}) + + // Mock the parent countTokens method by setting up the prototype + const parentCountTokens = jest.fn().mockResolvedValue(10) + const originalPrototype = Object.getPrototypeOf(handler) + Object.setPrototypeOf(handler, { + ...originalPrototype, + countTokens: parentCountTokens, + }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(10) + }) + }) + + describe("calculateCostGenai utility integration", () => { + it("should calculate cost correctly for input/output tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(typeof usageChunk?.totalCost).toBe("number") + }) + + it("should calculate cost with reasoning tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + thoughtsTokenCount: 200, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(usageChunk?.reasoningTokens).toBe(200) + }) + + it("should calculate cost with cache tokens", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Response" } + yield { + usageMetadata: { + promptTokenCount: 1000, + candidatesTokenCount: 500, + cachedContentTokenCount: 100, + }, + } + }, + }) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "User message" }]) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunk = chunks.find((chunk) => chunk.type === "usage") + expect(usageChunk?.totalCost).toBeGreaterThan(0) + expect(usageChunk?.cacheReadTokens).toBe(100) + }) + }) + + describe("error handling", () => { + it("should handle createMessage stream errors", async () => { + ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(new Error("Stream error")) + + const stream = handler.createMessage("System prompt", [{ role: "user", content: "Test" }]) + + await expect(async () => { + for await (const chunk of stream) { + // This should throw + } + }).rejects.toThrow("Stream error") + }) + + it("should handle countTokens errors gracefully", async () => { + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(new Error("Count error")) + + // Mock the parent countTokens method + const parentCountTokens = jest.fn().mockResolvedValue(0) + Object.setPrototypeOf(handler, { countTokens: parentCountTokens }) + + const result = await handler.countTokens([{ type: "text", text: "Test" }]) + expect(result).toBe(0) + }) + }) + + describe("destruct", () => { + it("should clean up resources", () => { + expect(() => handler.destruct()).not.toThrow() }) }) }) diff --git a/src/api/providers/__tests__/vertex.test.ts b/src/api/providers/__tests__/vertex.test.ts index 99178c9ab7f..e8bcfdc46f8 100644 --- a/src/api/providers/__tests__/vertex.test.ts +++ b/src/api/providers/__tests__/vertex.test.ts @@ -1,22 +1,24 @@ // npx jest src/api/providers/__tests__/vertex.test.ts import { Anthropic } from "@anthropic-ai/sdk" +import type { ModelInfo } from "@roo-code/types" import { ApiStreamChunk } from "../../transform/stream" +import { calculateCostGenai } from "../../../utils/calculateCostGenai" import { VertexHandler } from "../vertex" describe("VertexHandler", () => { let handler: VertexHandler - beforeEach(() => { // Create mock functions const mockGenerateContentStream = jest.fn() const mockGenerateContent = jest.fn() const mockGetGenerativeModel = jest.fn() + const mockCountTokens = jest.fn() handler = new VertexHandler({ - apiModelId: "gemini-1.5-pro-001", + apiModelId: "gemini-2.0-flash-001", vertexProjectId: "test-project", vertexRegion: "us-central1", }) @@ -27,9 +29,59 @@ describe("VertexHandler", () => { generateContentStream: mockGenerateContentStream, generateContent: mockGenerateContent, getGenerativeModel: mockGetGenerativeModel, + countTokens: mockCountTokens, }, } as any }) + describe("constructor", () => { + it("should initialize with JSON credentials", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexJsonCredentials: '{"type": "service_account", "project_id": "test"}', + }) + + expect(testHandler["options"].vertexJsonCredentials).toBe( + '{"type": "service_account", "project_id": "test"}', + ) + expect(testHandler["options"].vertexProjectId).toBe("test-project") + expect(testHandler["options"].vertexRegion).toBe("us-central1") + }) + + it("should initialize with key file path", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexKeyFile: "/path/to/keyfile.json", + }) + + expect(testHandler["options"].vertexKeyFile).toBe("/path/to/keyfile.json") + }) + + it("should initialize with API key", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexApiKey: "test-api-key", + }) + + expect(testHandler["options"].vertexApiKey).toBe("test-api-key") + }) + + it("should handle missing credentials gracefully", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + expect(testHandler["options"].vertexProjectId).toBe("test-project") + expect(testHandler["options"].vertexRegion).toBe("us-central1") + }) + }) describe("createMessage", () => { const mockMessages: Anthropic.Messages.MessageParam[] = [ @@ -38,12 +90,11 @@ describe("VertexHandler", () => { ] const systemPrompt = "You are a helpful assistant" - - it("should handle streaming responses correctly for Gemini", async () => { + it("should handle streaming responses correctly for Vertex", async () => { // Let's examine the test expectations and adjust our mock accordingly // The test expects 4 chunks: // 1. Usage chunk with input tokens - // 2. Text chunk with "Gemini response part 1" + // 2. Text chunk with "Vertex response part 1" // 3. Text chunk with " part 2" // 4. Usage chunk with output tokens @@ -51,7 +102,7 @@ describe("VertexHandler", () => { // instead of mocking the client jest.spyOn(handler, "createMessage").mockImplementation(async function* () { yield { type: "usage", inputTokens: 10, outputTokens: 0 } - yield { type: "text", text: "Gemini response part 1" } + yield { type: "text", text: "Vertex response part 1" } yield { type: "text", text: " part 2" } yield { type: "usage", inputTokens: 0, outputTokens: 5 } }) @@ -66,7 +117,7 @@ describe("VertexHandler", () => { expect(chunks.length).toBe(4) expect(chunks[0]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 0 }) - expect(chunks[1]).toEqual({ type: "text", text: "Gemini response part 1" }) + expect(chunks[1]).toEqual({ type: "text", text: "Vertex response part 1" }) expect(chunks[2]).toEqual({ type: "text", text: " part 2" }) expect(chunks[3]).toEqual({ type: "usage", inputTokens: 0, outputTokens: 5 }) @@ -76,14 +127,14 @@ describe("VertexHandler", () => { }) describe("completePrompt", () => { - it("should complete prompt successfully for Gemini", async () => { + it("should complete prompt successfully for Vertex", async () => { // Mock the response with text property ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ - text: "Test Gemini response", + text: "Test Vertex response", }) const result = await handler.completePrompt("Test prompt") - expect(result).toBe("Test Gemini response") + expect(result).toBe("Test Vertex response") // Verify the call to generateContent expect(handler["client"].models.generateContent).toHaveBeenCalledWith( @@ -96,17 +147,15 @@ describe("VertexHandler", () => { }), ) }) - - it("should handle API errors for Gemini", async () => { + it("should handle API errors for Vertex", async () => { const mockError = new Error("Vertex API error") ;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( - "Gemini completion error: Vertex API error", + "Vertex completion error: Vertex API error", ) }) - - it("should handle empty response for Gemini", async () => { + it("should handle empty response for Vertex", async () => { // Mock the response with empty text ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({ text: "", @@ -116,22 +165,329 @@ describe("VertexHandler", () => { expect(result).toBe("") }) }) - describe("getModel", () => { - it("should return correct model info for Gemini", () => { - // Create a new instance with specific model ID + it("should return correct model info for Vertex models", () => { + // Create a new instance with specific vertex model ID const testHandler = new VertexHandler({ apiModelId: "gemini-2.0-flash-001", vertexProjectId: "test-project", vertexRegion: "us-central1", }) - // Don't mock getModel here as we want to test the actual implementation const modelInfo = testHandler.getModel() expect(modelInfo.id).toBe("gemini-2.0-flash-001") expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(1048576) }) + it("should fall back to vertex default model when apiModelId is not provided", () => { + const testHandler = new VertexHandler({ + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + expect(modelInfo.id).toBe("claude-sonnet-4@20250514") // vertexDefaultModelId + }) + it("should fall back to vertex default when invalid model is provided", () => { + const testHandler = new VertexHandler({ + apiModelId: "invalid-model-id", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + expect(modelInfo.id).toBe("claude-sonnet-4@20250514") // Should fall back to default + }) + }) + describe("countTokens", () => { + it("should count tokens successfully", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockResponse = { totalTokens: 42 } + + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue(mockResponse) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(42) + + expect(handler["client"].models.countTokens).toHaveBeenCalledWith( + expect.objectContaining({ + model: expect.any(String), + contents: expect.any(Array), + }), + ) + }) + it("should use project path for API key authentication", async () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexApiKey: "test-key", + }) + + // Mock the client + testHandler["client"] = { + models: { + countTokens: jest.fn().mockResolvedValue({ totalTokens: 50 }), + }, + } as any + + const mockContent = [{ type: "text" as const, text: "Test content" }] + await testHandler.countTokens(mockContent) + + expect(testHandler["client"].models.countTokens).toHaveBeenCalledWith( + expect.objectContaining({ + model: "projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-001", + }), + ) + }) + + it("should fall back to base implementation when response is undefined", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockResponse = { totalTokens: undefined } + + ;(handler["client"].models.countTokens as jest.Mock).mockResolvedValue(mockResponse) + + // Mock the super.countTokens method + const mockSuperCountTokens = jest.fn().mockResolvedValue(25) + Object.setPrototypeOf(handler, { countTokens: mockSuperCountTokens }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(25) + expect(mockSuperCountTokens).toHaveBeenCalledWith(mockContent) + }) + + it("should fall back to base implementation on API error", async () => { + const mockContent = [{ type: "text" as const, text: "Hello world" }] + const mockError = new Error("API error") + + ;(handler["client"].models.countTokens as jest.Mock).mockRejectedValue(mockError) + + // Mock the super.countTokens method + const mockSuperCountTokens = jest.fn().mockResolvedValue(30) + Object.setPrototypeOf(handler, { countTokens: mockSuperCountTokens }) + + const result = await handler.countTokens(mockContent) + expect(result).toBe(30) + expect(mockSuperCountTokens).toHaveBeenCalledWith(mockContent) + }) + }) + + describe("getModel with :thinking suffix", () => { + it("should remove :thinking suffix from model ID", () => { + // Note: this model doesn't exist in vertexModels, so it will fall back to default + const testHandler = new VertexHandler({ + apiModelId: "some-thinking-model:thinking", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { id } = testHandler.getModel() + // Since the model doesn't exist, it falls back to default + expect(id).toBe("claude-sonnet-4@20250514") + }) + + it("should not modify model ID without :thinking suffix", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { id } = testHandler.getModel() + expect(id).toBe("gemini-2.0-flash-001") + }) + + it("should remove :thinking suffix from actual thinking model", () => { + // Test the :thinking logic with the model selection logic separately + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // First verify the model exists and is selected + const { id: selectedId } = testHandler.getModel() + expect(selectedId).toBe("gemini-2.0-flash-thinking-exp-01-21") + + // Now test with :thinking suffix + const thinkingHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21:thinking", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // This should fall back to default since the :thinking version doesn't exist as a key + const { id: thinkingId } = thinkingHandler.getModel() + expect(thinkingId).toBe("claude-sonnet-4@20250514") + }) + }) + + describe("createMessage with detailed scenarios", () => { + it("should handle complex usage metadata with reasoning and cache tokens", async () => { + // Create a more detailed mock for createMessage + jest.spyOn(handler, "createMessage").mockImplementation(async function* () { + yield { type: "usage", inputTokens: 1000, outputTokens: 0 } + yield { type: "text", text: "Thinking..." } + yield { type: "text", text: "Response content" } + yield { + type: "usage", + inputTokens: 0, + outputTokens: 500, + cacheReadTokens: 200, + reasoningTokens: 150, + totalCost: 0.05, + } + }) + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Complex question" }] + + const stream = handler.createMessage("You are helpful", mockMessages) + const chunks: any[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBe(4) + expect(chunks[3]).toEqual({ + type: "usage", + inputTokens: 0, + outputTokens: 500, + cacheReadTokens: 200, + reasoningTokens: 150, + totalCost: 0.05, + }) + }) + it("should handle API key with project ID in model path", async () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-001", + vertexProjectId: "my-project", + vertexRegion: "europe-west1", + vertexApiKey: "test-api-key", + }) + + // Mock the client and its methods + const mockGenerateContentStream = jest.fn().mockImplementation(async function* () { + yield { text: "Response" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }) + + testHandler["client"] = { + models: { + generateContentStream: mockGenerateContentStream, + }, + } as any + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + const stream = testHandler.createMessage("System", mockMessages) + + // Consume the stream to trigger the API call + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockGenerateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + model: "projects/my-project/locations/europe-west1/publishers/google/models/gemini-2.0-flash-001", + }), + ) + }) + }) + + describe("calculateCostGenai utility integration", () => { + it("should work with vertex model info", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-exp", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const inputTokens = 10000 + const outputTokens = 5000 + + // Test that calculateCost works with vertex model info + const cost = calculateCostGenai({ info, inputTokens, outputTokens }) + + // Should return a number if pricing info is available, undefined if not + expect(typeof cost === "number" || cost === undefined).toBe(true) + }) + it("should calculate cost with cache read tokens using vertex model", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-1.5-pro-002", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const inputTokens = 20000 + const outputTokens = 10000 + const cacheReadTokens = 5000 + + const cost = calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }) + + // gemini-1.5-pro-002 has inputPrice and outputPrice but no cacheReadsPrice + // so calculateCostGenai returns undefined + expect(cost).toBeUndefined() + }) + + it("should handle models with zero pricing", () => { + const handler = new VertexHandler({ + apiModelId: "gemini-2.0-flash-thinking-exp-01-21", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const cost = calculateCostGenai({ info, inputTokens: 1000, outputTokens: 500 }) + + // This model has inputPrice: 0, outputPrice: 0, but no cacheReadsPrice + // so calculateCostGenai returns undefined + expect(cost).toBeUndefined() + }) + + it("should handle models without pricing information", () => { + const handler = new VertexHandler({ + apiModelId: "claude-sonnet-4@20250514", // Default vertex model + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const { info } = handler.getModel() + const cost = calculateCostGenai({ info, inputTokens: 1000, outputTokens: 500 }) + + // Claude models in vertex might not have pricing info + expect(typeof cost === "number" || cost === undefined).toBe(true) + }) + }) + + describe("error handling", () => { + it("should handle streaming errors gracefully", async () => { + const mockError = new Error("Streaming failed") // Mock createMessage to throw an error + // eslint-disable-next-line require-yield + jest.spyOn(handler, "createMessage").mockImplementation(async function* () { + throw mockError + }) + + const mockMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + const stream = handler.createMessage("System", mockMessages) + + await expect(async () => { + for await (const chunk of stream) { + // This should throw + } + }).rejects.toThrow("Streaming failed") + }) + }) + + describe("destruct", () => { + it("should have a destruct method", () => { + expect(typeof handler.destruct).toBe("function") + expect(() => handler.destruct()).not.toThrow() + }) }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6765c8676d8..9d236ffe3b0 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -4,14 +4,14 @@ import { type GenerateContentResponseUsageMetadata, type GenerateContentParameters, type GenerateContentConfig, + CountTokensParameters, } from "@google/genai" -import type { JWTInput } from "google-auth-library" import { type ModelInfo, type GeminiModelId, geminiDefaultModelId, geminiModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { safeJsonParse } from "../../shared/safeJsonParse" +import { calculateCostGenai } from "../../utils/calculateCostGenai" import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" import type { ApiStream } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -19,43 +19,15 @@ import { getModelParams } from "../transform/model-params" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { BaseProvider } from "./base-provider" -type GeminiHandlerOptions = ApiHandlerOptions & { - isVertex?: boolean -} - export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: GoogleGenAI - constructor({ isVertex, ...options }: GeminiHandlerOptions) { + constructor(options: ApiHandlerOptions) { super() - this.options = options - - const project = this.options.vertexProjectId ?? "not-provided" - const location = this.options.vertexRegion ?? "not-provided" const apiKey = this.options.geminiApiKey ?? "not-provided" - - this.client = this.options.vertexJsonCredentials - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { - credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), - }, - }) - : this.options.vertexKeyFile - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { keyFile: this.options.vertexKeyFile }, - }) - : isVertex - ? new GoogleGenAI({ vertexai: true, project, location }) - : new GoogleGenAI({ apiKey }) + this.client = new GoogleGenAI({ apiKey }) } async *createMessage( @@ -66,7 +38,6 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl const { id: model, info, reasoning: thinkingConfig, maxTokens } = this.getModel() const contents = messages.map(convertAnthropicMessageToGemini) - const config: GenerateContentConfig = { systemInstruction, httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined, @@ -124,7 +95,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl outputTokens, cacheReadTokens, reasoningTokens, - totalCost: this.calculateCost({ info, inputTokens, outputTokens, cacheReadTokens }), + totalCost: calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }), } } } @@ -171,10 +142,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl try { const { id: model } = this.getModel() - const response = await this.client.models.countTokens({ + const params: CountTokensParameters = { model, contents: convertAnthropicContentToGemini(content), - }) + } + const response = await this.client.models.countTokens(params) if (response.totalTokens === undefined) { console.warn("Gemini token counting returned undefined, using fallback") @@ -187,58 +159,5 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return super.countTokens(content) } } - - public calculateCost({ - info, - inputTokens, - outputTokens, - cacheReadTokens = 0, - }: { - info: ModelInfo - inputTokens: number - outputTokens: number - cacheReadTokens?: number - }) { - if (!info.inputPrice || !info.outputPrice || !info.cacheReadsPrice) { - return undefined - } - - let inputPrice = info.inputPrice - let outputPrice = info.outputPrice - let cacheReadsPrice = info.cacheReadsPrice - - // If there's tiered pricing then adjust the input and output token prices - // based on the input tokens used. - if (info.tiers) { - const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) - - if (tier) { - inputPrice = tier.inputPrice ?? inputPrice - outputPrice = tier.outputPrice ?? outputPrice - cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice - } - } - - // Subtract the cached input tokens from the total input tokens. - const uncachedInputTokens = inputTokens - cacheReadTokens - - let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 - - const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) - const outputTokensCost = outputPrice * (outputTokens / 1_000_000) - const totalCost = inputTokensCost + outputTokensCost + cacheReadCost - - const trace: Record = { - input: { price: inputPrice, tokens: uncachedInputTokens, cost: inputTokensCost }, - output: { price: outputPrice, tokens: outputTokens, cost: outputTokensCost }, - } - - if (cacheReadTokens > 0) { - trace.cacheRead = { price: cacheReadsPrice, tokens: cacheReadTokens, cost: cacheReadCost } - } - - // console.log(`[GeminiHandler] calculateCost -> ${totalCost}`, trace) - - return totalCost - } + public destruct() {} } diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 2c077d97b7e..f0d532ce012 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,18 +1,117 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import { + GoogleGenAI, + type GenerateContentResponseUsageMetadata, + type GenerateContentParameters, + type GenerateContentConfig, + CountTokensParameters, +} from "@google/genai" +import type { JWTInput } from "google-auth-library" + import { type ModelInfo, type VertexModelId, vertexDefaultModelId, vertexModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" +import { safeJsonParse } from "../../shared/safeJsonParse" import { getModelParams } from "../transform/model-params" +import { calculateCostGenai } from "../../utils/calculateCostGenai" +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" +import type { ApiStream } from "../transform/stream" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" -import { GeminiHandler } from "./gemini" -import { SingleCompletionHandler } from "../index" +export class VertexHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: GoogleGenAI -export class VertexHandler extends GeminiHandler implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super({ ...options, isVertex: true }) + super() + this.options = options + + const project = this.options.vertexProjectId ?? "not-provided" + const location = this.options.vertexRegion ?? "not-provided" + + if (this.options.vertexJsonCredentials) { + this.client = new GoogleGenAI({ + vertexai: true, + project, + location, + googleAuthOptions: { + credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), + }, + }) + } else if (this.options.vertexKeyFile) { + this.client = new GoogleGenAI({ + vertexai: true, + project, + location, + googleAuthOptions: { keyFile: this.options.vertexKeyFile }, + }) + } else if (this.options.vertexApiKey) { + this.client = new GoogleGenAI({ + vertexai: true, + apiKey: this.options.vertexApiKey, + apiVersion: "v1", + }) + } else { + this.client = new GoogleGenAI({ vertexai: true, project, location }) + } + } + async *createMessage( + systemInstruction: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: model, reasoning: thinkingConfig, maxTokens: maxOutputTokens, info } = this.getModel() + + const contents = messages.map(convertAnthropicMessageToGemini) + + const config: GenerateContentConfig = { + systemInstruction, + thinkingConfig, + maxOutputTokens, + temperature: this.options.modelTemperature ?? 0, + } + + const params: GenerateContentParameters = { model, contents, config } + + if (this.options.vertexApiKey && this.options.vertexProjectId) { + params.model = `projects/${this.options.vertexProjectId}/locations/${this.options.vertexRegion}/publishers/google/models/${model}` + } + + const result = await this.client.models.generateContentStream(params) + + let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined + + for await (const chunk of result) { + if (chunk.text) { + yield { type: "text", text: chunk.text } + } + + if (chunk.usageMetadata) { + lastUsageMetadata = chunk.usageMetadata + } + } + + if (lastUsageMetadata) { + const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 + const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 + const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount + const reasoningTokens = lastUsageMetadata.thoughtsTokenCount + + yield { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: calculateCostGenai({ info, inputTokens, outputTokens, cacheReadTokens }), + } + } } - override getModel() { + getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId const info: ModelInfo = vertexModels[id] @@ -24,4 +123,60 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand // suffix. return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } + + async completePrompt(prompt: string): Promise { + try { + const { id: model } = this.getModel() + + const params: GenerateContentParameters = { + model, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: { + temperature: this.options.modelTemperature ?? 0, + }, + } + + if (this.options.vertexApiKey && this.options.vertexProjectId) { + params.model = `projects/${this.options.vertexProjectId}/locations/${this.options.vertexRegion}/publishers/google/models/${model}` + } + + const result = await this.client.models.generateContent(params) + + return result.text ?? "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Vertex completion error: ${error.message}`) + } + + throw error + } + } + + override async countTokens(content: Array): Promise { + try { + const { id: model } = this.getModel() + + const params: CountTokensParameters = { + model, + contents: convertAnthropicContentToGemini(content), + } + + if (this.options.vertexApiKey && this.options.vertexProjectId) { + params.model = `projects/${this.options.vertexProjectId}/locations/${this.options.vertexRegion}/publishers/google/models/${model}` + } + + const response = await this.client.models.countTokens(params) + + if (response.totalTokens === undefined) { + console.warn("Vertex token counting returned undefined, using fallback") + return super.countTokens(content) + } + + return response.totalTokens + } catch (error) { + console.warn("Vertex token counting failed, using fallback", error) + return super.countTokens(content) + } + } + public destruct() {} } diff --git a/src/utils/__tests__/calculateCostGenai.test.ts b/src/utils/__tests__/calculateCostGenai.test.ts new file mode 100644 index 00000000000..02f131413c0 --- /dev/null +++ b/src/utils/__tests__/calculateCostGenai.test.ts @@ -0,0 +1,127 @@ +// npx jest src/utils/__tests__/calculateCostGenai.test.ts + +import type { ModelInfo } from "@roo-code/types" + +import { calculateCostGenai } from "../calculateCostGenai" + +describe("calculateCostGenai", () => { + // Mock ModelInfo based on gemini-1.5-flash-latest pricing (per 1M tokens) + const mockInfo: ModelInfo = { + inputPrice: 0.125, // $/1M tokens + outputPrice: 0.375, // $/1M tokens + cacheWritesPrice: 0.125, // Assume same as input for test + cacheReadsPrice: 0.125 * 0.25, // Assume 0.25x input for test + contextWindow: 1_000_000, + maxTokens: 8192, + supportsPromptCache: true, // Enable cache calculations for tests + } + + it("should calculate cost correctly based on input and output tokens", () => { + const inputTokens = 10000 // Use larger numbers for per-million pricing + const outputTokens = 20000 + // Added non-null assertions (!) as mockInfo guarantees these values + const expectedCost = + (inputTokens / 1_000_000) * mockInfo.inputPrice! + (outputTokens / 1_000_000) * mockInfo.outputPrice! + + const cost = calculateCostGenai({ info: mockInfo, inputTokens, outputTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should return 0 if token counts are zero", () => { + // Note: The method expects numbers, not undefined. Passing undefined would be a type error. + // The calculateCost method itself returns undefined if prices are missing, but 0 if tokens are 0 and prices exist. + expect(calculateCostGenai({ info: mockInfo, inputTokens: 0, outputTokens: 0 })).toBe(0) + }) + + it("should handle only input tokens", () => { + const inputTokens = 5000 + // Added non-null assertion (!) + const expectedCost = (inputTokens / 1_000_000) * mockInfo.inputPrice! + expect(calculateCostGenai({ info: mockInfo, inputTokens, outputTokens: 0 })).toBeCloseTo(expectedCost) + }) + + it("should handle only output tokens", () => { + const outputTokens = 15000 + // Added non-null assertion (!) + const expectedCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! + expect(calculateCostGenai({ info: mockInfo, inputTokens: 0, outputTokens })).toBeCloseTo(expectedCost) + }) + + it("should calculate cost with cache read tokens", () => { + const inputTokens = 10000 // Total logical input + const outputTokens = 20000 + const cacheReadTokens = 8000 // Part of inputTokens read from cache + + const uncachedReadTokens = inputTokens - cacheReadTokens + // Added non-null assertions (!) + const expectedInputCost = (uncachedReadTokens / 1_000_000) * mockInfo.inputPrice! + const expectedOutputCost = (outputTokens / 1_000_000) * mockInfo.outputPrice! + const expectedCacheReadCost = mockInfo.cacheReadsPrice! * (cacheReadTokens / 1_000_000) + const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost + + const cost = calculateCostGenai({ info: mockInfo, inputTokens, outputTokens, cacheReadTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should return undefined if pricing info is missing", () => { + // Create a copy and explicitly set a price to undefined + const incompleteInfo: ModelInfo = { ...mockInfo, outputPrice: undefined } + const cost = calculateCostGenai({ info: incompleteInfo, inputTokens: 1000, outputTokens: 1000 }) + expect(cost).toBeUndefined() + }) + + it("should handle tiered pricing", () => { + const tieredInfo: ModelInfo = { + ...mockInfo, + tiers: [ + { + contextWindow: 50000, + inputPrice: 0.2, + outputPrice: 0.6, + cacheReadsPrice: 0.05, + }, + { + contextWindow: 1000000, + inputPrice: 0.125, + outputPrice: 0.375, + cacheReadsPrice: 0.03125, + }, + ], + } + + // Should use first tier pricing for small input + const inputTokens = 30000 + const outputTokens = 20000 + const expectedCost = (inputTokens / 1_000_000) * 0.2 + (outputTokens / 1_000_000) * 0.6 + + const cost = calculateCostGenai({ info: tieredInfo, inputTokens, outputTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) + + it("should handle tiered pricing with cache reads", () => { + const tieredInfo: ModelInfo = { + ...mockInfo, + tiers: [ + { + contextWindow: 50000, + inputPrice: 0.2, + outputPrice: 0.6, + cacheReadsPrice: 0.05, + }, + ], + } + + const inputTokens = 30000 + const outputTokens = 20000 + const cacheReadTokens = 10000 + const uncachedInputTokens = inputTokens - cacheReadTokens + + const expectedInputCost = (uncachedInputTokens / 1_000_000) * 0.2 + const expectedOutputCost = (outputTokens / 1_000_000) * 0.6 + const expectedCacheReadCost = (cacheReadTokens / 1_000_000) * 0.05 + const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheReadCost + + const cost = calculateCostGenai({ info: tieredInfo, inputTokens, outputTokens, cacheReadTokens }) + expect(cost).toBeCloseTo(expectedCost) + }) +}) diff --git a/src/utils/calculateCostGenai.ts b/src/utils/calculateCostGenai.ts new file mode 100644 index 00000000000..1f36e2cfb2f --- /dev/null +++ b/src/utils/calculateCostGenai.ts @@ -0,0 +1,58 @@ +import type { ModelInfo } from "@roo-code/types" + +/** + * Calculate the cost for GenAI models (Gemini and Vertex AI) based on token usage + * @param options - Token usage and model information + * @returns Total cost in USD or undefined if pricing info is missing + */ +export function calculateCostGenai({ + info, + inputTokens, + outputTokens, + cacheReadTokens = 0, +}: { + info: ModelInfo + inputTokens: number + outputTokens: number + cacheReadTokens?: number +}): number | undefined { + if (!info.inputPrice || !info.outputPrice || !info.cacheReadsPrice) { + return undefined + } + + let inputPrice = info.inputPrice + let outputPrice = info.outputPrice + let cacheReadsPrice = info.cacheReadsPrice + + // If there's tiered pricing then adjust the input and output token prices + // based on the input tokens used. + if (info.tiers) { + const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) + + if (tier) { + inputPrice = tier.inputPrice ?? inputPrice + outputPrice = tier.outputPrice ?? outputPrice + cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice + } + } + + // Subtract the cached input tokens from the total input tokens. + const uncachedInputTokens = inputTokens - cacheReadTokens + + let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 + + const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) + const outputTokensCost = outputPrice * (outputTokens / 1_000_000) + const totalCost = inputTokensCost + outputTokensCost + cacheReadCost + + const trace: Record = { + input: { price: inputPrice, tokens: uncachedInputTokens, cost: inputTokensCost }, + output: { price: outputPrice, tokens: outputTokens, cost: outputTokensCost }, + } + + if (cacheReadTokens > 0) { + trace.cacheRead = { price: cacheReadsPrice, tokens: cacheReadTokens, cost: cacheReadCost } + } + + return totalCost +} diff --git a/webview-ui/src/components/settings/providers/Vertex.tsx b/webview-ui/src/components/settings/providers/Vertex.tsx index 19a136927a2..4aa3c91c0d9 100644 --- a/webview-ui/src/components/settings/providers/Vertex.tsx +++ b/webview-ui/src/components/settings/providers/Vertex.tsx @@ -74,6 +74,13 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro className="w-full"> + + +