Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/types/src/providers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* HuggingFace provider constants
*/

// Default values for HuggingFace models
export const HUGGINGFACE_DEFAULT_MAX_TOKENS = 2048
export const HUGGINGFACE_MAX_TOKENS_FALLBACK = 8192
export const HUGGINGFACE_DEFAULT_CONTEXT_WINDOW = 128_000

// UI constants
export const HUGGINGFACE_SLIDER_STEP = 256
export const HUGGINGFACE_SLIDER_MIN = 1
export const HUGGINGFACE_TEMPERATURE_MAX_VALUE = 2

// API constants
export const HUGGINGFACE_API_URL = "https://router.huggingface.co/v1/models?collection=roocode"
export const HUGGINGFACE_CACHE_DURATION = 1000 * 60 * 60 // 1 hour
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export * from "./deepseek.js"
export * from "./gemini.js"
export * from "./glama.js"
export * from "./groq.js"
export * from "./huggingface.js"
export * from "./lite-llm.js"
export * from "./lm-studio.js"
export * from "./mistral.js"
Expand Down
56 changes: 50 additions & 6 deletions src/api/huggingface-models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models"
import {
getHuggingFaceModels as fetchModels,
getCachedRawHuggingFaceModels,
type HuggingFaceModel,
} from "./providers/fetchers/huggingface"
import axios from "axios"
import { HUGGINGFACE_API_URL } from "@roo-code/types"

export interface HuggingFaceModelsResponse {
models: HuggingFaceModel[]
Expand All @@ -7,11 +13,49 @@ export interface HuggingFaceModelsResponse {
}

export async function getHuggingFaceModels(): Promise<HuggingFaceModelsResponse> {
const models = await fetchHuggingFaceModels()
try {
// First, trigger the fetch to populate cache
await fetchModels()

return {
models,
cached: false, // We could enhance this to track if data came from cache
timestamp: Date.now(),
// Get the raw models from cache
const cachedRawModels = getCachedRawHuggingFaceModels()

if (cachedRawModels) {
return {
models: cachedRawModels,
cached: true,
timestamp: Date.now(),
}
}

// If no cached raw models, fetch directly from API
const response = await axios.get(HUGGINGFACE_API_URL, {
headers: {
"Upgrade-Insecure-Requests": "1",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
Priority: "u=0, i",
Pragma: "no-cache",
"Cache-Control": "no-cache",
},
timeout: 10000,
})

const models = response.data?.data || []

return {
models,
cached: false,
timestamp: Date.now(),
}
} catch (error) {
console.error("Failed to get HuggingFace models:", error)
return {
models: [],
cached: false,
timestamp: Date.now(),
}
}
}
229 changes: 229 additions & 0 deletions src/api/providers/fetchers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import axios from "axios"
import { z } from "zod"
import type { ModelInfo } from "@roo-code/types"
import {
HUGGINGFACE_API_URL,
HUGGINGFACE_CACHE_DURATION,
HUGGINGFACE_DEFAULT_MAX_TOKENS,
HUGGINGFACE_DEFAULT_CONTEXT_WINDOW,
} from "@roo-code/types"
import type { ModelRecord } from "../../../shared/api"

/**
* HuggingFace Provider Schema
*/
const huggingFaceProviderSchema = z.object({
provider: z.string(),
status: z.enum(["live", "staging", "error"]),
supports_tools: z.boolean().optional(),
supports_structured_output: z.boolean().optional(),
context_length: z.number().optional(),
pricing: z
.object({
input: z.number(),
output: z.number(),
})
.optional(),
})

/**
* Represents a provider that can serve a HuggingFace model
* @property provider - The provider identifier (e.g., "sambanova", "together")
* @property status - The current status of the provider
* @property supports_tools - Whether the provider supports tool/function calling
* @property supports_structured_output - Whether the provider supports structured output
* @property context_length - The maximum context length supported by this provider
* @property pricing - The pricing information for input/output tokens
*/
export type HuggingFaceProvider = z.infer<typeof huggingFaceProviderSchema>

/**
* HuggingFace Model Schema
*/
const huggingFaceModelSchema = z.object({
id: z.string(),
object: z.literal("model"),
created: z.number(),
owned_by: z.string(),
providers: z.array(huggingFaceProviderSchema),
})

/**
* Represents a HuggingFace model available through the router API
* @property id - The unique identifier of the model
* @property object - The object type (always "model")
* @property created - Unix timestamp of when the model was created
* @property owned_by - The organization that owns the model
* @property providers - List of providers that can serve this model
*/
export type HuggingFaceModel = z.infer<typeof huggingFaceModelSchema>

/**
* HuggingFace API Response Schema
*/
const huggingFaceApiResponseSchema = z.object({
object: z.string(),
data: z.array(huggingFaceModelSchema),
})

/**
* Represents the response from the HuggingFace router API
* @property object - The response object type
* @property data - Array of available models
*/
type HuggingFaceApiResponse = z.infer<typeof huggingFaceApiResponseSchema>

/**
* Cache entry for storing fetched models
* @property data - The cached model records
* @property timestamp - Unix timestamp of when the cache was last updated
*/
interface CacheEntry {
data: ModelRecord
rawModels?: HuggingFaceModel[]
timestamp: number
}

let cache: CacheEntry | null = null

/**
* Parse a HuggingFace model into ModelInfo format
* @param model - The HuggingFace model to parse
* @param provider - Optional specific provider to use for capabilities
* @returns ModelInfo object compatible with the application's model system
*/
function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFaceProvider): ModelInfo {
// Use provider-specific values if available, otherwise find first provider with values
const contextLength =
provider?.context_length ||
model.providers.find((p) => p.context_length)?.context_length ||
HUGGINGFACE_DEFAULT_CONTEXT_WINDOW

const pricing = provider?.pricing || model.providers.find((p) => p.pricing)?.pricing

// Include provider name in description if specific provider is given
const description = provider ? `${model.id} via ${provider.provider}` : `${model.id} via HuggingFace`

return {
maxTokens: Math.min(contextLength, HUGGINGFACE_DEFAULT_MAX_TOKENS),
contextWindow: contextLength,
supportsImages: false, // HuggingFace API doesn't provide this info yet
supportsPromptCache: false,
supportsComputerUse: false,
inputPrice: pricing?.input,
outputPrice: pricing?.output,
description,
}
}

/**
* Fetches available models from HuggingFace
*
* @returns A promise that resolves to a record of model IDs to model info
* @throws Will throw an error if the request fails
*/
export async function getHuggingFaceModels(): Promise<ModelRecord> {
const now = Date.now()

// Check cache
if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) {
return cache.data
}

const models: ModelRecord = {}

try {
const response = await axios.get<HuggingFaceApiResponse>(HUGGINGFACE_API_URL, {
headers: {
"Upgrade-Insecure-Requests": "1",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
Priority: "u=0, i",
Pragma: "no-cache",
"Cache-Control": "no-cache",
},
timeout: 10000, // 10 second timeout
})

const result = huggingFaceApiResponseSchema.safeParse(response.data)

if (!result.success) {
console.error("HuggingFace models response validation failed:", result.error.format())
throw new Error("Invalid response format from HuggingFace API")
}

const validModels = result.data.data.filter((model) => model.providers.length > 0)

for (const model of validModels) {
// Add the base model
models[model.id] = parseHuggingFaceModel(model)

// Add provider-specific variants for all live providers
for (const provider of model.providers) {
if (provider.status === "live") {
const providerKey = `${model.id}:${provider.provider}`
const providerModel = parseHuggingFaceModel(model, provider)

// Always add provider variants to show all available providers
models[providerKey] = providerModel
}
}
}

// Update cache
cache = {
data: models,
rawModels: validModels,
timestamp: now,
}

return models
} catch (error) {
console.error("Error fetching HuggingFace models:", error)

// Return cached data if available
if (cache) {
return cache.data
}

// Re-throw with more context
if (axios.isAxiosError(error)) {
if (error.response) {
throw new Error(
`Failed to fetch HuggingFace models: ${error.response.status} ${error.response.statusText}`,
)
} else if (error.request) {
throw new Error(
"Failed to fetch HuggingFace models: No response from server. Check your internet connection.",
)
}
}

throw new Error(
`Failed to fetch HuggingFace models: ${error instanceof Error ? error.message : "Unknown error"}`,
)
}
}

