From 564b28607f6454171a89b92128505c6165d6ddcc Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Fri, 6 Feb 2026 13:38:48 -0500 Subject: [PATCH 1/2] refactor: migrate baseten provider to AI SDK --- src/api/providers/__tests__/baseten.spec.ts | 463 ++++++++++++++++++++ src/api/providers/baseten.ts | 37 +- 2 files changed, 492 insertions(+), 8 deletions(-) create mode 100644 src/api/providers/__tests__/baseten.spec.ts diff --git a/src/api/providers/__tests__/baseten.spec.ts b/src/api/providers/__tests__/baseten.spec.ts new file mode 100644 index 00000000000..6467fc91ab5 --- /dev/null +++ b/src/api/providers/__tests__/baseten.spec.ts @@ -0,0 +1,463 @@ +// npx vitest run src/api/providers/__tests__/baseten.spec.ts + +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => { + return vi.fn(() => ({ + modelId: "zai-org/GLM-4.6", + provider: "baseten", + })) + }), +})) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import { basetenDefaultModelId, basetenModels, type BasetenModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { BasetenHandler } from "../baseten" + +describe("BasetenHandler", () => { + let handler: BasetenHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + basetenApiKey: "test-baseten-api-key", + apiModelId: "zai-org/GLM-4.6", + } + handler = new BasetenHandler(mockOptions) + vi.clearAllMocks() + }) + + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(BasetenHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new BasetenHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(basetenDefaultModelId) + }) + }) + + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new BasetenHandler({ + basetenApiKey: "test-baseten-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(basetenDefaultModelId) + expect(model.info).toEqual(basetenModels[basetenDefaultModelId]) + }) + + it("should return specified model when valid model is provided", () => { + const testModelId: BasetenModelId = "deepseek-ai/DeepSeek-R1" + const handlerWithModel = new BasetenHandler({ + apiModelId: testModelId, + basetenApiKey: "test-baseten-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(basetenModels[testModelId]) + }) + + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new BasetenHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + expect(model.info).toBe(basetenModels[basetenDefaultModelId]) + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from Baseten" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from Baseten") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) + + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new BasetenHandler({ + basetenApiKey: "test-key", + apiModelId: "zai-org/GLM-4.6", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) + + it("should use user-specified temperature over default", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new BasetenHandler({ + basetenApiKey: "test-key", + apiModelId: "zai-org/GLM-4.6", + modelTemperature: 0.9, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.9, + }), + ) + }) + + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from Baseten", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from Baseten") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) + }) + + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + await handler.completePrompt("Test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) + }) + + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, + }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) + }) + }) + + describe("error handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle AI SDK errors with handleAiSdkError", async () => { + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw new Error("API Error") + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("Baseten: API Error") + }) + + it("should preserve status codes in error handling", async () => { + const apiError = new Error("Rate limit exceeded") + ;(apiError as any).status = 429 + + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw apiError + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + try { + for await (const _ of stream) { + // consume stream + } + expect.fail("Should have thrown an error") + } catch (error: any) { + expect(error.message).toContain("Baseten") + expect(error.status).toBe(429) + } + }) + }) +}) diff --git a/src/api/providers/baseten.ts b/src/api/providers/baseten.ts index ca0c2867756..ef3eaf231ce 100644 --- a/src/api/providers/baseten.ts +++ b/src/api/providers/baseten.ts @@ -1,18 +1,39 @@ -import { type BasetenModelId, basetenDefaultModelId, basetenModels } from "@roo-code/types" +import { basetenModels, basetenDefaultModelId, type BasetenModelId } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" -export class BasetenHandler extends BaseOpenAiCompatibleProvider { +import { getModelParams } from "../transform/model-params" + +import { OpenAICompatibleHandler, type OpenAICompatibleConfig } from "./openai-compatible" + +export class BasetenHandler extends OpenAICompatibleHandler { constructor(options: ApiHandlerOptions) { - super({ - ...options, + const modelId = options.apiModelId ?? basetenDefaultModelId + const modelInfo = basetenModels[modelId as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId] + + const config: OpenAICompatibleConfig = { providerName: "Baseten", baseURL: "https://inference.baseten.co/v1", - apiKey: options.basetenApiKey, - defaultProviderModelId: basetenDefaultModelId, - providerModels: basetenModels, + apiKey: options.basetenApiKey ?? "not-provided", + modelId, + modelInfo, + modelMaxTokens: options.modelMaxTokens ?? undefined, + temperature: options.modelTemperature ?? 0.5, + } + + super(options, config) + } + + override getModel() { + const id = this.options.apiModelId ?? basetenDefaultModelId + const info = basetenModels[id as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, defaultTemperature: 0.5, }) + return { id, info, ...params } } } From 6736d8ffde8e8998f573f6bed939f945690a7726 Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Fri, 6 Feb 2026 18:49:38 -0500 Subject: [PATCH 2/2] refactor(baseten): migrate to native @ai-sdk/baseten package Replace OpenAICompatibleHandler with dedicated @ai-sdk/baseten package, following the same pattern used by other native AI SDK providers (groq, deepseek, etc.). This uses createBaseten() for provider initialization and extends BaseProvider directly instead of the generic OpenAI-compatible handler. --- pnpm-lock.yaml | 196 ++++++++++++++++++++ src/api/providers/__tests__/baseten.spec.ts | 21 +-- src/api/providers/baseten.ts | 149 +++++++++++++-- src/esbuild.mjs | 16 ++ src/package.json | 1 + 5 files changed, 348 insertions(+), 35 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 58f6354f62d..c231b43e53a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -749,6 +749,9 @@ importers: '@ai-sdk/amazon-bedrock': specifier: ^4.0.50 version: 4.0.50(zod@3.25.76) + '@ai-sdk/baseten': + specifier: ^1.0.31 + version: 1.0.31(zod@3.25.76) '@ai-sdk/cerebras': specifier: ^1.0.0 version: 1.0.35(zod@3.25.76) @@ -1435,6 +1438,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/baseten@1.0.31': + resolution: {integrity: sha512-tGbV96WBb5nnfyUYFrPyBxrhw53YlKSJbMC+rH3HhQlUaIs8+m/Bm4M0isrek9owIIf4MmmSDZ5VZL08zz7eFQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35': resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} engines: {node: '>=18'} @@ -1513,6 +1522,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.28': + resolution: {integrity: sha512-WzDnU0B13FMSSupDtm2lksFZvWGXnOfhG5S0HoPI0pkX5uVkr6N1UTATMyVaxLCG0MRkMhXCjkg4NXgEbb330Q==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -1543,6 +1558,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.14': + resolution: {integrity: sha512-7bzKd9lgiDeXM7O4U4nQ8iTxguAOkg8LZGD9AfDVZYjO5cKYRwBPwVjboFcVrxncRHu0tYxZtXZtiLKpG4pEng==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider@2.0.0': resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==} engines: {node: '>=18'} @@ -1563,6 +1584,10 @@ packages: resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==} engines: {node: '>=18'} + '@ai-sdk/provider@3.0.8': + resolution: {integrity: sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==} + engines: {node: '>=18'} + '@ai-sdk/xai@3.0.46': resolution: {integrity: sha512-26qM/jYcFhF5krTM7bQT1CiZcdz22EQmA+r5me1hKYFM/yM20sSUMHnAcUzvzuuG9oQVKF0tziU2IcC0HX5huQ==} engines: {node: '>=18'} @@ -1880,6 +1905,93 @@ packages: resolution: {integrity: sha512-+EzkxvLNfiUeKMgy/3luqfsCWFRXLb7U6wNQTk60tovuckwB15B191tJWvpp4HjiQWdJkCxO3Wbvc6jlk3Xb2Q==} engines: {node: '>=6.9.0'} + '@basetenlabs/performance-client-android-arm-eabi@0.0.10': + resolution: {integrity: sha512-gwDZ6GDJA0AAmQAHxt2vaCz0tYTaLjxJKZnoYt+0Eji4gy231JZZFAwvbAqNdQCrGEQ9lXnk7SNM1Apet4NlYg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [android] + + '@basetenlabs/performance-client-android-arm64@0.0.10': + resolution: {integrity: sha512-oGRB/6hH89majhsmoVmj1IAZv4C7F2aLeTSebevBelmdYO4CFkn5qewxLzU1pDkkmxVVk2k+TRpYa1Dt4B96qQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@basetenlabs/performance-client-darwin-arm64@0.0.10': + resolution: {integrity: sha512-QpBOUjeO05tWgFWkDw2RUQZa3BMplX5jNiBBTi5mH1lIL/m1sm2vkxoc0iorEESp1mMPstYFS/fr4ssBuO7wyA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@basetenlabs/performance-client-darwin-universal@0.0.10': + resolution: {integrity: sha512-CBM38GAhekjylrlf7jW/0WNyFAGnAMBCNHZxaPnAjjhDNzJh1tcrwhvtOs66XbAqCOjO/tkt5Pdu6mg2Ui2Pjw==} + engines: {node: '>= 10'} + os: [darwin] + + '@basetenlabs/performance-client-darwin-x64@0.0.10': + resolution: {integrity: sha512-R+NsA72Axclh1CUpmaWOCLTWCqXn5/tFMj2z9BnHVSRTelx/pYFlx6ZngVTB1HYp1n21m3upPXGo8CHF8R7Itw==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@basetenlabs/performance-client-linux-arm-gnueabihf@0.0.10': + resolution: {integrity: sha512-96kEo0Eas4GVQdFkxIB1aAv6dy5Ga57j+RIg5l0Yiawv+AYIEmgk9BsGkqcwayp8Iiu6LN22Z+AUsGY2gstNrg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@basetenlabs/performance-client-linux-arm-musleabihf@0.0.10': + resolution: {integrity: sha512-lzEHeu+/BWDl2q+QZcqCkg1rDGF4MeyM3HgYwX+07t+vGZoqtM2we9vEV68wXMpl6ToEHQr7ML2KHA1Gb6ogxg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@basetenlabs/performance-client-linux-arm64-gnu@0.0.10': + resolution: {integrity: sha512-MnY2cIRY/cQOYERWIHhh5CoaS2wgmmXtGDVGSLYyZvjwizrXZvjkEz7Whv2jaQ21T5S56VER67RABjz2TItrHQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@basetenlabs/performance-client-linux-riscv64-gnu@0.0.10': + resolution: {integrity: sha512-2KUvdK4wuoZdIqNnJhx7cu6ybXCwtiwGAtlrEvhai3FOkUQ3wE2Xa+TQ33mNGSyFbw6wAvLawYtKVFmmw27gJw==} + engines: {node: '>= 10'} + cpu: [riscv64] + os: [linux] + + '@basetenlabs/performance-client-linux-x64-gnu@0.0.10': + resolution: {integrity: sha512-9jjQPjHLiVOGwUPlmhnBl7OmmO7hQ8WMt+v3mJuxkS5JTNDmVOngfmgGlbN9NjBhQMENjdcMUVOquVo7HeybGQ==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@basetenlabs/performance-client-linux-x64-musl@0.0.10': + resolution: {integrity: sha512-bjYB8FKcPvEa251Ep2Gm3tvywADL9eavVjZsikdf0AvJ1K5pT+vLLvJBU9ihBsTPWnbF4pJgxVjwS6UjVObsQA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@basetenlabs/performance-client-win32-arm64-msvc@0.0.10': + resolution: {integrity: sha512-Vxq5UXEmfh3C3hpwXdp3Daaf0dnLR9zFH2x8MJ1Hf/TcilmOP1clneewNpIv0e7MrnT56Z4pM6P3d8VFMZqBKg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@basetenlabs/performance-client-win32-ia32-msvc@0.0.10': + resolution: {integrity: sha512-KJrm7CgZdP/UDC5+tHtqE6w9XMfY5YUfMOxJfBZGSsLMqS2OGsakQsaF0a55k+58l29X5w/nAkjHrI1BcQO03w==} + engines: {node: '>= 10'} + cpu: [ia32] + os: [win32] + + '@basetenlabs/performance-client-win32-x64-msvc@0.0.10': + resolution: {integrity: sha512-M/mhvfTItUcUX+aeXRb5g5MbRlndfg6yelV7tSYfLU4YixMIe5yoGaAP3iDilpFJjcC99f+EU4l4+yLbPtpXig==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@basetenlabs/performance-client@0.0.10': + resolution: {integrity: sha512-H6bpd1JcDbuJsOS2dNft+CCGLzBqHJO/ST/4mMKhLAW641J6PpVJUw1szYsk/dTetdedbWxHpMkvFObOKeP8nw==} + engines: {node: '>= 10'} + '@bcoe/v8-coverage@0.2.3': resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==} @@ -11016,6 +11128,14 @@ snapshots: '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/baseten@1.0.31(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.28(zod@3.25.76) + '@ai-sdk/provider': 3.0.8 + '@ai-sdk/provider-utils': 4.0.14(zod@3.25.76) + '@basetenlabs/performance-client': 0.0.10 + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) @@ -11102,6 +11222,12 @@ snapshots: '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.28(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.8 + '@ai-sdk/provider-utils': 4.0.14(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -11138,6 +11264,13 @@ snapshots: eventsource-parser: 3.0.6 zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.14(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.8 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 3.25.76 + '@ai-sdk/provider@2.0.0': dependencies: json-schema: 0.4.0 @@ -11158,6 +11291,10 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@3.0.8': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/xai@3.0.46(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 2.0.26(zod@3.25.76) @@ -11893,6 +12030,65 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.27.1 + '@basetenlabs/performance-client-android-arm-eabi@0.0.10': + optional: true + + '@basetenlabs/performance-client-android-arm64@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-arm64@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-universal@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-x64@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm-gnueabihf@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm-musleabihf@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-riscv64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-x64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-x64-musl@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-arm64-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-ia32-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-x64-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client@0.0.10': + optionalDependencies: + '@basetenlabs/performance-client-android-arm-eabi': 0.0.10 + '@basetenlabs/performance-client-android-arm64': 0.0.10 + '@basetenlabs/performance-client-darwin-arm64': 0.0.10 + '@basetenlabs/performance-client-darwin-universal': 0.0.10 + '@basetenlabs/performance-client-darwin-x64': 0.0.10 + '@basetenlabs/performance-client-linux-arm-gnueabihf': 0.0.10 + '@basetenlabs/performance-client-linux-arm-musleabihf': 0.0.10 + '@basetenlabs/performance-client-linux-arm64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-riscv64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-x64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-x64-musl': 0.0.10 + '@basetenlabs/performance-client-win32-arm64-msvc': 0.0.10 + '@basetenlabs/performance-client-win32-ia32-msvc': 0.0.10 + '@basetenlabs/performance-client-win32-x64-msvc': 0.0.10 + '@bcoe/v8-coverage@0.2.3': {} '@braintree/sanitize-url@7.1.1': {} diff --git a/src/api/providers/__tests__/baseten.spec.ts b/src/api/providers/__tests__/baseten.spec.ts index 6467fc91ab5..e44b201f291 100644 --- a/src/api/providers/__tests__/baseten.spec.ts +++ b/src/api/providers/__tests__/baseten.spec.ts @@ -15,8 +15,8 @@ vi.mock("ai", async (importOriginal) => { } }) -vi.mock("@ai-sdk/openai-compatible", () => ({ - createOpenAICompatible: vi.fn(() => { +vi.mock("@ai-sdk/baseten", () => ({ + createBaseten: vi.fn(() => { return vi.fn(() => ({ modelId: "zai-org/GLM-4.6", provider: "baseten", @@ -123,12 +123,9 @@ describe("BasetenHandler", () => { outputTokens: 5, }) - const mockProviderMetadata = Promise.resolve({}) - mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: mockUsage, - providerMetadata: mockProviderMetadata, }) const stream = handler.createMessage(systemPrompt, messages) @@ -153,12 +150,9 @@ describe("BasetenHandler", () => { outputTokens: 20, }) - const mockProviderMetadata = Promise.resolve({}) - mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: mockUsage, - providerMetadata: mockProviderMetadata, }) const stream = handler.createMessage(systemPrompt, messages) @@ -181,7 +175,6 @@ describe("BasetenHandler", () => { mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), - providerMetadata: Promise.resolve({}), }) const handlerWithDefaultTemp = new BasetenHandler({ @@ -209,7 +202,6 @@ describe("BasetenHandler", () => { mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), - providerMetadata: Promise.resolve({}), }) const handlerWithCustomTemp = new BasetenHandler({ @@ -239,7 +231,6 @@ describe("BasetenHandler", () => { mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), - providerMetadata: Promise.resolve({}), }) const stream = handler.createMessage(systemPrompt, messages) @@ -320,12 +311,9 @@ describe("BasetenHandler", () => { outputTokens: 5, }) - const mockProviderMetadata = Promise.resolve({}) - mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: mockUsage, - providerMetadata: mockProviderMetadata, }) const stream = handler.createMessage(systemPrompt, messages, { @@ -381,12 +369,9 @@ describe("BasetenHandler", () => { outputTokens: 5, }) - const mockProviderMetadata = Promise.resolve({}) - mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: mockUsage, - providerMetadata: mockProviderMetadata, }) const stream = handler.createMessage(systemPrompt, messages) @@ -420,7 +405,6 @@ describe("BasetenHandler", () => { mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), - providerMetadata: Promise.resolve({}), }) const stream = handler.createMessage(systemPrompt, messages) @@ -444,7 +428,6 @@ describe("BasetenHandler", () => { mockStreamText.mockReturnValue({ fullStream: mockFullStream(), usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), - providerMetadata: Promise.resolve({}), }) const stream = handler.createMessage(systemPrompt, messages) diff --git a/src/api/providers/baseten.ts b/src/api/providers/baseten.ts index ef3eaf231ce..2e63f3d52c1 100644 --- a/src/api/providers/baseten.ts +++ b/src/api/providers/baseten.ts @@ -1,30 +1,47 @@ -import { basetenModels, basetenDefaultModelId, type BasetenModelId } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createBaseten } from "@ai-sdk/baseten" +import { streamText, generateText, ToolSet } from "ai" + +import { basetenModels, basetenDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import { OpenAICompatibleHandler, type OpenAICompatibleConfig } from "./openai-compatible" +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const BASETEN_DEFAULT_TEMPERATURE = 0.5 + +/** + * Baseten provider using the dedicated @ai-sdk/baseten package. + * Provides native support for Baseten's inference API. + */ +export class BasetenHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType -export class BasetenHandler extends OpenAICompatibleHandler { constructor(options: ApiHandlerOptions) { - const modelId = options.apiModelId ?? basetenDefaultModelId - const modelInfo = basetenModels[modelId as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId] + super() + this.options = options - const config: OpenAICompatibleConfig = { - providerName: "Baseten", + this.provider = createBaseten({ baseURL: "https://inference.baseten.co/v1", apiKey: options.basetenApiKey ?? "not-provided", - modelId, - modelInfo, - modelMaxTokens: options.modelMaxTokens ?? undefined, - temperature: options.modelTemperature ?? 0.5, - } - - super(options, config) + headers: DEFAULT_HEADERS, + }) } - override getModel() { + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { const id = this.options.apiModelId ?? basetenDefaultModelId const info = basetenModels[id as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId] const params = getModelParams({ @@ -32,8 +49,108 @@ export class BasetenHandler extends OpenAICompatibleHandler { modelId: id, model: info, settings: this.options, - defaultTemperature: 0.5, + defaultTemperature: BASETEN_DEFAULT_TEMPERATURE, }) return { id, info, ...params } } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics(usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const aiSdkMessages = convertToAiSdkMessages(messages) + + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + } + + const result = streamText(requestOptions) + + try { + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + throw handleAiSdkError(error, "Baseten") + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE, + }) + + return text + } + + override isAiSdkProvider(): boolean { + return true + } } diff --git a/src/esbuild.mjs b/src/esbuild.mjs index aabacfcee99..fb7b1866797 100644 --- a/src/esbuild.mjs +++ b/src/esbuild.mjs @@ -43,6 +43,22 @@ async function main() { * @type {import('esbuild').Plugin[]} */ const plugins = [ + { + // Stub out @basetenlabs/performance-client which contains native .node + // binaries that esbuild cannot bundle. This module is only used by + // @ai-sdk/baseten for embedding models, not for chat completions. + name: "stub-baseten-native", + setup(build) { + build.onResolve({ filter: /^@basetenlabs\/performance-client/ }, (args) => ({ + path: args.path, + namespace: "stub-baseten-native", + })) + build.onLoad({ filter: /.*/, namespace: "stub-baseten-native" }, () => ({ + contents: "module.exports = { PerformanceClient: class PerformanceClient {} };", + loader: "js", + })) + }, + }, { name: "copyFiles", setup(build) { diff --git a/src/package.json b/src/package.json index f2b574c2626..eb67c0d1d7d 100644 --- a/src/package.json +++ b/src/package.json @@ -451,6 +451,7 @@ }, "dependencies": { "@ai-sdk/amazon-bedrock": "^4.0.50", + "@ai-sdk/baseten": "^1.0.31", "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26",