diff --git a/packages/types/src/codebase-index.ts b/packages/types/src/codebase-index.ts index 89d5b168d78..cfd83611fb8 100644 --- a/packages/types/src/codebase-index.ts +++ b/packages/types/src/codebase-index.ts @@ -42,12 +42,25 @@ export type CodebaseIndexConfig = z.infer * CodebaseIndexModels */ +const modelProfileSchema = z.object({ + /** The fixed dimension for the model, or a fallback for models with variable dimensions. */ + dimension: z.number(), + scoreThreshold: z.number().optional(), + queryPrefix: z.string().optional(), + /** The minimum dimension supported by a variable-dimension model. */ + minDimension: z.number().optional(), + /** The maximum dimension supported by a variable-dimension model. */ + maxDimension: z.number().optional(), + /** The default dimension for a variable-dimension model, used for UI presentation. */ + defaultDimension: z.number().optional(), +}) + export const codebaseIndexModelsSchema = z.object({ - openai: z.record(z.string(), z.object({ dimension: z.number() })).optional(), - ollama: z.record(z.string(), z.object({ dimension: z.number() })).optional(), - "openai-compatible": z.record(z.string(), z.object({ dimension: z.number() })).optional(), - gemini: z.record(z.string(), z.object({ dimension: z.number() })).optional(), - mistral: z.record(z.string(), z.object({ dimension: z.number() })).optional(), + openai: z.record(z.string(), modelProfileSchema).optional(), + ollama: z.record(z.string(), modelProfileSchema).optional(), + "openai-compatible": z.record(z.string(), modelProfileSchema).optional(), + gemini: z.record(z.string(), modelProfileSchema).optional(), + mistral: z.record(z.string(), modelProfileSchema).optional(), }) export type CodebaseIndexModels = z.infer diff --git a/src/services/code-index/embedders/__tests__/gemini.spec.ts b/src/services/code-index/embedders/__tests__/gemini.spec.ts index d41a4dc1e93..73d574b30c3 100644 --- a/src/services/code-index/embedders/__tests__/gemini.spec.ts +++ b/src/services/code-index/embedders/__tests__/gemini.spec.ts @@ -104,7 +104,7 @@ describe("GeminiEmbedder", () => { const result = await embedder.createEmbeddings(texts) // Assert - expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001") + expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001", undefined) expect(result).toEqual(mockResponse) }) @@ -124,7 +124,7 @@ describe("GeminiEmbedder", () => { const result = await embedder.createEmbeddings(texts, "gemini-embedding-001") // Assert - expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001") + expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001", undefined) expect(result).toEqual(mockResponse) }) @@ -190,4 +190,40 @@ describe("GeminiEmbedder", () => { await expect(embedder.validateConfiguration()).rejects.toThrow("Validation failed") }) }) + + describe("createEmbeddings", () => { + let mockCreateEmbeddings: any + + beforeEach(() => { + mockCreateEmbeddings = vitest.fn() + MockedOpenAICompatibleEmbedder.prototype.createEmbeddings = mockCreateEmbeddings + embedder = new GeminiEmbedder("test-api-key") + }) + + it("should use default model when none is provided", async () => { + // Arrange + const texts = ["text1", "text2"] + mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } }) + + // Act + await embedder.createEmbeddings(texts) + + // Assert + expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001", undefined) + }) + + it("should pass model and dimension to the OpenAICompatibleEmbedder", async () => { + // Arrange + const texts = ["text1", "text2"] + const model = "custom-model" + const options = { dimension: 1536 } + mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } }) + + // Act + await embedder.createEmbeddings(texts, model, options) + + // Assert + expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, model, options) + }) + }) }) diff --git a/src/services/code-index/embedders/gemini.ts b/src/services/code-index/embedders/gemini.ts index 7e795875c9d..df0aa501333 100644 --- a/src/services/code-index/embedders/gemini.ts +++ b/src/services/code-index/embedders/gemini.ts @@ -47,11 +47,15 @@ export class GeminiEmbedder implements IEmbedder { * @param model Optional model identifier (uses constructor model if not provided) * @returns Promise resolving to embedding response */ - async createEmbeddings(texts: string[], model?: string): Promise { + async createEmbeddings( + texts: string[], + model?: string, + options?: { dimension?: number }, + ): Promise { try { // Use the provided model or fall back to the instance's model const modelToUse = model || this.modelId - return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse) + return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse, options) } catch (error) { TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { error: error instanceof Error ? error.message : String(error), diff --git a/src/services/code-index/embedders/openai-compatible.ts b/src/services/code-index/embedders/openai-compatible.ts index 035f50f3867..e34fb3dbc42 100644 --- a/src/services/code-index/embedders/openai-compatible.ts +++ b/src/services/code-index/embedders/openai-compatible.ts @@ -82,7 +82,11 @@ export class OpenAICompatibleEmbedder implements IEmbedder { * @param model Optional model identifier * @returns Promise resolving to embedding response */ - async createEmbeddings(texts: string[], model?: string): Promise { + async createEmbeddings( + texts: string[], + model?: string, + options?: { dimension?: number }, + ): Promise { const modelToUse = model || this.defaultModelId // Apply model-specific query prefix if required @@ -150,7 +154,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder { } if (currentBatch.length > 0) { - const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse) + const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse, options) allEmbeddings.push(...batchResult.embeddings) usage.promptTokens += batchResult.usage.promptTokens usage.totalTokens += batchResult.usage.totalTokens @@ -192,7 +196,18 @@ export class OpenAICompatibleEmbedder implements IEmbedder { url: string, batchTexts: string[], model: string, + options?: { dimension?: number }, ): Promise { + const body: Record = { + input: batchTexts, + model: model, + encoding_format: "base64", + } + + if (options?.dimension) { + body.dimensions = options.dimension + } + const response = await fetch(url, { method: "POST", headers: { @@ -202,11 +217,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder { "api-key": this.apiKey, Authorization: `Bearer ${this.apiKey}`, }, - body: JSON.stringify({ - input: batchTexts, - model: model, - encoding_format: "base64", - }), + body: JSON.stringify(body), }) if (!response || !response.ok) { @@ -245,6 +256,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder { private async _embedBatchWithRetries( batchTexts: string[], model: string, + options?: { dimension?: number }, ): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> { // Use cached value for performance const isFullUrl = this.isFullUrl @@ -258,7 +270,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder { if (isFullUrl) { // Use direct HTTP request for full endpoint URLs - response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model) + response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model, options) } else { // Use OpenAI SDK for base URLs response = (await this.embeddingsClient.embeddings.create({ @@ -268,6 +280,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder { // when processing numeric arrays, which breaks compatibility with models using larger dimensions. // By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves. encoding_format: "base64", + ...(options?.dimension && { dimensions: options.dimension }), })) as OpenAIEmbeddingResponse } diff --git a/src/services/code-index/interfaces/embedder.ts b/src/services/code-index/interfaces/embedder.ts index c5653ea2b7e..fe7a45dad19 100644 --- a/src/services/code-index/interfaces/embedder.ts +++ b/src/services/code-index/interfaces/embedder.ts @@ -9,7 +9,7 @@ export interface IEmbedder { * @param model Optional model ID to use for embeddings * @returns Promise resolving to an EmbeddingResponse */ - createEmbeddings(texts: string[], model?: string): Promise + createEmbeddings(texts: string[], model?: string, options?: { dimension?: number }): Promise /** * Validates the embedder configuration by testing connectivity and credentials. diff --git a/src/shared/embeddingModels.ts b/src/shared/embeddingModels.ts index a3cd61e6593..539f4b3ee40 100644 --- a/src/shared/embeddingModels.ts +++ b/src/shared/embeddingModels.ts @@ -5,9 +5,13 @@ export type EmbedderProvider = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" // Add other providers as needed export interface EmbeddingModelProfile { + /** The fixed dimension for the model, or a fallback for models with variable dimensions. */ dimension: number scoreThreshold?: number // Model-specific minimum score threshold for semantic search queryPrefix?: string // Optional prefix required by the model for queries + minDimension?: number // The minimum dimension supported by a variable-dimension model. + maxDimension?: number // The maximum dimension supported by a variable-dimension model. + defaultDimension?: number // The default dimension for a variable-dimension model, used for UI presentation. // Add other model-specific properties if needed, e.g., context window size } @@ -48,7 +52,13 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = { }, gemini: { "text-embedding-004": { dimension: 768 }, - "gemini-embedding-001": { dimension: 3072, scoreThreshold: 0.4 }, + "gemini-embedding-001": { + dimension: 3072, // Fallback, but defaultDimension is preferred + minDimension: 128, + maxDimension: 3072, + defaultDimension: 3072, + scoreThreshold: 0.4, + }, }, mistral: { "codestral-embed-2505": { dimension: 1536, scoreThreshold: 0.4 }, diff --git a/webview-ui/src/components/chat/CodeIndexPopover.tsx b/webview-ui/src/components/chat/CodeIndexPopover.tsx index c85aaf6ea5b..8d2c83837f3 100644 --- a/webview-ui/src/components/chat/CodeIndexPopover.tsx +++ b/webview-ui/src/components/chat/CodeIndexPopover.tsx @@ -73,7 +73,7 @@ interface LocalCodeIndexSettings { } // Validation schema for codebase index settings -const createValidationSchema = (provider: EmbedderProvider, t: any) => { +const createValidationSchema = (provider: EmbedderProvider, t: any, models: any) => { const baseSchema = z.object({ codebaseIndexEnabled: z.boolean(), codebaseIndexQdrantUrl: z @@ -121,12 +121,52 @@ const createValidationSchema = (provider: EmbedderProvider, t: any) => { }) case "gemini": - return baseSchema.extend({ - codebaseIndexGeminiApiKey: z.string().min(1, t("settings:codeIndex.validation.geminiApiKeyRequired")), - codebaseIndexEmbedderModelId: z - .string() - .min(1, t("settings:codeIndex.validation.modelSelectionRequired")), - }) + return baseSchema + .extend({ + codebaseIndexGeminiApiKey: z + .string() + .min(1, t("settings:codeIndex.validation.geminiApiKeyRequired")), + codebaseIndexEmbedderModelId: z + .string() + .min(1, t("settings:codeIndex.validation.modelSelectionRequired")), + codebaseIndexEmbedderModelDimension: z.number().optional(), + }) + .refine( + (data) => { + const model = models?.gemini?.[data.codebaseIndexEmbedderModelId || ""] + // If the model supports variable dimensions, a dimension must be provided. + if (model?.minDimension && !data.codebaseIndexEmbedderModelDimension) { + return false // Fails validation if dimension is required but not provided + } + return true + }, + { + message: t("settings:codeIndex.validation.modelDimensionRequired"), + path: ["codebaseIndexEmbedderModelDimension"], + }, + ) + .refine( + (data) => { + const model = models?.gemini?.[data.codebaseIndexEmbedderModelId || ""] + if (model?.minDimension && model?.maxDimension && data.codebaseIndexEmbedderModelDimension) { + return ( + data.codebaseIndexEmbedderModelDimension >= model.minDimension && + data.codebaseIndexEmbedderModelDimension <= model.maxDimension + ) + } + return true + }, + (data) => { + const model = models?.gemini?.[data.codebaseIndexEmbedderModelId || ""] + return { + message: t("settings:codeIndex.validation.invalidDimension", { + min: model?.minDimension, + max: model?.maxDimension, + }), + path: ["codebaseIndexEmbedderModelDimension"], + } + }, + ) case "mistral": return baseSchema.extend({ @@ -193,21 +233,28 @@ export const CodeIndexPopover: React.FC = ({ setIndexingStatus(externalIndexingStatus) }, [externalIndexingStatus]) - // Initialize settings from global state + // Initializes the settings from the global state when it changes useEffect(() => { if (codebaseIndexConfig) { + const provider = codebaseIndexConfig.codebaseIndexEmbedderProvider || "openai" + const modelId = codebaseIndexConfig.codebaseIndexEmbedderModelId || "" + const modelProfile = codebaseIndexModels?.[provider]?.[modelId] + const settings = { codebaseIndexEnabled: codebaseIndexConfig.codebaseIndexEnabled ?? true, codebaseIndexQdrantUrl: codebaseIndexConfig.codebaseIndexQdrantUrl || "", - codebaseIndexEmbedderProvider: codebaseIndexConfig.codebaseIndexEmbedderProvider || "openai", + codebaseIndexEmbedderProvider: provider, codebaseIndexEmbedderBaseUrl: codebaseIndexConfig.codebaseIndexEmbedderBaseUrl || "", - codebaseIndexEmbedderModelId: codebaseIndexConfig.codebaseIndexEmbedderModelId || "", + codebaseIndexEmbedderModelId: modelId, + // Determines the dimension exclusively from the global configuration and model profiles. + // The local 'currentSettings' state is no longer read here. codebaseIndexEmbedderModelDimension: - codebaseIndexConfig.codebaseIndexEmbedderModelDimension || undefined, + codebaseIndexConfig.codebaseIndexEmbedderModelDimension || modelProfile?.defaultDimension, codebaseIndexSearchMaxResults: codebaseIndexConfig.codebaseIndexSearchMaxResults ?? CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_RESULTS, codebaseIndexSearchMinScore: codebaseIndexConfig.codebaseIndexSearchMinScore ?? CODEBASE_INDEX_DEFAULTS.DEFAULT_SEARCH_MIN_SCORE, + // Keys are initially set to empty and populated by a separate effect. codeIndexOpenAiKey: "", codeIndexQdrantApiKey: "", codebaseIndexOpenAiCompatibleBaseUrl: codebaseIndexConfig.codebaseIndexOpenAiCompatibleBaseUrl || "", @@ -218,10 +265,10 @@ export const CodeIndexPopover: React.FC = ({ setInitialSettings(settings) setCurrentSettings(settings) - // Request secret status to check if secrets exist + // Requests the status of the secrets to display placeholders correctly. vscode.postMessage({ type: "requestCodeIndexSecretStatus" }) } - }, [codebaseIndexConfig]) + }, [codebaseIndexConfig, codebaseIndexModels]) // Dependencies are now correct and complete. // Request initial indexing status useEffect(() => { @@ -366,9 +413,30 @@ export const CodeIndexPopover: React.FC = ({ } } + // Handles model changes, ensuring dimension is reset correctly + const handleModelChange = (newModelId: string) => { + const provider = currentSettings.codebaseIndexEmbedderProvider + const modelProfile = codebaseIndexModels?.[provider]?.[newModelId] + const defaultDimension = modelProfile?.defaultDimension + + setCurrentSettings((prev) => ({ + ...prev, + codebaseIndexEmbedderModelId: newModelId, + codebaseIndexEmbedderModelDimension: defaultDimension, + })) + + // Clear validation errors for model and dimension + setFormErrors((prev) => { + const newErrors = { ...prev } + delete newErrors.codebaseIndexEmbedderModelId + delete newErrors.codebaseIndexEmbedderModelDimension + return newErrors + }) + } + // Validation function const validateSettings = (): boolean => { - const schema = createValidationSchema(currentSettings.codebaseIndexEmbedderProvider, t) + const schema = createValidationSchema(currentSettings.codebaseIndexEmbedderProvider, t, codebaseIndexModels) // Prepare data for validation const dataToValidate: any = {} @@ -920,9 +988,7 @@ export const CodeIndexPopover: React.FC = ({ - updateSetting("codebaseIndexEmbedderModelId", e.target.value) - } + onChange={(e: any) => handleModelChange(e.target.value)} className={cn("w-full", { "border-red-500": formErrors.codebaseIndexEmbedderModelId, })}> @@ -952,6 +1018,51 @@ export const CodeIndexPopover: React.FC = ({

)} + {(() => { + const selectedModelProfile = + codebaseIndexModels?.gemini?.[ + currentSettings.codebaseIndexEmbedderModelId + ] + + if ( + selectedModelProfile?.minDimension && + selectedModelProfile?.maxDimension + ) { + return ( +
+ + + updateSetting( + "codebaseIndexEmbedderModelDimension", + e.target.value + ? parseInt(e.target.value, 10) + : undefined, + ) + } + className="w-full" + /> + {formErrors.codebaseIndexEmbedderModelDimension && ( +

+ {formErrors.codebaseIndexEmbedderModelDimension} +

+ )} +
+ ) + } + return null + })()} )} diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index 7c58e679c6e..a29b95c2b04 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -98,6 +98,7 @@ "qdrantUrlPlaceholder": "http://localhost:6333", "saveError": "Failed to save settings", "modelDimensions": "({{dimension}} dimensions)", + "dimensionRange": "({{min}}-{{max}})", "saveSuccess": "Settings saved successfully", "saving": "Saving...", "saveSettings": "Save", @@ -118,6 +119,7 @@ "apiKeyRequired": "API key is required", "modelIdRequired": "Model ID is required", "modelDimensionRequired": "Model dimension is required", + "invalidDimension": "Dimension must be between {{min}} and {{max}}", "geminiApiKeyRequired": "Gemini API key is required", "mistralApiKeyRequired": "Mistral API key is required", "ollamaBaseUrlRequired": "Ollama base URL is required",