diff --git a/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts b/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts index 271d68cc205..107f3af24dd 100644 --- a/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts +++ b/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts @@ -6,6 +6,9 @@ import { MAX_ITEM_TOKENS, INITIAL_RETRY_DELAY_MS } from "../../constants" // Mock the OpenAI SDK vitest.mock("openai") +// Mock global fetch +global.fetch = vitest.fn() + // Mock i18n vitest.mock("../../../../i18n", () => ({ t: (key: string, params?: Record) => { @@ -613,5 +616,270 @@ describe("OpenAICompatibleEmbedder", () => { expect(returnedArray).toEqual([0.25, 0.5, 0.75, 1.0]) }) }) + + /** + * Test Azure OpenAI compatibility with helper functions for conciseness + */ + describe("Azure OpenAI compatibility", () => { + const azureUrl = + "https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01" + const baseUrl = "https://api.openai.com/v1" + + // Helper to create mock fetch response + const createMockResponse = (data: any, status = 200, ok = true) => ({ + ok, + status, + json: vitest.fn().mockResolvedValue(data), + text: vitest.fn().mockResolvedValue(status === 200 ? "" : "Error message"), + }) + + // Helper to create base64 embedding + const createBase64Embedding = (values: number[]) => { + const embedding = new Float32Array(values) + return Buffer.from(embedding.buffer).toString("base64") + } + + // Helper to verify embedding values with floating-point tolerance + const expectEmbeddingValues = (actual: number[], expected: number[]) => { + expect(actual).toHaveLength(expected.length) + expected.forEach((val, i) => expect(actual[i]).toBeCloseTo(val, 5)) + } + + beforeEach(() => { + vitest.clearAllMocks() + ;(global.fetch as MockedFunction).mockReset() + }) + + describe("URL detection", () => { + it.each([ + [ + "https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01", + true, + ], + ["https://myresource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings", true], + ["https://api.openai.com/v1", false], + ["https://api.example.com", false], + ["http://localhost:8080", false], + ])("should detect URL type correctly: %s -> %s", (url, expected) => { + const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId) + const isFullUrl = (embedder as any).isFullEndpointUrl(url) + expect(isFullUrl).toBe(expected) + }) + + // Edge cases where 'embeddings' or 'deployments' appear in non-endpoint contexts + it("should return false for URLs with 'embeddings' in non-endpoint contexts", () => { + const testUrls = [ + "https://api.example.com/embeddings-service/v1", + "https://embeddings.example.com/api", + "https://api.example.com/v1/embeddings-api", + "https://my-embeddings-provider.com/v1", + ] + + testUrls.forEach((url) => { + const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId) + const isFullUrl = (embedder as any).isFullEndpointUrl(url) + expect(isFullUrl).toBe(false) + }) + }) + + it("should return false for URLs with 'deployments' in non-endpoint contexts", () => { + const testUrls = [ + "https://deployments.example.com/api", + "https://api.deployments.com/v1", + "https://my-deployments-service.com/api/v1", + "https://deployments-manager.example.com", + ] + + testUrls.forEach((url) => { + const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId) + const isFullUrl = (embedder as any).isFullEndpointUrl(url) + expect(isFullUrl).toBe(false) + }) + }) + + it("should correctly identify actual endpoint URLs", () => { + const endpointUrls = [ + "https://api.example.com/v1/embeddings", + "https://api.example.com/v1/embeddings?api-version=2024", + "https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings", + "https://api.example.com/embed", + "https://api.example.com/embed?version=1", + ] + + endpointUrls.forEach((url) => { + const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId) + const isFullUrl = (embedder as any).isFullEndpointUrl(url) + expect(isFullUrl).toBe(true) + }) + }) + }) + + describe("direct HTTP requests", () => { + it("should use direct fetch for Azure URLs and SDK for base URLs", async () => { + const testTexts = ["Test text"] + const base64String = createBase64Embedding([0.1, 0.2, 0.3]) + + // Test Azure URL (direct fetch) + const azureEmbedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) + const mockFetchResponse = createMockResponse({ + data: [{ embedding: base64String }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) + ;(global.fetch as MockedFunction).mockResolvedValue(mockFetchResponse as any) + + const azureResult = await azureEmbedder.createEmbeddings(testTexts) + expect(global.fetch).toHaveBeenCalledWith( + azureUrl, + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + "api-key": testApiKey, + Authorization: `Bearer ${testApiKey}`, + }), + }), + ) + expect(mockEmbeddingsCreate).not.toHaveBeenCalled() + expectEmbeddingValues(azureResult.embeddings[0], [0.1, 0.2, 0.3]) + + // Reset and test base URL (SDK) + vitest.clearAllMocks() + const baseEmbedder = new OpenAICompatibleEmbedder(baseUrl, testApiKey, testModelId) + mockEmbeddingsCreate.mockResolvedValue({ + data: [{ embedding: [0.4, 0.5, 0.6] }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) + + const baseResult = await baseEmbedder.createEmbeddings(testTexts) + expect(mockEmbeddingsCreate).toHaveBeenCalled() + expect(global.fetch).not.toHaveBeenCalled() + expect(baseResult.embeddings[0]).toEqual([0.4, 0.5, 0.6]) + }) + + it.each([ + [401, "Authentication failed. Please check your API key."], + [500, "Failed to create embeddings after 3 attempts"], + ])("should handle HTTP errors: %d", async (status, expectedMessage) => { + const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) + const mockResponse = createMockResponse({}, status, false) + ;(global.fetch as MockedFunction).mockResolvedValue(mockResponse as any) + + await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(expectedMessage) + }) + + it("should handle rate limiting with retries", async () => { + vitest.useFakeTimers() + const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) + const base64String = createBase64Embedding([0.1, 0.2, 0.3]) + + ;(global.fetch as MockedFunction) + .mockResolvedValueOnce(createMockResponse({}, 429, false) as any) + .mockResolvedValueOnce(createMockResponse({}, 429, false) as any) + .mockResolvedValueOnce( + createMockResponse({ + data: [{ embedding: base64String }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) as any, + ) + + const resultPromise = embedder.createEmbeddings(["test"]) + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 3) + const result = await resultPromise + + expect(global.fetch).toHaveBeenCalledTimes(3) + expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Rate limit hit")) + expectEmbeddingValues(result.embeddings[0], [0.1, 0.2, 0.3]) + vitest.useRealTimers() + }) + + it("should handle multiple embeddings and network errors", async () => { + const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) + + // Test multiple embeddings + const base64_1 = createBase64Embedding([0.25, 0.5]) + const base64_2 = createBase64Embedding([0.75, 1.0]) + const mockResponse = createMockResponse({ + data: [{ embedding: base64_1 }, { embedding: base64_2 }], + usage: { prompt_tokens: 20, total_tokens: 30 }, + }) + ;(global.fetch as MockedFunction).mockResolvedValue(mockResponse as any) + + const result = await embedder.createEmbeddings(["test1", "test2"]) + expect(result.embeddings).toHaveLength(2) + expectEmbeddingValues(result.embeddings[0], [0.25, 0.5]) + expectEmbeddingValues(result.embeddings[1], [0.75, 1.0]) + + // Test network error + const networkError = new Error("Network failed") + ;(global.fetch as MockedFunction).mockRejectedValue(networkError) + await expect(embedder.createEmbeddings(["test"])).rejects.toThrow( + "Failed to create embeddings after 3 attempts", + ) + }) + }) + }) + }) + + describe("URL detection", () => { + it("should detect Azure deployment URLs as full endpoints", async () => { + const embedder = new OpenAICompatibleEmbedder( + "https://myinstance.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2023-05-15", + "test-key", + ) + + // The private method is tested indirectly through the createEmbeddings behavior + // If it's detected as a full URL, it will make a direct HTTP request + const mockFetch = vitest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + data: [{ embedding: [0.1, 0.2] }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }), + }) + global.fetch = mockFetch + + await embedder.createEmbeddings(["test"]) + + // Should make direct HTTP request to the full URL + expect(mockFetch).toHaveBeenCalledWith( + "https://myinstance.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2023-05-15", + expect.any(Object), + ) + }) + + it("should detect /embed endpoints as full URLs", async () => { + const embedder = new OpenAICompatibleEmbedder("https://api.example.com/v1/embed", "test-key") + + const mockFetch = vitest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + data: [{ embedding: [0.1, 0.2] }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }), + }) + global.fetch = mockFetch + + await embedder.createEmbeddings(["test"]) + + // Should make direct HTTP request to the full URL + expect(mockFetch).toHaveBeenCalledWith("https://api.example.com/v1/embed", expect.any(Object)) + }) + + it("should treat base URLs without endpoint patterns as SDK URLs", async () => { + const embedder = new OpenAICompatibleEmbedder("https://api.openai.com/v1", "test-key") + + // Mock the OpenAI SDK's embeddings.create method + const mockCreate = vitest.fn().mockResolvedValue({ + data: [{ embedding: [0.1, 0.2] }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) + embedder["embeddingsClient"].embeddings = { + create: mockCreate, + } as any + + await embedder.createEmbeddings(["test"]) + + // Should use SDK which will append /embeddings + expect(mockCreate).toHaveBeenCalled() + }) }) }) diff --git a/src/services/code-index/embedders/openai-compatible.ts b/src/services/code-index/embedders/openai-compatible.ts index f7e679abf65..88eced8a0a8 100644 --- a/src/services/code-index/embedders/openai-compatible.ts +++ b/src/services/code-index/embedders/openai-compatible.ts @@ -26,9 +26,19 @@ interface OpenAIEmbeddingResponse { * OpenAI Compatible implementation of the embedder interface with batching and rate limiting. * This embedder allows using any OpenAI-compatible API endpoint by specifying a custom baseURL. */ +interface HttpError extends Error { + status?: number + response?: { + status?: number + } +} + export class OpenAICompatibleEmbedder implements IEmbedder { private embeddingsClient: OpenAI private readonly defaultModelId: string + private readonly baseUrl: string + private readonly apiKey: string + private readonly isFullUrl: boolean private readonly maxItemTokens: number /** @@ -46,11 +56,15 @@ export class OpenAICompatibleEmbedder implements IEmbedder { throw new Error("API key is required for OpenAI Compatible embedder") } + this.baseUrl = baseUrl + this.apiKey = apiKey this.embeddingsClient = new OpenAI({ baseURL: baseUrl, apiKey: apiKey, }) this.defaultModelId = modelId || getDefaultModelId("openai-compatible") + // Cache the URL type check for performance + this.isFullUrl = this.isFullEndpointUrl(baseUrl) this.maxItemTokens = maxItemTokens || MAX_ITEM_TOKENS } @@ -138,6 +152,65 @@ export class OpenAICompatibleEmbedder implements IEmbedder { return { embeddings: allEmbeddings, usage } } + /** + * Determines if the provided URL is a full endpoint URL or a base URL that needs the endpoint appended by the SDK. + * Uses smart pattern matching for known providers while accepting we can't cover all possible patterns. + * @param url The URL to check + * @returns true if it's a full endpoint URL, false if it's a base URL + */ + private isFullEndpointUrl(url: string): boolean { + // Known patterns for major providers + const patterns = [ + // Azure OpenAI: /deployments/{deployment-name}/embeddings + /\/deployments\/[^\/]+\/embeddings(\?|$)/, + // Direct endpoints: ends with /embeddings (before query params) + /\/embeddings(\?|$)/, + // Some providers use /embed instead of /embeddings + /\/embed(\?|$)/, + ] + + return patterns.some((pattern) => pattern.test(url)) + } + + /** + * Makes a direct HTTP request to the embeddings endpoint + * Used when the user provides a full endpoint URL (e.g., Azure OpenAI with query parameters) + * @param url The full endpoint URL + * @param batchTexts Array of texts to embed + * @param model Model identifier to use + * @returns Promise resolving to OpenAI-compatible response + */ + private async makeDirectEmbeddingRequest( + url: string, + batchTexts: string[], + model: string, + ): Promise { + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + // Azure OpenAI uses 'api-key' header, while OpenAI uses 'Authorization' + // We'll try 'api-key' first for Azure compatibility + "api-key": this.apiKey, + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ + input: batchTexts, + model: model, + encoding_format: "base64", + }), + }) + + if (!response.ok) { + const errorText = await response.text() + const error = new Error(`HTTP ${response.status}: ${errorText}`) as HttpError + error.status = response.status + throw error + } + + return await response.json() + } + /** * Helper method to handle batch embedding with retries and exponential backoff * @param batchTexts Array of texts to embed in this batch @@ -148,16 +221,27 @@ export class OpenAICompatibleEmbedder implements IEmbedder { batchTexts: string[], model: string, ): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> { + // Use cached value for performance + const isFullUrl = this.isFullUrl + for (let attempts = 0; attempts < MAX_RETRIES; attempts++) { try { - const response = (await this.embeddingsClient.embeddings.create({ - input: batchTexts, - model: model, - // OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256 - // when processing numeric arrays, which breaks compatibility with models using larger dimensions. - // By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves. - encoding_format: "base64", - })) as OpenAIEmbeddingResponse + let response: OpenAIEmbeddingResponse + + if (isFullUrl) { + // Use direct HTTP request for full endpoint URLs + response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model) + } else { + // Use OpenAI SDK for base URLs + response = (await this.embeddingsClient.embeddings.create({ + input: batchTexts, + model: model, + // OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256 + // when processing numeric arrays, which breaks compatibility with models using larger dimensions. + // By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves. + encoding_format: "base64", + })) as OpenAIEmbeddingResponse + } // Convert base64 embeddings to float32 arrays const processedEmbeddings = response.data.map((item: EmbeddingItem) => { @@ -187,8 +271,9 @@ export class OpenAICompatibleEmbedder implements IEmbedder { totalTokens: response.usage?.total_tokens || 0, }, } - } catch (error: any) { - const isRateLimitError = error?.status === 429 + } catch (error) { + const httpError = error as HttpError + const isRateLimitError = httpError?.status === 429 const hasMoreAttempts = attempts < MAX_RETRIES - 1 if (isRateLimitError && hasMoreAttempts) { @@ -209,19 +294,19 @@ export class OpenAICompatibleEmbedder implements IEmbedder { // Provide more context in the error message using robust error extraction let errorMessage = t("embeddings:unknownError") - if (error?.message) { - errorMessage = error.message + if (httpError?.message) { + errorMessage = httpError.message } else if (typeof error === "string") { errorMessage = error - } else if (error && typeof error.toString === "function") { + } else if (error && typeof error === "object" && "toString" in error) { try { - errorMessage = error.toString() + errorMessage = String(error) } catch { errorMessage = t("embeddings:unknownError") } } - const statusCode = error?.status || error?.response?.status + const statusCode = httpError?.status || httpError?.response?.status if (statusCode === 401) { throw new Error(t("embeddings:authenticationFailed"))