From 06fd004635fba6f2822590235b8a7aeb19844c12 Mon Sep 17 00:00:00 2001 From: chen sihan Date: Sun, 25 Jan 2026 19:30:16 -0800 Subject: [PATCH 1/2] add keywordsai gateway --- packages/types/src/provider-settings.ts | 12 ++ src/api/index.ts | 3 + src/api/providers/fetchers/keywordsai.ts | 55 +++++++++ src/api/providers/fetchers/modelCache.ts | 8 ++ src/api/providers/index.ts | 1 + src/api/providers/keywordsai.ts | 30 +++++ src/api/providers/openai.ts | 12 ++ src/core/task/Task.ts | 10 +- src/core/webview/webviewMessageHandler.ts | 8 ++ src/shared/api.ts | 1 + .../src/components/settings/ApiOptions.tsx | 13 ++ .../src/components/settings/constants.ts | 1 + .../settings/providers/KeywordsAI.tsx | 111 ++++++++++++++++++ .../components/settings/providers/index.ts | 1 + 14 files changed, 262 insertions(+), 4 deletions(-) create mode 100644 src/api/providers/fetchers/keywordsai.ts create mode 100644 src/api/providers/keywordsai.ts create mode 100644 webview-ui/src/components/settings/providers/KeywordsAI.tsx diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 0c5965f7ff6..7fcfee0411e 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -50,6 +50,7 @@ export const dynamicProviders = [ "unbound", "roo", "chutes", + "keywordsai", ] as const export type DynamicProvider = (typeof dynamicProviders)[number] @@ -140,6 +141,7 @@ export const providerNames = [ "vertex", "xai", "zai", + "keywordsai", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -417,6 +419,12 @@ const basetenSchema = apiModelIdProviderModelSchema.extend({ basetenApiKey: z.string().optional(), }) +const keywordsaiSchema = apiModelIdProviderModelSchema.extend({ + keywordsaiApiKey: z.string().optional(), + keywordsaiBaseUrl: z.string().optional(), + keywordsaiEnableLogging: z.boolean().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -458,6 +466,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })), rooSchema.merge(z.object({ apiProvider: z.literal("roo") })), vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })), + keywordsaiSchema.merge(z.object({ apiProvider: z.literal("keywordsai") })), defaultSchema, ]) @@ -499,6 +508,7 @@ export const providerSettingsSchema = z.object({ ...qwenCodeSchema.shape, ...rooSchema.shape, ...vercelAiGatewaySchema.shape, + ...keywordsaiSchema.shape, ...codebaseIndexProviderSchema.shape, }) @@ -584,6 +594,7 @@ export const modelIdKeysByProvider: Record = { "io-intelligence": "ioIntelligenceModelId", roo: "apiModelId", "vercel-ai-gateway": "vercelAiGatewayModelId", + keywordsai: "apiModelId", } /** @@ -720,6 +731,7 @@ export const MODELS_BY_PROVIDER: Record< deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] }, "vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] }, chutes: { id: "chutes", label: "Chutes AI", models: [] }, + keywordsai: { id: "keywordsai", label: "Keywords AI", models: [] }, // Local providers; models discovered from localhost endpoints. lmstudio: { id: "lmstudio", label: "LM Studio", models: [] }, diff --git a/src/api/index.ts b/src/api/index.ts index 1995380a68d..416e9835155 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -41,6 +41,7 @@ import { DeepInfraHandler, MiniMaxHandler, BasetenHandler, + KeywordsAiHandler, } from "./providers" import { NativeOllamaHandler } from "./providers/native-ollama" @@ -197,6 +198,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new MiniMaxHandler(options) case "baseten": return new BasetenHandler(options) + case "keywordsai": + return new KeywordsAiHandler(options) default: return new AnthropicHandler(options) } diff --git a/src/api/providers/fetchers/keywordsai.ts b/src/api/providers/fetchers/keywordsai.ts new file mode 100644 index 00000000000..c2b2d768621 --- /dev/null +++ b/src/api/providers/fetchers/keywordsai.ts @@ -0,0 +1,55 @@ +import axios from "axios" +import { z } from "zod" + +import type { ModelInfo } from "@roo-code/types" + +import { DEFAULT_HEADERS } from "../constants" + +const KeywordsAIProviderSchema = z.object({ + provider_name: z.string().optional(), + provider_id: z.string().optional(), + moderation: z.string().optional(), + credential_fields: z.array(z.string()).optional(), +}) + +const KeywordsAIModelSchema = z.object({ + model_name: z.string(), + max_context_window: z.number(), + input_cost: z.number(), + output_cost: z.number(), + rate_limit: z.number().optional(), + provider: KeywordsAIProviderSchema.optional(), +}) + +const KeywordsAIModelsResponseSchema = z.object({ + models: z.array(KeywordsAIModelSchema), +}) + +export async function getKeywordsAiModels( + baseUrl: string = "https://api.keywordsai.co/api/", +): Promise> { + const url = `${baseUrl.replace(/\/$/, "")}/models/public` + const models: Record = {} + + const response = await axios.get(url, { headers: DEFAULT_HEADERS }) + const parsed = KeywordsAIModelsResponseSchema.safeParse(response.data) + const data = parsed.success ? parsed.data.models : (response.data?.models ?? []) + + for (const m of data as z.infer[]) { + const contextWindow = m.max_context_window ?? 8192 + const maxTokens = Math.ceil(contextWindow * 0.2) + + const info: ModelInfo = { + maxTokens, + contextWindow, + supportsImages: false, + supportsPromptCache: false, + inputPrice: m.input_cost, + outputPrice: m.output_cost, + } + + models[m.model_name] = info + } + + return models +} diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 51ca19e2bce..35eabe1b081 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -29,6 +29,7 @@ import { getDeepInfraModels } from "./deepinfra" import { getHuggingFaceModels } from "./huggingface" import { getRooModels } from "./roo" import { getChutesModels } from "./chutes" +import { getKeywordsAiModels } from "./keywordsai" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) @@ -108,6 +109,9 @@ async function fetchModelsFromProvider(options: GetModelsOptions): Promise { { provider: "openrouter", options: { provider: "openrouter" } }, { provider: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } }, { provider: "chutes", options: { provider: "chutes" } }, + { + provider: "keywordsai", + options: { provider: "keywordsai", baseUrl: "https://api.keywordsai.co/api/" }, + }, ] // Refresh each provider in background (fire and forget) diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 141839e29f9..21f327ed475 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -33,3 +33,4 @@ export { VercelAiGatewayHandler } from "./vercel-ai-gateway" export { DeepInfraHandler } from "./deepinfra" export { MiniMaxHandler } from "./minimax" export { BasetenHandler } from "./baseten" +export { KeywordsAiHandler } from "./keywordsai" diff --git a/src/api/providers/keywordsai.ts b/src/api/providers/keywordsai.ts new file mode 100644 index 00000000000..034f4eb6785 --- /dev/null +++ b/src/api/providers/keywordsai.ts @@ -0,0 +1,30 @@ +import type { ApiHandlerOptions } from "../../shared/api" +import { OpenAiHandler } from "./openai" + +/** + * Keywords AI gateway handler. Uses the OpenAI-compatible gateway; + * only adds disable_log when logging is disabled (enable-logging option). + */ +export class KeywordsAiHandler extends OpenAiHandler { + constructor(options: ApiHandlerOptions) { + const baseUrl = options.keywordsaiBaseUrl || "https://api.keywordsai.co/api/" + super({ + ...options, + openAiApiKey: options.keywordsaiApiKey ?? "not-provided", + openAiBaseUrl: baseUrl, + openAiModelId: options.apiModelId, + openAiStreamingEnabled: true, + openAiHeaders: { + "X-KeywordsAI-Source": "RooCode-Extension", + ...(options.openAiHeaders || {}), + }, + }) + } + + protected override getExtraRequestParams(): Record { + if (this.options.keywordsaiEnableLogging === false) { + return { disable_log: true } + } + return {} + } +} diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 74cbb511138..a04e5e7912b 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -167,6 +167,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl // Add max_tokens if needed this.addMaxTokensIfNeeded(requestOptions, modelInfo) + Object.assign(requestOptions, this.getExtraRequestParams()) + let stream try { stream = await this.client.chat.completions.create( @@ -235,6 +237,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl // Add max_tokens if needed this.addMaxTokensIfNeeded(requestOptions, modelInfo) + Object.assign(requestOptions, this.getExtraRequestParams()) + let response try { response = await this.client.chat.completions.create( @@ -269,6 +273,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + /** + * Optional extra body params merged into chat completions create(). + * Subclasses (e.g. Keywords AI gateway) use this for provider-specific params like disable_log. + */ + protected getExtraRequestParams(): Record { + return {} + } + protected processUsageMetrics(usage: any, _modelInfo?: ModelInfo): ApiStreamUsageChunk { return { type: "usage", diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index d2be23714e2..d323dd4733a 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -2704,7 +2704,9 @@ export class Task extends EventEmitter implements TaskLike { // Yields only if the first chunk is successful, otherwise will // allow the user to retry the request (most likely due to rate // limit error, which gets thrown on the first chunk). - const stream = this.attemptApiRequest(currentItem.retryAttempt ?? 0, { skipProviderRateLimit: true }) + const stream = this.attemptApiRequest(currentItem.retryAttempt ?? 0, { + skipProviderRateLimit: true, + }) let assistantMessage = "" let reasoningMessage = "" let pendingGroundingSources: GroundingSource[] = [] @@ -4202,7 +4204,7 @@ export class Task extends EventEmitter implements TaskLike { ) await this.handleContextWindowExceededError() // Retry the request after handling the context window error - yield* this.attemptApiRequest(retryAttempt + 1) + yield* this.attemptApiRequest(retryAttempt + 1, options) return } @@ -4222,7 +4224,7 @@ export class Task extends EventEmitter implements TaskLike { // Delegate generator output from the recursive call with // incremented retry count. - yield* this.attemptApiRequest(retryAttempt + 1) + yield* this.attemptApiRequest(retryAttempt + 1, options) return } else { @@ -4240,7 +4242,7 @@ export class Task extends EventEmitter implements TaskLike { await this.say("api_req_retried") // Delegate generator output from the recursive call. - yield* this.attemptApiRequest() + yield* this.attemptApiRequest(0, options) return } } diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 2dc77d05027..316d1c37658 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -874,6 +874,7 @@ export const webviewMessageHandler = async ( lmstudio: {}, roo: {}, chutes: {}, + keywordsai: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -924,6 +925,13 @@ export const webviewMessageHandler = async ( key: "chutes", options: { provider: "chutes", apiKey: apiConfiguration.chutesApiKey }, }, + { + key: "keywordsai", + options: { + provider: "keywordsai", + baseUrl: apiConfiguration.keywordsaiBaseUrl || "https://api.keywordsai.co/api/", + }, + }, ] // IO Intelligence is conditional on api key diff --git a/src/shared/api.ts b/src/shared/api.ts index b2ba1e35420..61207601645 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -181,6 +181,7 @@ const dynamicProviderExtras = { lmstudio: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type roo: {} as { apiKey?: string; baseUrl?: string }, chutes: {} as { apiKey?: string }, + keywordsai: {} as { apiKey?: string; baseUrl?: string }, } as const satisfies Record // Build the dynamic options union from the map, intersected with CommonFetchParams diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 939d2734d4b..f73d5b0b40b 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -106,6 +106,7 @@ import { VercelAiGateway, DeepInfra, MiniMax, + KeywordsAI, } from "./providers" import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants" @@ -591,6 +592,18 @@ const ApiOptions = ({ /> )} + {selectedProvider === "keywordsai" && ( + + )} + {selectedProvider === "bedrock" && ( a.label.localeCompare(b.label)) diff --git a/webview-ui/src/components/settings/providers/KeywordsAI.tsx b/webview-ui/src/components/settings/providers/KeywordsAI.tsx new file mode 100644 index 00000000000..79ee0992f5c --- /dev/null +++ b/webview-ui/src/components/settings/providers/KeywordsAI.tsx @@ -0,0 +1,111 @@ +import { useCallback, useState } from "react" +import { VSCodeCheckbox, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import type { OrganizationAllowList, ProviderSettings, RouterModels } from "@roo-code/types" + +import { vscode } from "@src/utils/vscode" +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { Button } from "@src/components/ui" + +import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" + +type KeywordsAIProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void + routerModels?: RouterModels + refetchRouterModels: () => void + organizationAllowList: OrganizationAllowList + modelValidationError?: string + simplifySettings?: boolean +} + +export const KeywordsAI = ({ + apiConfiguration, + setApiConfigurationField, + routerModels, + refetchRouterModels, + organizationAllowList, + modelValidationError, + simplifySettings, +}: KeywordsAIProps) => { + const { t } = useAppTranslation() + const [didRefetch, setDidRefetch] = useState() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + return ( + <> + + + + + + + + { + setApiConfigurationField("keywordsaiEnableLogging", e.target.checked) + }}> + Enable Logging + +
+ Controls whether request/response data is logged. When disabled, only performance metrics are recorded. +
+ + + {didRefetch && ( +
+ {t("settings:providers.refreshModels.hint")} +
+ )} + + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index bca620d052d..e9649ba1254 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -32,3 +32,4 @@ export { VercelAiGateway } from "./VercelAiGateway" export { DeepInfra } from "./DeepInfra" export { MiniMax } from "./MiniMax" export { Baseten } from "./Baseten" +export { KeywordsAI } from "./KeywordsAI" From 9d8f385502663e5c8a7b4d7d77bbbbccfbef7e3f Mon Sep 17 00:00:00 2001 From: chen sihan Date: Sun, 25 Jan 2026 19:41:45 -0800 Subject: [PATCH 2/2] add keywords ai case --- webview-ui/src/components/ui/hooks/useSelectedModel.ts | 5 +++++ webview-ui/src/utils/__tests__/validate.spec.ts | 1 + 2 files changed, 6 insertions(+) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 8eac6fa7403..2a9e3293a4f 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -368,6 +368,11 @@ function getSelectedModel({ const info = routerModels["vercel-ai-gateway"]?.[id] return { id, info } } + case "keywordsai": { + const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.keywordsai, defaultModelId) + const routerInfo = routerModels.keywordsai?.[id] + return { id, info: routerInfo } + } // case "anthropic": // case "fake-ai": default: { diff --git a/webview-ui/src/utils/__tests__/validate.spec.ts b/webview-ui/src/utils/__tests__/validate.spec.ts index 09239b649ca..9219adb77a5 100644 --- a/webview-ui/src/utils/__tests__/validate.spec.ts +++ b/webview-ui/src/utils/__tests__/validate.spec.ts @@ -20,6 +20,7 @@ import { getModelValidationError, validateApiConfigurationExcludingModelErrors, describe("Model Validation Functions", () => { const mockRouterModels: RouterModels = { + keywordsai: {}, openrouter: { "valid-model": { maxTokens: 8192,