/**
* Get cached models without making an API request
*/
export function getCachedHuggingFaceModels(): ModelRecord | null {
return cache?.data || null
}

/**
* Get cached raw models for UI display
*/
export function getCachedRawHuggingFaceModels(): HuggingFaceModel[] | null {
return cache?.rawModels || null
}

/**
* Clear the cache
*/
export function clearHuggingFaceCache(): void {
cache = null
}
35 changes: 34 additions & 1 deletion src/api/providers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import type { ApiHandlerOptions } from "../../shared/api"
import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { getHuggingFaceModels, getCachedHuggingFaceModels } from "./fetchers/huggingface"

export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
private client: OpenAI
private options: ApiHandlerOptions
private modelCache: ModelRecord | null = null

constructor(options: ApiHandlerOptions) {
super()
Expand All @@ -25,6 +27,20 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
apiKey: this.options.huggingFaceApiKey,
defaultHeaders: DEFAULT_HEADERS,
})

// Try to get cached models first
this.modelCache = getCachedHuggingFaceModels()

// Fetch models asynchronously
this.fetchModels()
}

private async fetchModels() {
try {
this.modelCache = await getHuggingFaceModels()
} catch (error) {
console.error("Failed to fetch HuggingFace models:", error)
}
}

override async *createMessage(
Expand All @@ -43,6 +59,11 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
stream_options: { include_usage: true },
}

// Add max_tokens if specified
if (this.options.includeMaxTokens && this.options.modelMaxTokens) {
params.max_tokens = this.options.modelMaxTokens
}

const stream = await this.client.chat.completions.create(params)

for await (const chunk of stream) {
Expand Down Expand Up @@ -86,6 +107,18 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion

override getModel() {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"

// Try to get model info from cache
const modelInfo = this.modelCache?.[modelId]

if (modelInfo) {
return {
id: modelId,
info: modelInfo,
}
}

// Fallback to default values if model not found in cache
return {
id: modelId,
info: {
Expand Down
Loading