Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import { MAX_ITEM_TOKENS, INITIAL_RETRY_DELAY_MS } from "../../constants"
// Mock the OpenAI SDK
vitest.mock("openai")

// Mock global fetch
global.fetch = vitest.fn()

// Mock i18n
vitest.mock("../../../../i18n", () => ({
t: (key: string, params?: Record<string, any>) => {
Expand Down Expand Up @@ -613,5 +616,270 @@ describe("OpenAICompatibleEmbedder", () => {
expect(returnedArray).toEqual([0.25, 0.5, 0.75, 1.0])
})
})

/**
* Test Azure OpenAI compatibility with helper functions for conciseness
*/
describe("Azure OpenAI compatibility", () => {
const azureUrl =
"https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01"
const baseUrl = "https://api.openai.com/v1"

// Helper to create mock fetch response
const createMockResponse = (data: any, status = 200, ok = true) => ({
ok,
status,
json: vitest.fn().mockResolvedValue(data),
text: vitest.fn().mockResolvedValue(status === 200 ? "" : "Error message"),
})

// Helper to create base64 embedding
const createBase64Embedding = (values: number[]) => {
const embedding = new Float32Array(values)
return Buffer.from(embedding.buffer).toString("base64")
}

// Helper to verify embedding values with floating-point tolerance
const expectEmbeddingValues = (actual: number[], expected: number[]) => {
expect(actual).toHaveLength(expected.length)
expected.forEach((val, i) => expect(actual[i]).toBeCloseTo(val, 5))
}

beforeEach(() => {
vitest.clearAllMocks()
;(global.fetch as MockedFunction<typeof fetch>).mockReset()
})

describe("URL detection", () => {
it.each([
[
"https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01",
true,
],
["https://myresource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings", true],
["https://api.openai.com/v1", false],
["https://api.example.com", false],
["http://localhost:8080", false],
])("should detect URL type correctly: %s -> %s", (url, expected) => {
const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId)
const isFullUrl = (embedder as any).isFullEndpointUrl(url)
expect(isFullUrl).toBe(expected)
})

// Edge cases where 'embeddings' or 'deployments' appear in non-endpoint contexts
it("should return false for URLs with 'embeddings' in non-endpoint contexts", () => {
const testUrls = [
"https://api.example.com/embeddings-service/v1",
"https://embeddings.example.com/api",
"https://api.example.com/v1/embeddings-api",
"https://my-embeddings-provider.com/v1",
]

testUrls.forEach((url) => {
const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId)
const isFullUrl = (embedder as any).isFullEndpointUrl(url)
expect(isFullUrl).toBe(false)
})
})

it("should return false for URLs with 'deployments' in non-endpoint contexts", () => {
const testUrls = [
"https://deployments.example.com/api",
"https://api.deployments.com/v1",
"https://my-deployments-service.com/api/v1",
"https://deployments-manager.example.com",
]

testUrls.forEach((url) => {
const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId)
const isFullUrl = (embedder as any).isFullEndpointUrl(url)
expect(isFullUrl).toBe(false)
})
})

it("should correctly identify actual endpoint URLs", () => {
const endpointUrls = [
"https://api.example.com/v1/embeddings",
"https://api.example.com/v1/embeddings?api-version=2024",
"https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings",
"https://api.example.com/embed",
"https://api.example.com/embed?version=1",
]

endpointUrls.forEach((url) => {
const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId)
const isFullUrl = (embedder as any).isFullEndpointUrl(url)
expect(isFullUrl).toBe(true)
})
})
})

describe("direct HTTP requests", () => {
it("should use direct fetch for Azure URLs and SDK for base URLs", async () => {
const testTexts = ["Test text"]
const base64String = createBase64Embedding([0.1, 0.2, 0.3])

// Test Azure URL (direct fetch)
const azureEmbedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
const mockFetchResponse = createMockResponse({
data: [{ embedding: base64String }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockFetchResponse as any)

const azureResult = await azureEmbedder.createEmbeddings(testTexts)
expect(global.fetch).toHaveBeenCalledWith(
azureUrl,
expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
"api-key": testApiKey,
Authorization: `Bearer ${testApiKey}`,
}),
}),
)
expect(mockEmbeddingsCreate).not.toHaveBeenCalled()
expectEmbeddingValues(azureResult.embeddings[0], [0.1, 0.2, 0.3])

// Reset and test base URL (SDK)
vitest.clearAllMocks()
const baseEmbedder = new OpenAICompatibleEmbedder(baseUrl, testApiKey, testModelId)
mockEmbeddingsCreate.mockResolvedValue({
data: [{ embedding: [0.4, 0.5, 0.6] }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})

const baseResult = await baseEmbedder.createEmbeddings(testTexts)
expect(mockEmbeddingsCreate).toHaveBeenCalled()
expect(global.fetch).not.toHaveBeenCalled()
expect(baseResult.embeddings[0]).toEqual([0.4, 0.5, 0.6])
})

it.each([
[401, "Authentication failed. Please check your API key."],
[500, "Failed to create embeddings after 3 attempts"],
])("should handle HTTP errors: %d", async (status, expectedMessage) => {
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
const mockResponse = createMockResponse({}, status, false)
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockResponse as any)

await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(expectedMessage)
})

