diff --git a/packages/inference/README.md b/packages/inference/README.md index 9e686d9089..b2910ed186 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -61,6 +61,7 @@ Currently, we support the following providers: - [Sambanova](https://sambanova.ai) - [Scaleway](https://www.scaleway.com/en/generative-apis/) - [Together](https://together.xyz) +- [Baseten](https://baseten.co) - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) @@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models) - [Together supported models](https://huggingface.co/api/partners/together/models) +- [Baseten supported models](https://huggingface.co/api/partners/baseten/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [Groq supported models](https://console.groq.com/docs/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index d863d5dc5c..c41f929cab 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,3 +1,4 @@ +import * as Baseten from "../providers/baseten.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; @@ -55,6 +56,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from import { InferenceClientInputError } from "../errors.js"; export const PROVIDERS: Record>> = { + baseten: { + conversational: new Baseten.BasetenConversationalTask(), + }, "black-forest-labs": { "text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(), }, diff --git a/packages/inference/src/providers/baseten.ts b/packages/inference/src/providers/baseten.ts new file mode 100644 index 0000000000..c39661af2a --- /dev/null +++ b/packages/inference/src/providers/baseten.ts @@ -0,0 +1,25 @@ +/** + * See the registered mapping of HF model ID => Baseten model ID here: + * + * https://huggingface.co/api/partners/baseten/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Baseten and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Baseten, please open an issue on the present repo + * and we will tag Baseten team members. + * + * Thanks! + */ +import { BaseConversationalTask } from "./providerHelper.js"; + +const BASETEN_API_BASE_URL = "https://inference.baseten.co"; + +export class BasetenConversationalTask extends BaseConversationalTask { + constructor() { + super("baseten", BASETEN_API_BASE_URL); + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 6c1bec7dec..1dfa656736 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< * Example: * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", */ + baseten: {}, "black-forest-labs": {}, cerebras: {}, cohere: {}, diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index c218ebbfdb..b5d4815b5a 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -45,6 +45,7 @@ export interface Options { export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ + "baseten", "black-forest-labs", "cerebras", "cohere", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 52d27253d6..4520c389a1 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2343,4 +2343,62 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + + describe.concurrent( + "Baseten", + () => { + const client = new InferenceClient(env.HF_BASETEN_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["baseten"] = { + "Qwen/Qwen3-235B-A22B-Instruct-2507": { + provider: "baseten", + hfModelId: "Qwen/Qwen3-235B-A22B-Instruct-2507", + providerId: "Qwen/Qwen3-235B-A22B-Instruct-2507", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion - Qwen3 235B Instruct", async () => { + const res = await client.chatCompletion({ + model: "Qwen/Qwen3-235B-A22B-Instruct-2507", + provider: "baseten", + messages: [{ role: "user", content: "What is 5 + 3?" }], + max_tokens: 20, + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toBeDefined(); + expect(typeof completion).toBe("string"); + expect(completion).toMatch(/(eight|8)/i); + } + }); + + it("chatCompletion stream - Qwen3 235B", async () => { + const stream = client.chatCompletionStream({ + model: "Qwen/Qwen3-235B-A22B-Instruct-2507", + provider: "baseten", + messages: [{ role: "user", content: "Count from 1 to 3" }], + stream: true, + max_tokens: 20, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + expect(fullResponse).toMatch(/1.*2.*3/); + }); + }, + TIMEOUT + ); });