diff --git a/packages/types/src/providers/huggingface.ts b/packages/types/src/providers/huggingface.ts new file mode 100644 index 00000000000..d2571a073e7 --- /dev/null +++ b/packages/types/src/providers/huggingface.ts @@ -0,0 +1,17 @@ +/** + * HuggingFace provider constants + */ + +// Default values for HuggingFace models +export const HUGGINGFACE_DEFAULT_MAX_TOKENS = 2048 +export const HUGGINGFACE_MAX_TOKENS_FALLBACK = 8192 +export const HUGGINGFACE_DEFAULT_CONTEXT_WINDOW = 128_000 + +// UI constants +export const HUGGINGFACE_SLIDER_STEP = 256 +export const HUGGINGFACE_SLIDER_MIN = 1 +export const HUGGINGFACE_TEMPERATURE_MAX_VALUE = 2 + +// API constants +export const HUGGINGFACE_API_URL = "https://router.huggingface.co/v1/models?collection=roocode" +export const HUGGINGFACE_CACHE_DURATION = 1000 * 60 * 60 // 1 hour diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index e4e506b8a7b..f5061f152c0 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -6,6 +6,7 @@ export * from "./deepseek.js" export * from "./gemini.js" export * from "./glama.js" export * from "./groq.js" +export * from "./huggingface.js" export * from "./lite-llm.js" export * from "./lm-studio.js" export * from "./mistral.js" diff --git a/src/api/huggingface-models.ts b/src/api/huggingface-models.ts index ec1915d0e3d..1ee6369d4bc 100644 --- a/src/api/huggingface-models.ts +++ b/src/api/huggingface-models.ts @@ -1,4 +1,10 @@ -import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models" +import { + getHuggingFaceModels as fetchModels, + getCachedRawHuggingFaceModels, + type HuggingFaceModel, +} from "./providers/fetchers/huggingface" +import axios from "axios" +import { HUGGINGFACE_API_URL } from "@roo-code/types" export interface HuggingFaceModelsResponse { models: HuggingFaceModel[] @@ -7,11 +13,49 @@ export interface HuggingFaceModelsResponse { } export async function getHuggingFaceModels(): Promise { - const models = await fetchHuggingFaceModels() + try { + // First, trigger the fetch to populate cache + await fetchModels() - return { - models, - cached: false, // We could enhance this to track if data came from cache - timestamp: Date.now(), + // Get the raw models from cache + const cachedRawModels = getCachedRawHuggingFaceModels() + + if (cachedRawModels) { + return { + models: cachedRawModels, + cached: true, + timestamp: Date.now(), + } + } + + // If no cached raw models, fetch directly from API + const response = await axios.get(HUGGINGFACE_API_URL, { + headers: { + "Upgrade-Insecure-Requests": "1", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + Priority: "u=0, i", + Pragma: "no-cache", + "Cache-Control": "no-cache", + }, + timeout: 10000, + }) + + const models = response.data?.data || [] + + return { + models, + cached: false, + timestamp: Date.now(), + } + } catch (error) { + console.error("Failed to get HuggingFace models:", error) + return { + models: [], + cached: false, + timestamp: Date.now(), + } } } diff --git a/src/api/providers/fetchers/huggingface.ts b/src/api/providers/fetchers/huggingface.ts new file mode 100644 index 00000000000..16c33b9e047 --- /dev/null +++ b/src/api/providers/fetchers/huggingface.ts @@ -0,0 +1,229 @@ +import axios from "axios" +import { z } from "zod" +import type { ModelInfo } from "@roo-code/types" +import { + HUGGINGFACE_API_URL, + HUGGINGFACE_CACHE_DURATION, + HUGGINGFACE_DEFAULT_MAX_TOKENS, + HUGGINGFACE_DEFAULT_CONTEXT_WINDOW, +} from "@roo-code/types" +import type { ModelRecord } from "../../../shared/api" + +/** + * HuggingFace Provider Schema + */ +const huggingFaceProviderSchema = z.object({ + provider: z.string(), + status: z.enum(["live", "staging", "error"]), + supports_tools: z.boolean().optional(), + supports_structured_output: z.boolean().optional(), + context_length: z.number().optional(), + pricing: z + .object({ + input: z.number(), + output: z.number(), + }) + .optional(), +}) + +/** + * Represents a provider that can serve a HuggingFace model + * @property provider - The provider identifier (e.g., "sambanova", "together") + * @property status - The current status of the provider + * @property supports_tools - Whether the provider supports tool/function calling + * @property supports_structured_output - Whether the provider supports structured output + * @property context_length - The maximum context length supported by this provider + * @property pricing - The pricing information for input/output tokens + */ +export type HuggingFaceProvider = z.infer + +/** + * HuggingFace Model Schema + */ +const huggingFaceModelSchema = z.object({ + id: z.string(), + object: z.literal("model"), + created: z.number(), + owned_by: z.string(), + providers: z.array(huggingFaceProviderSchema), +}) + +/** + * Represents a HuggingFace model available through the router API + * @property id - The unique identifier of the model + * @property object - The object type (always "model") + * @property created - Unix timestamp of when the model was created + * @property owned_by - The organization that owns the model + * @property providers - List of providers that can serve this model + */ +export type HuggingFaceModel = z.infer + +/** + * HuggingFace API Response Schema + */ +const huggingFaceApiResponseSchema = z.object({ + object: z.string(), + data: z.array(huggingFaceModelSchema), +}) + +/** + * Represents the response from the HuggingFace router API + * @property object - The response object type + * @property data - Array of available models + */ +type HuggingFaceApiResponse = z.infer + +/** + * Cache entry for storing fetched models + * @property data - The cached model records + * @property timestamp - Unix timestamp of when the cache was last updated + */ +interface CacheEntry { + data: ModelRecord + rawModels?: HuggingFaceModel[] + timestamp: number +} + +let cache: CacheEntry | null = null + +/** + * Parse a HuggingFace model into ModelInfo format + * @param model - The HuggingFace model to parse + * @param provider - Optional specific provider to use for capabilities + * @returns ModelInfo object compatible with the application's model system + */ +function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFaceProvider): ModelInfo { + // Use provider-specific values if available, otherwise find first provider with values + const contextLength = + provider?.context_length || + model.providers.find((p) => p.context_length)?.context_length || + HUGGINGFACE_DEFAULT_CONTEXT_WINDOW + + const pricing = provider?.pricing || model.providers.find((p) => p.pricing)?.pricing + + // Include provider name in description if specific provider is given + const description = provider ? `${model.id} via ${provider.provider}` : `${model.id} via HuggingFace` + + return { + maxTokens: Math.min(contextLength, HUGGINGFACE_DEFAULT_MAX_TOKENS), + contextWindow: contextLength, + supportsImages: false, // HuggingFace API doesn't provide this info yet + supportsPromptCache: false, + supportsComputerUse: false, + inputPrice: pricing?.input, + outputPrice: pricing?.output, + description, + } +} + +/** + * Fetches available models from HuggingFace + * + * @returns A promise that resolves to a record of model IDs to model info + * @throws Will throw an error if the request fails + */ +export async function getHuggingFaceModels(): Promise { + const now = Date.now() + + // Check cache + if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) { + return cache.data + } + + const models: ModelRecord = {} + + try { + const response = await axios.get(HUGGINGFACE_API_URL, { + headers: { + "Upgrade-Insecure-Requests": "1", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + Priority: "u=0, i", + Pragma: "no-cache", + "Cache-Control": "no-cache", + }, + timeout: 10000, // 10 second timeout + }) + + const result = huggingFaceApiResponseSchema.safeParse(response.data) + + if (!result.success) { + console.error("HuggingFace models response validation failed:", result.error.format()) + throw new Error("Invalid response format from HuggingFace API") + } + + const validModels = result.data.data.filter((model) => model.providers.length > 0) + + for (const model of validModels) { + // Add the base model + models[model.id] = parseHuggingFaceModel(model) + + // Add provider-specific variants for all live providers + for (const provider of model.providers) { + if (provider.status === "live") { + const providerKey = `${model.id}:${provider.provider}` + const providerModel = parseHuggingFaceModel(model, provider) + + // Always add provider variants to show all available providers + models[providerKey] = providerModel + } + } + } + + // Update cache + cache = { + data: models, + rawModels: validModels, + timestamp: now, + } + + return models + } catch (error) { + console.error("Error fetching HuggingFace models:", error) + + // Return cached data if available + if (cache) { + return cache.data + } + + // Re-throw with more context + if (axios.isAxiosError(error)) { + if (error.response) { + throw new Error( + `Failed to fetch HuggingFace models: ${error.response.status} ${error.response.statusText}`, + ) + } else if (error.request) { + throw new Error( + "Failed to fetch HuggingFace models: No response from server. Check your internet connection.", + ) + } + } + + throw new Error( + `Failed to fetch HuggingFace models: ${error instanceof Error ? error.message : "Unknown error"}`, + ) + } +} + +/** + * Get cached models without making an API request + */ +export function getCachedHuggingFaceModels(): ModelRecord | null { + return cache?.data || null +} + +/** + * Get cached raw models for UI display + */ +export function getCachedRawHuggingFaceModels(): HuggingFaceModel[] | null { + return cache?.rawModels || null +} + +/** + * Clear the cache + */ +export function clearHuggingFaceCache(): void { + cache = null +} diff --git a/src/api/providers/huggingface.ts b/src/api/providers/huggingface.ts index 913605bd929..aa158654c9a 100644 --- a/src/api/providers/huggingface.ts +++ b/src/api/providers/huggingface.ts @@ -1,16 +1,18 @@ import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" -import type { ApiHandlerOptions } from "../../shared/api" +import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" +import { getHuggingFaceModels, getCachedHuggingFaceModels } from "./fetchers/huggingface" export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler { private client: OpenAI private options: ApiHandlerOptions + private modelCache: ModelRecord | null = null constructor(options: ApiHandlerOptions) { super() @@ -25,6 +27,20 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion apiKey: this.options.huggingFaceApiKey, defaultHeaders: DEFAULT_HEADERS, }) + + // Try to get cached models first + this.modelCache = getCachedHuggingFaceModels() + + // Fetch models asynchronously + this.fetchModels() + } + + private async fetchModels() { + try { + this.modelCache = await getHuggingFaceModels() + } catch (error) { + console.error("Failed to fetch HuggingFace models:", error) + } } override async *createMessage( @@ -43,6 +59,11 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion stream_options: { include_usage: true }, } + // Add max_tokens if specified + if (this.options.includeMaxTokens && this.options.modelMaxTokens) { + params.max_tokens = this.options.modelMaxTokens + } + const stream = await this.client.chat.completions.create(params) for await (const chunk of stream) { @@ -86,6 +107,18 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion override getModel() { const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + + // Try to get model info from cache + const modelInfo = this.modelCache?.[modelId] + + if (modelInfo) { + return { + id: modelId, + info: modelInfo, + } + } + + // Fallback to default values if model not found in cache return { id: modelId, info: { diff --git a/src/services/huggingface-models.ts b/src/services/huggingface-models.ts deleted file mode 100644 index 9c0bc406f93..00000000000 --- a/src/services/huggingface-models.ts +++ /dev/null @@ -1,171 +0,0 @@ -export interface HuggingFaceModel { - _id: string - id: string - inferenceProviderMapping: InferenceProviderMapping[] - trendingScore: number - config: ModelConfig - tags: string[] - pipeline_tag: "text-generation" | "image-text-to-text" - library_name?: string -} - -export interface InferenceProviderMapping { - provider: string - providerId: string - status: "live" | "staging" | "error" - task: "conversational" -} - -export interface ModelConfig { - architectures: string[] - model_type: string - tokenizer_config?: { - chat_template?: string | Array<{ name: string; template: string }> - model_max_length?: number - } -} - -interface HuggingFaceApiParams { - pipeline_tag?: "text-generation" | "image-text-to-text" - filter: string - inference_provider: string - limit: number - expand: string[] -} - -const DEFAULT_PARAMS: HuggingFaceApiParams = { - filter: "conversational", - inference_provider: "all", - limit: 100, - expand: [ - "inferenceProviderMapping", - "config", - "library_name", - "pipeline_tag", - "tags", - "mask_token", - "trendingScore", - ], -} - -const BASE_URL = "https://huggingface.co/api/models" -const CACHE_DURATION = 1000 * 60 * 60 // 1 hour - -interface CacheEntry { - data: HuggingFaceModel[] - timestamp: number - status: "success" | "partial" | "error" -} - -let cache: CacheEntry | null = null - -function buildApiUrl(params: HuggingFaceApiParams): string { - const url = new URL(BASE_URL) - - // Add simple params - Object.entries(params).forEach(([key, value]) => { - if (!Array.isArray(value)) { - url.searchParams.append(key, String(value)) - } - }) - - // Handle array params specially - params.expand.forEach((item) => { - url.searchParams.append("expand[]", item) - }) - - return url.toString() -} - -const headers: HeadersInit = { - "Upgrade-Insecure-Requests": "1", - "Sec-Fetch-Dest": "document", - "Sec-Fetch-Mode": "navigate", - "Sec-Fetch-Site": "none", - "Sec-Fetch-User": "?1", - Priority: "u=0, i", - Pragma: "no-cache", - "Cache-Control": "no-cache", -} - -const requestInit: RequestInit = { - credentials: "include", - headers, - method: "GET", - mode: "cors", -} - -export async function fetchHuggingFaceModels(): Promise { - const now = Date.now() - - // Check cache - if (cache && now - cache.timestamp < CACHE_DURATION) { - console.log("Using cached Hugging Face models") - return cache.data - } - - try { - console.log("Fetching Hugging Face models from API...") - - // Fetch both text-generation and image-text-to-text models in parallel - const [textGenResponse, imgTextResponse] = await Promise.allSettled([ - fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "text-generation" }), requestInit), - fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "image-text-to-text" }), requestInit), - ]) - - let textGenModels: HuggingFaceModel[] = [] - let imgTextModels: HuggingFaceModel[] = [] - let hasErrors = false - - // Process text-generation models - if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) { - textGenModels = await textGenResponse.value.json() - } else { - console.error("Failed to fetch text-generation models:", textGenResponse) - hasErrors = true - } - - // Process image-text-to-text models - if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) { - imgTextModels = await imgTextResponse.value.json() - } else { - console.error("Failed to fetch image-text-to-text models:", imgTextResponse) - hasErrors = true - } - - // Combine and filter models - const allModels = [...textGenModels, ...imgTextModels] - .filter((model) => model.inferenceProviderMapping.length > 0) - .sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase())) - - // Update cache - cache = { - data: allModels, - timestamp: now, - status: hasErrors ? "partial" : "success", - } - - console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`) - return allModels - } catch (error) { - console.error("Error fetching Hugging Face models:", error) - - // Return cached data if available - if (cache) { - console.log("Using stale cached data due to fetch error") - cache.status = "error" - return cache.data - } - - // No cache available, return empty array - return [] - } -} - -export function getCachedModels(): HuggingFaceModel[] | null { - return cache?.data || null -} - -export function clearCache(): void { - cache = null -} diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 000762e317a..bdd32c4e36b 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -138,26 +138,21 @@ export interface ExtensionMessage { lmStudioModels?: string[] vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] huggingFaceModels?: Array<{ - _id: string id: string - inferenceProviderMapping: Array<{ + object: string + created: number + owned_by: string + providers: Array<{ provider: string - providerId: string status: "live" | "staging" | "error" - task: "conversational" - }> - trendingScore: number - config: { - architectures: string[] - model_type: string - tokenizer_config?: { - chat_template?: string | Array<{ name: string; template: string }> - model_max_length?: number + supports_tools?: boolean + supports_structured_output?: boolean + context_length?: number + pricing?: { + input: number + output: number } - } - tags: string[] - pipeline_tag: "text-generation" | "image-text-to-text" - library_name?: string + }> }> mcpServers?: McpServer[] commits?: GitCommit[] diff --git a/webview-ui/src/components/settings/providers/HuggingFace.tsx b/webview-ui/src/components/settings/providers/HuggingFace.tsx index d4195492dd7..afbdf2ff599 100644 --- a/webview-ui/src/components/settings/providers/HuggingFace.tsx +++ b/webview-ui/src/components/settings/providers/HuggingFace.tsx @@ -1,38 +1,43 @@ import { useCallback, useState, useEffect, useMemo } from "react" import { useEvent } from "react-use" -import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react" import type { ProviderSettings } from "@roo-code/types" +import { + HUGGINGFACE_DEFAULT_MAX_TOKENS, + HUGGINGFACE_MAX_TOKENS_FALLBACK, + HUGGINGFACE_SLIDER_STEP, + HUGGINGFACE_SLIDER_MIN, + HUGGINGFACE_TEMPERATURE_MAX_VALUE, +} from "@roo-code/types" import { ExtensionMessage } from "@roo/ExtensionMessage" import { vscode } from "@src/utils/vscode" import { useAppTranslation } from "@src/i18n/TranslationContext" import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" -import { SearchableSelect, type SearchableSelectOption } from "@src/components/ui" +import { SearchableSelect, type SearchableSelectOption, Slider } from "@src/components/ui" +import { TemperatureControl } from "../TemperatureControl" +import { cn } from "@src/lib/utils" +import { formatPrice } from "@/utils/formatPrice" import { inputEventTransform } from "../transforms" type HuggingFaceModel = { - _id: string id: string - inferenceProviderMapping: Array<{ + object: string + created: number + owned_by: string + providers: Array<{ provider: string - providerId: string status: "live" | "staging" | "error" - task: "conversational" - }> - trendingScore: number - config: { - architectures: string[] - model_type: string - tokenizer_config?: { - chat_template?: string | Array<{ name: string; template: string }> - model_max_length?: number + supports_tools?: boolean + supports_structured_output?: boolean + context_length?: number + pricing?: { + input: number + output: number } - } - tags: string[] - pipeline_tag: "text-generation" | "image-text-to-text" - library_name?: string + }> } type HuggingFaceProps = { @@ -81,10 +86,7 @@ export const HuggingFace = ({ apiConfiguration, setApiConfigurationField }: Hugg // Get current model and its providers const currentModel = models.find((m) => m.id === apiConfiguration?.huggingFaceModelId) - const availableProviders = useMemo( - () => currentModel?.inferenceProviderMapping || [], - [currentModel?.inferenceProviderMapping], - ) + const availableProviders = useMemo(() => currentModel?.providers || [], [currentModel?.providers]) // Set default provider when model changes useEffect(() => { @@ -140,6 +142,32 @@ export const HuggingFace = ({ apiConfiguration, setApiConfigurationField }: Hugg return nameMap[provider] || provider.charAt(0).toUpperCase() + provider.slice(1) } + // Get current provider + const currentProvider = useMemo(() => { + if (!currentModel || !selectedProvider || selectedProvider === "auto") return null + return currentModel.providers.find((p) => p.provider === selectedProvider) + }, [currentModel, selectedProvider]) + + // Get model capabilities based on current provider + const modelCapabilities = useMemo(() => { + if (!currentModel) return null + + // For now, assume text-only models since we don't have pipeline_tag in new API + // This could be enhanced by checking model name patterns or adding vision support detection + const supportsImages = false + + // Use provider-specific capabilities if a specific provider is selected + const maxTokens = + currentProvider?.context_length || currentModel.providers.find((p) => p.context_length)?.context_length + const supportsTools = currentProvider?.supports_tools || currentModel.providers.some((p) => p.supports_tools) + + return { + supportsImages, + maxTokens, + supportsTools, + } + }, [currentModel, currentProvider]) + return ( <> )} + {/* Model capabilities */} + {currentModel && modelCapabilities && ( +
+
+ + {modelCapabilities.supportsImages + ? t("settings:modelInfo.supportsImages") + : t("settings:modelInfo.noImages")} +
+ {modelCapabilities.maxTokens && ( +
+ {t("settings:modelInfo.maxOutput")}:{" "} + {modelCapabilities.maxTokens.toLocaleString()} tokens +
+ )} + {currentProvider?.pricing && ( + <> +
+ {t("settings:modelInfo.inputPrice")}:{" "} + {formatPrice(currentProvider.pricing.input)} / 1M tokens +
+
+ {t("settings:modelInfo.outputPrice")}:{" "} + {formatPrice(currentProvider.pricing.output)} / 1M tokens +
+ + )} +
+ )} + + {/* Temperature control */} + setApiConfigurationField("modelTemperature", value)} + maxValue={HUGGINGFACE_TEMPERATURE_MAX_VALUE} + /> + + {/* Max tokens control */} +
+ + setApiConfigurationField("includeMaxTokens", (e.target as HTMLInputElement).checked) + }> + + +
+ {t("settings:limitMaxTokensDescription")} +
+ + {apiConfiguration?.includeMaxTokens && ( +
+
+ +
+ setApiConfigurationField("modelMaxTokens", value)} + /> + + {apiConfiguration?.modelMaxTokens || HUGGINGFACE_DEFAULT_MAX_TOKENS} + +
+
+ {t("settings:maxTokensGenerateDescription")} +
+
+
+ )} +
+
{t("settings:providers.apiKeyStorageNotice")}
diff --git a/webview-ui/src/components/settings/providers/__tests__/HuggingFace.spec.tsx b/webview-ui/src/components/settings/providers/__tests__/HuggingFace.spec.tsx index 3fd29e4c722..0256fe94d35 100644 --- a/webview-ui/src/components/settings/providers/__tests__/HuggingFace.spec.tsx +++ b/webview-ui/src/components/settings/providers/__tests__/HuggingFace.spec.tsx @@ -1,9 +1,8 @@ -import React from "react" import { render, screen } from "@/utils/test-utils" import { HuggingFace } from "../HuggingFace" import { ProviderSettings } from "@roo-code/types" -// Mock the VSCodeTextField component +// Mock the VSCode components vi.mock("@vscode/webview-ui-toolkit/react", () => ({ VSCodeTextField: ({ children, @@ -32,6 +31,18 @@ vi.mock("@vscode/webview-ui-toolkit/react", () => ({ ) }, + VSCodeCheckbox: ({ children, checked, onChange, ...rest }: any) => ( +
+ + {children} +
+ ), VSCodeLink: ({ children, href, onClick }: any) => ( {children} @@ -53,6 +64,10 @@ vi.mock("@src/i18n/TranslationContext", () => ({ "settings:providers.getHuggingFaceApiKey": "Get Hugging Face API Key", "settings:providers.huggingFaceApiKey": "Hugging Face API Key", "settings:providers.huggingFaceModelId": "Model ID", + "settings:modelInfo.fetchingModels": "Fetching models...", + "settings:modelInfo.errorFetchingModels": "Error fetching models", + "settings:modelInfo.noModelsFound": "No models found", + "settings:modelInfo.noImages": "Does not support images", } return translations[key] || key }, @@ -79,11 +94,31 @@ vi.mock("@src/components/ui", () => ({ ), })) +// Mock the formatPrice utility +vi.mock("@/utils/formatPrice", () => ({ + formatPrice: (price: number) => `$${price.toFixed(2)}`, +})) + +// Create a mock postMessage function +const mockPostMessage = vi.fn() + +// Mock the vscode module +vi.mock("@src/utils/vscode", () => ({ + vscode: { + postMessage: vi.fn(), + }, +})) + +// Import the mocked module to set up the spy +import { vscode } from "@src/utils/vscode" + describe("HuggingFace Component", () => { const mockSetApiConfigurationField = vi.fn() beforeEach(() => { vi.clearAllMocks() + // Set up the mock implementation + vi.mocked(vscode.postMessage).mockImplementation(mockPostMessage) }) it("should render with internationalized labels", () => { @@ -159,4 +194,102 @@ describe("HuggingFace Component", () => { expect(apiKeyButton).toBeInTheDocument() expect(apiKeyButton).toHaveTextContent("Get Hugging Face API Key") }) + + it("should fetch models when component mounts", () => { + const apiConfiguration: Partial = { + huggingFaceApiKey: "test-api-key", + huggingFaceModelId: "", + } + + render( + , + ) + + // Check that the fetch models message was sent + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "requestHuggingFaceModels", + }) + }) + + it("should display loading state while fetching models", () => { + const apiConfiguration: Partial = { + huggingFaceApiKey: "test-api-key", + huggingFaceModelId: "", + } + + render( + , + ) + + // Check for loading text in the label + expect(screen.getByText("settings:providers.huggingFaceLoading")).toBeInTheDocument() + }) + + it("should display model capabilities when a model is selected", async () => { + const apiConfiguration: Partial = { + huggingFaceApiKey: "test-api-key", + huggingFaceModelId: "test-model", + huggingFaceInferenceProvider: "test-provider", // Select a specific provider to show pricing + } + + const { rerender } = render( + , + ) + + // Simulate receiving models from the backend + const mockModels = [ + { + id: "test-model", + object: "model", + created: Date.now(), + owned_by: "test", + providers: [ + { + provider: "test-provider", + status: "live" as const, + supports_tools: false, + supports_structured_output: false, + context_length: 8192, + pricing: { + input: 0.001, + output: 0.002, + }, + }, + ], + }, + ] + + // Simulate message event + const messageEvent = new MessageEvent("message", { + data: { + type: "huggingFaceModels", + huggingFaceModels: mockModels, + }, + }) + window.dispatchEvent(messageEvent) + + // Re-render to trigger effect + rerender( + , + ) + + // Check that model capabilities are displayed + expect(screen.getByText("Does not support images")).toBeInTheDocument() + expect(screen.getByText("8,192 tokens")).toBeInTheDocument() + // Check that both input and output prices are displayed + const priceElements = screen.getAllByText("$0.00 / 1M tokens") + expect(priceElements).toHaveLength(2) // One for input, one for output + }) }) diff --git a/webview-ui/src/i18n/locales/ca/settings.json b/webview-ui/src/i18n/locales/ca/settings.json index 047b5858bf2..3388e341697 100644 --- a/webview-ui/src/i18n/locales/ca/settings.json +++ b/webview-ui/src/i18n/locales/ca/settings.json @@ -754,5 +754,8 @@ "useCustomArn": "Utilitza ARN personalitzat..." }, "includeMaxOutputTokens": "Incloure tokens màxims de sortida", - "includeMaxOutputTokensDescription": "Enviar el paràmetre de tokens màxims de sortida a les sol·licituds API. Alguns proveïdors poden no admetre això." + "includeMaxOutputTokensDescription": "Enviar el paràmetre de tokens màxims de sortida a les sol·licituds API. Alguns proveïdors poden no admetre això.", + "limitMaxTokensDescription": "Limitar el nombre màxim de tokens en la resposta", + "maxOutputTokensLabel": "Tokens màxims de sortida", + "maxTokensGenerateDescription": "Tokens màxims a generar en la resposta" } diff --git a/webview-ui/src/i18n/locales/de/settings.json b/webview-ui/src/i18n/locales/de/settings.json index fa646854628..c1c8ea5460f 100644 --- a/webview-ui/src/i18n/locales/de/settings.json +++ b/webview-ui/src/i18n/locales/de/settings.json @@ -754,5 +754,8 @@ "useCustomArn": "Benutzerdefinierte ARN verwenden..." }, "includeMaxOutputTokens": "Maximale Ausgabe-Tokens einbeziehen", - "includeMaxOutputTokensDescription": "Senden Sie den Parameter für maximale Ausgabe-Tokens in API-Anfragen. Einige Anbieter unterstützen dies möglicherweise nicht." + "includeMaxOutputTokensDescription": "Senden Sie den Parameter für maximale Ausgabe-Tokens in API-Anfragen. Einige Anbieter unterstützen dies möglicherweise nicht.", + "limitMaxTokensDescription": "Begrenze die maximale Anzahl von Tokens in der Antwort", + "maxOutputTokensLabel": "Maximale Ausgabe-Tokens", + "maxTokensGenerateDescription": "Maximale Tokens, die in der Antwort generiert werden" } diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index f2b05b33691..51a1b335289 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -754,5 +754,8 @@ "useCustomArn": "Use custom ARN..." }, "includeMaxOutputTokens": "Include max output tokens", - "includeMaxOutputTokensDescription": "Send max output tokens parameter in API requests. Some providers may not support this." + "includeMaxOutputTokensDescription": "Send max output tokens parameter in API requests. Some providers may not support this.", + "limitMaxTokensDescription": "Limit the maximum number of tokens in the response", + "maxOutputTokensLabel": "Max output tokens", + "maxTokensGenerateDescription": "Maximum tokens to generate in response" } diff --git a/webview-ui/src/i18n/locales/es/settings.json b/webview-ui/src/i18n/locales/es/settings.json index 3d60bfb45c3..f422169322c 100644 --- a/webview-ui/src/i18n/locales/es/settings.json +++ b/webview-ui/src/i18n/locales/es/settings.json @@ -754,5 +754,8 @@ "useCustomArn": "Usar ARN personalizado..." }, "includeMaxOutputTokens": "Incluir tokens máximos de salida", - "includeMaxOutputTokensDescription": "Enviar parámetro de tokens máximos de salida en solicitudes API. Algunos proveedores pueden no soportar esto." + "includeMaxOutputTokensDescription": "Enviar parámetro de tokens máximos de salida en solicitudes API. Algunos proveedores pueden no soportar esto.", + "limitMaxTokensDescription": "Limitar el número máximo de tokens en la respuesta", + "maxOutputTokensLabel": "Tokens máximos de salida", + "maxTokensGenerateDescription": "Tokens máximos a generar en la respuesta" } diff --git a/webview-ui/src/i18n/locales/fr/settings.json b/webview-ui/src/i18n/locales/fr/settings.json index 2654afcbdc6..dce37331b72 100644 --- a/webview-ui/src/i18n/locales/fr/settings.json +++ b/webview-ui/src/i18n/locales/fr/settings.json @@ -754,5 +754,8 @@ "useCustomArn": "Utiliser un ARN personnalisé..." }, "includeMaxOutputTokens": "Inclure les tokens de sortie maximum", - "includeMaxOutputTokensDescription": "Envoyer le paramètre de tokens de sortie maximum dans les requêtes API. Certains fournisseurs peuvent ne pas supporter cela." + "includeMaxOutputTokensDescription": "Envoyer le paramètre de tokens de sortie maximum dans les requêtes API. Certains fournisseurs peuvent ne pas supporter cela.", + "limitMaxTokensDescription": "Limiter le nombre maximum de tokens dans la réponse", + "maxOutputTokensLabel": "Tokens de sortie maximum", + "maxTokensGenerateDescription": "Tokens maximum à générer dans la réponse" } diff --git a/webview-ui/src/i18n/locales/hi/settings.json b/webview-ui/src/i18n/locales/hi/settings.json index 279de29ada2..8d15b258ef7 100644 --- a/webview-ui/src/i18n/locales/hi/settings.json +++ b/webview-ui/src/i18n/locales/hi/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "कस्टम ARN का उपयोग करें..." }, "includeMaxOutputTokens": "अधिकतम आउटपुट टोकन शामिल करें", - "includeMaxOutputTokensDescription": "API अनुरोधों में अधिकतम आउटपुट टोकन पैरामीटर भेजें। कुछ प्रदाता इसका समर्थन नहीं कर सकते हैं।" + "includeMaxOutputTokensDescription": "API अनुरोधों में अधिकतम आउटपुट टोकन पैरामीटर भेजें। कुछ प्रदाता इसका समर्थन नहीं कर सकते हैं।", + "limitMaxTokensDescription": "प्रतिक्रिया में टोकन की अधिकतम संख्या सीमित करें", + "maxOutputTokensLabel": "अधिकतम आउटपुट टोकन", + "maxTokensGenerateDescription": "प्रतिक्रिया में उत्पन्न करने के लिए अधिकतम टोकन" } diff --git a/webview-ui/src/i18n/locales/id/settings.json b/webview-ui/src/i18n/locales/id/settings.json index d54db7122ed..9d7d3595abd 100644 --- a/webview-ui/src/i18n/locales/id/settings.json +++ b/webview-ui/src/i18n/locales/id/settings.json @@ -784,5 +784,8 @@ "useCustomArn": "Gunakan ARN kustom..." }, "includeMaxOutputTokens": "Sertakan token output maksimum", - "includeMaxOutputTokensDescription": "Kirim parameter token output maksimum dalam permintaan API. Beberapa provider mungkin tidak mendukung ini." + "includeMaxOutputTokensDescription": "Kirim parameter token output maksimum dalam permintaan API. Beberapa provider mungkin tidak mendukung ini.", + "limitMaxTokensDescription": "Batasi jumlah maksimum token dalam respons", + "maxOutputTokensLabel": "Token output maksimum", + "maxTokensGenerateDescription": "Token maksimum untuk dihasilkan dalam respons" } diff --git a/webview-ui/src/i18n/locales/it/settings.json b/webview-ui/src/i18n/locales/it/settings.json index ed77701914e..13a5b258acc 100644 --- a/webview-ui/src/i18n/locales/it/settings.json +++ b/webview-ui/src/i18n/locales/it/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Usa ARN personalizzato..." }, "includeMaxOutputTokens": "Includi token di output massimi", - "includeMaxOutputTokensDescription": "Invia il parametro dei token di output massimi nelle richieste API. Alcuni provider potrebbero non supportarlo." + "includeMaxOutputTokensDescription": "Invia il parametro dei token di output massimi nelle richieste API. Alcuni provider potrebbero non supportarlo.", + "limitMaxTokensDescription": "Limita il numero massimo di token nella risposta", + "maxOutputTokensLabel": "Token di output massimi", + "maxTokensGenerateDescription": "Token massimi da generare nella risposta" } diff --git a/webview-ui/src/i18n/locales/ja/settings.json b/webview-ui/src/i18n/locales/ja/settings.json index efa020b320f..741aec350fe 100644 --- a/webview-ui/src/i18n/locales/ja/settings.json +++ b/webview-ui/src/i18n/locales/ja/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "カスタム ARN を使用..." }, "includeMaxOutputTokens": "最大出力トークンを含める", - "includeMaxOutputTokensDescription": "APIリクエストで最大出力トークンパラメータを送信します。一部のプロバイダーはこれをサポートしていない場合があります。" + "includeMaxOutputTokensDescription": "APIリクエストで最大出力トークンパラメータを送信します。一部のプロバイダーはこれをサポートしていない場合があります。", + "limitMaxTokensDescription": "レスポンスの最大トークン数を制限する", + "maxOutputTokensLabel": "最大出力トークン", + "maxTokensGenerateDescription": "レスポンスで生成する最大トークン数" } diff --git a/webview-ui/src/i18n/locales/ko/settings.json b/webview-ui/src/i18n/locales/ko/settings.json index 39b8fd6e333..83f7c714de8 100644 --- a/webview-ui/src/i18n/locales/ko/settings.json +++ b/webview-ui/src/i18n/locales/ko/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "사용자 지정 ARN 사용..." }, "includeMaxOutputTokens": "최대 출력 토큰 포함", - "includeMaxOutputTokensDescription": "API 요청에서 최대 출력 토큰 매개변수를 전송합니다. 일부 제공업체는 이를 지원하지 않을 수 있습니다." + "includeMaxOutputTokensDescription": "API 요청에서 최대 출력 토큰 매개변수를 전송합니다. 일부 제공업체는 이를 지원하지 않을 수 있습니다.", + "limitMaxTokensDescription": "응답에서 최대 토큰 수 제한", + "maxOutputTokensLabel": "최대 출력 토큰", + "maxTokensGenerateDescription": "응답에서 생성할 최대 토큰 수" } diff --git a/webview-ui/src/i18n/locales/nl/settings.json b/webview-ui/src/i18n/locales/nl/settings.json index d7eba61045e..8e518012068 100644 --- a/webview-ui/src/i18n/locales/nl/settings.json +++ b/webview-ui/src/i18n/locales/nl/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Aangepaste ARN gebruiken..." }, "includeMaxOutputTokens": "Maximale output tokens opnemen", - "includeMaxOutputTokensDescription": "Stuur maximale output tokens parameter in API-verzoeken. Sommige providers ondersteunen dit mogelijk niet." + "includeMaxOutputTokensDescription": "Stuur maximale output tokens parameter in API-verzoeken. Sommige providers ondersteunen dit mogelijk niet.", + "limitMaxTokensDescription": "Beperk het maximale aantal tokens in het antwoord", + "maxOutputTokensLabel": "Maximale output tokens", + "maxTokensGenerateDescription": "Maximale tokens om te genereren in het antwoord" } diff --git a/webview-ui/src/i18n/locales/pl/settings.json b/webview-ui/src/i18n/locales/pl/settings.json index 960c59724b8..2cfa5dc4bac 100644 --- a/webview-ui/src/i18n/locales/pl/settings.json +++ b/webview-ui/src/i18n/locales/pl/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Użyj niestandardowego ARN..." }, "includeMaxOutputTokens": "Uwzględnij maksymalne tokeny wyjściowe", - "includeMaxOutputTokensDescription": "Wyślij parametr maksymalnych tokenów wyjściowych w żądaniach API. Niektórzy dostawcy mogą tego nie obsługiwać." + "includeMaxOutputTokensDescription": "Wyślij parametr maksymalnych tokenów wyjściowych w żądaniach API. Niektórzy dostawcy mogą tego nie obsługiwać.", + "limitMaxTokensDescription": "Ogranicz maksymalną liczbę tokenów w odpowiedzi", + "maxOutputTokensLabel": "Maksymalne tokeny wyjściowe", + "maxTokensGenerateDescription": "Maksymalne tokeny do wygenerowania w odpowiedzi" } diff --git a/webview-ui/src/i18n/locales/pt-BR/settings.json b/webview-ui/src/i18n/locales/pt-BR/settings.json index 2bb9d4b9213..562ab777b4c 100644 --- a/webview-ui/src/i18n/locales/pt-BR/settings.json +++ b/webview-ui/src/i18n/locales/pt-BR/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Usar ARN personalizado..." }, "includeMaxOutputTokens": "Incluir tokens máximos de saída", - "includeMaxOutputTokensDescription": "Enviar parâmetro de tokens máximos de saída nas solicitações de API. Alguns provedores podem não suportar isso." + "includeMaxOutputTokensDescription": "Enviar parâmetro de tokens máximos de saída nas solicitações de API. Alguns provedores podem não suportar isso.", + "limitMaxTokensDescription": "Limitar o número máximo de tokens na resposta", + "maxOutputTokensLabel": "Tokens máximos de saída", + "maxTokensGenerateDescription": "Tokens máximos para gerar na resposta" } diff --git a/webview-ui/src/i18n/locales/ru/settings.json b/webview-ui/src/i18n/locales/ru/settings.json index c57a8db3d6b..0feed462bf7 100644 --- a/webview-ui/src/i18n/locales/ru/settings.json +++ b/webview-ui/src/i18n/locales/ru/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Использовать пользовательский ARN..." }, "includeMaxOutputTokens": "Включить максимальные выходные токены", - "includeMaxOutputTokensDescription": "Отправлять параметр максимальных выходных токенов в API-запросах. Некоторые провайдеры могут не поддерживать это." + "includeMaxOutputTokensDescription": "Отправлять параметр максимальных выходных токенов в API-запросах. Некоторые провайдеры могут не поддерживать это.", + "limitMaxTokensDescription": "Ограничить максимальное количество токенов в ответе", + "maxOutputTokensLabel": "Максимальные выходные токены", + "maxTokensGenerateDescription": "Максимальные токены для генерации в ответе" } diff --git a/webview-ui/src/i18n/locales/tr/settings.json b/webview-ui/src/i18n/locales/tr/settings.json index 576e3c1b538..20917bdde64 100644 --- a/webview-ui/src/i18n/locales/tr/settings.json +++ b/webview-ui/src/i18n/locales/tr/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Özel ARN kullan..." }, "includeMaxOutputTokens": "Maksimum çıktı tokenlerini dahil et", - "includeMaxOutputTokensDescription": "API isteklerinde maksimum çıktı token parametresini gönder. Bazı sağlayıcılar bunu desteklemeyebilir." + "includeMaxOutputTokensDescription": "API isteklerinde maksimum çıktı token parametresini gönder. Bazı sağlayıcılar bunu desteklemeyebilir.", + "limitMaxTokensDescription": "Yanıttaki maksimum token sayısını sınırla", + "maxOutputTokensLabel": "Maksimum çıktı tokenları", + "maxTokensGenerateDescription": "Yanıtta oluşturulacak maksimum token sayısı" } diff --git a/webview-ui/src/i18n/locales/vi/settings.json b/webview-ui/src/i18n/locales/vi/settings.json index 76a4f9397d6..c30830741e2 100644 --- a/webview-ui/src/i18n/locales/vi/settings.json +++ b/webview-ui/src/i18n/locales/vi/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "Sử dụng ARN tùy chỉnh..." }, "includeMaxOutputTokens": "Bao gồm token đầu ra tối đa", - "includeMaxOutputTokensDescription": "Gửi tham số token đầu ra tối đa trong các yêu cầu API. Một số nhà cung cấp có thể không hỗ trợ điều này." + "includeMaxOutputTokensDescription": "Gửi tham số token đầu ra tối đa trong các yêu cầu API. Một số nhà cung cấp có thể không hỗ trợ điều này.", + "limitMaxTokensDescription": "Giới hạn số lượng token tối đa trong phản hồi", + "maxOutputTokensLabel": "Token đầu ra tối đa", + "maxTokensGenerateDescription": "Token tối đa để tạo trong phản hồi" } diff --git a/webview-ui/src/i18n/locales/zh-CN/settings.json b/webview-ui/src/i18n/locales/zh-CN/settings.json index 748b9f1d5dc..d6c7e0580d5 100644 --- a/webview-ui/src/i18n/locales/zh-CN/settings.json +++ b/webview-ui/src/i18n/locales/zh-CN/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "使用自定义 ARN..." }, "includeMaxOutputTokens": "包含最大输出 Token 数", - "includeMaxOutputTokensDescription": "在 API 请求中发送最大输出 Token 参数。某些提供商可能不支持此功能。" + "includeMaxOutputTokensDescription": "在 API 请求中发送最大输出 Token 参数。某些提供商可能不支持此功能。", + "limitMaxTokensDescription": "限制响应中的最大 Token 数量", + "maxOutputTokensLabel": "最大输出 Token 数", + "maxTokensGenerateDescription": "响应中生成的最大 Token 数" } diff --git a/webview-ui/src/i18n/locales/zh-TW/settings.json b/webview-ui/src/i18n/locales/zh-TW/settings.json index f4a6149102e..07b94a387fe 100644 --- a/webview-ui/src/i18n/locales/zh-TW/settings.json +++ b/webview-ui/src/i18n/locales/zh-TW/settings.json @@ -755,5 +755,8 @@ "useCustomArn": "使用自訂 ARN..." }, "includeMaxOutputTokens": "包含最大輸出 Token 數", - "includeMaxOutputTokensDescription": "在 API 請求中傳送最大輸出 Token 參數。某些提供商可能不支援此功能。" + "includeMaxOutputTokensDescription": "在 API 請求中傳送最大輸出 Token 參數。某些提供商可能不支援此功能。", + "limitMaxTokensDescription": "限制回應中的最大 Token 數量", + "maxOutputTokensLabel": "最大輸出 Token 數", + "maxTokensGenerateDescription": "回應中產生的最大 Token 數" }