it("should handle rate limiting with retries", async () => {
vitest.useFakeTimers()
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
const base64String = createBase64Embedding([0.1, 0.2, 0.3])

;(global.fetch as MockedFunction<typeof fetch>)
.mockResolvedValueOnce(createMockResponse({}, 429, false) as any)
.mockResolvedValueOnce(createMockResponse({}, 429, false) as any)
.mockResolvedValueOnce(
createMockResponse({
data: [{ embedding: base64String }],
usage: { prompt_tokens: 10, total_tokens: 15 },
}) as any,
)

const resultPromise = embedder.createEmbeddings(["test"])
await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 3)
const result = await resultPromise

expect(global.fetch).toHaveBeenCalledTimes(3)
expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Rate limit hit"))
expectEmbeddingValues(result.embeddings[0], [0.1, 0.2, 0.3])
vitest.useRealTimers()
})

it("should handle multiple embeddings and network errors", async () => {
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)

// Test multiple embeddings
const base64_1 = createBase64Embedding([0.25, 0.5])
const base64_2 = createBase64Embedding([0.75, 1.0])
const mockResponse = createMockResponse({
data: [{ embedding: base64_1 }, { embedding: base64_2 }],
usage: { prompt_tokens: 20, total_tokens: 30 },
})
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockResponse as any)

const result = await embedder.createEmbeddings(["test1", "test2"])
expect(result.embeddings).toHaveLength(2)
expectEmbeddingValues(result.embeddings[0], [0.25, 0.5])
expectEmbeddingValues(result.embeddings[1], [0.75, 1.0])

// Test network error
const networkError = new Error("Network failed")
;(global.fetch as MockedFunction<typeof fetch>).mockRejectedValue(networkError)
await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(
"Failed to create embeddings after 3 attempts",
)
})
})
})
})

describe("URL detection", () => {
it("should detect Azure deployment URLs as full endpoints", async () => {
const embedder = new OpenAICompatibleEmbedder(
"https://myinstance.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2023-05-15",
"test-key",
)

// The private method is tested indirectly through the createEmbeddings behavior
// If it's detected as a full URL, it will make a direct HTTP request
const mockFetch = vitest.fn().mockResolvedValue({
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2] }],
usage: { prompt_tokens: 10, total_tokens: 15 },
}),
})
global.fetch = mockFetch

await embedder.createEmbeddings(["test"])

// Should make direct HTTP request to the full URL
expect(mockFetch).toHaveBeenCalledWith(
"https://myinstance.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2023-05-15",
expect.any(Object),
)
})

it("should detect /embed endpoints as full URLs", async () => {
const embedder = new OpenAICompatibleEmbedder("https://api.example.com/v1/embed", "test-key")

const mockFetch = vitest.fn().mockResolvedValue({
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2] }],
usage: { prompt_tokens: 10, total_tokens: 15 },
}),
})
global.fetch = mockFetch

await embedder.createEmbeddings(["test"])

// Should make direct HTTP request to the full URL
expect(mockFetch).toHaveBeenCalledWith("https://api.example.com/v1/embed", expect.any(Object))
})

it("should treat base URLs without endpoint patterns as SDK URLs", async () => {
const embedder = new OpenAICompatibleEmbedder("https://api.openai.com/v1", "test-key")

// Mock the OpenAI SDK's embeddings.create method
const mockCreate = vitest.fn().mockResolvedValue({
data: [{ embedding: [0.1, 0.2] }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})
embedder["embeddingsClient"].embeddings = {
create: mockCreate,
} as any

await embedder.createEmbeddings(["test"])

// Should use SDK which will append /embeddings
expect(mockCreate).toHaveBeenCalled()
})
})
})
Loading
Loading