From fedcbb93d547e61ec6f34b0d74c1d617db4dab4e Mon Sep 17 00:00:00 2001 From: machen Date: Thu, 9 Apr 2026 15:46:12 +0800 Subject: [PATCH 1/2] refactor(ai): tighten openai-chat typing without base extraction Apply typed request/response guards directly in openai-chat-completion and remove any usage in the provider implementation. Add provider-focused tests and clarify config comments for OpenAI-compatible endpoints. --- src/config.ts | 12 + .../ai/providers/openai-chat-completion.ts | 114 +++-- tests/openai-chat-completion-provider.test.ts | 442 ++++++++++++++++++ 3 files changed, 542 insertions(+), 26 deletions(-) create mode 100644 tests/openai-chat-completion-provider.test.ts diff --git a/src/config.ts b/src/config.ts index 0832cbd..8fad5da 100644 --- a/src/config.ts +++ b/src/config.ts @@ -287,6 +287,8 @@ const CONFIG_TEMPLATE = `{ "autoCaptureEnabled": true, // Provider type: "openai-chat" | "openai-responses" | "anthropic" + // Note: "openai-chat" is a generic OpenAI API-compatible mode. + // Any service that follows the OpenAI Chat Completions API can use it via custom "memoryApiUrl". "memoryProvider": "openai-chat", // REQUIRED for auto-capture (all 3 must be set): @@ -300,11 +302,21 @@ const CONFIG_TEMPLATE = `{ // From env variable: "env://LITELLM_API_KEY" // Examples for different providers: + // Any OpenAI-compatible endpoint can use the "openai-chat" provider pattern below. + // Common examples: DeepSeek, Qwen (via Alibaba Cloud ModelStudio), + // Zhipu GLM (BigModel platform), and Kimi (Moonshot AI platform). + // OpenAI Chat Completion (default, backward compatible): // "memoryProvider": "openai-chat" // "memoryModel": "gpt-4o-mini" // "memoryApiUrl": "https://api.openai.com/v1" // "memoryApiKey": "sk-..." + + // DeepSeek (OpenAI-compatible example): + // "memoryProvider": "openai-chat" + // "memoryModel": "deepseek-chat" + // "memoryApiUrl": "https://api.deepseek.com/v1" + // "memoryApiKey": "sk-..." // OpenAI Responses API (recommended, with session support): // "memoryProvider": "openai-responses" diff --git a/src/services/ai/providers/openai-chat-completion.ts b/src/services/ai/providers/openai-chat-completion.ts index efd461c..5564e2f 100644 --- a/src/services/ai/providers/openai-chat-completion.ts +++ b/src/services/ai/providers/openai-chat-completion.ts @@ -1,5 +1,11 @@ -import { BaseAIProvider, type ToolCallResult, applySafeExtraParams } from "./base-provider.js"; -import { AISessionManager } from "../session/ai-session-manager.js"; +import { + BaseAIProvider, + type ProviderConfig, + type ToolCallResult, + applySafeExtraParams, +} from "./base-provider.js"; +import type { AISessionManager } from "../session/ai-session-manager.js"; +import type { AIMessage } from "../session/session-types.js"; import type { ChatCompletionTool } from "../tools/tool-schema.js"; import { log } from "../../logger.js"; import { UserProfileValidator } from "../validators/user-profile-validator.js"; @@ -10,7 +16,7 @@ interface ToolCallResponse { content?: string; tool_calls?: Array<{ id: string; - type: string; + type: "function"; function: { name: string; arguments: string; @@ -21,10 +27,46 @@ interface ToolCallResponse { }>; } +type APIMessage = { + role: AIMessage["role"]; + content: string; + tool_calls?: ToolCallResponse["choices"][number]["message"]["tool_calls"]; + tool_call_id?: string; +}; + +type RequestBody = { + model: string; + messages: APIMessage[]; + tools: ChatCompletionTool[]; + tool_choice: "auto"; + temperature?: number; + [key: string]: unknown; +}; + +type AssistantSessionMessage = Omit; + +function isErrorResponseBody(data: unknown): data is { status: string; msg: string } { + return ( + typeof data === "object" && + data !== null && + typeof (data as { status?: unknown }).status === "string" && + typeof (data as { msg?: unknown }).msg === "string" + ); +} + +function isToolCallResponse(data: unknown): data is ToolCallResponse { + return ( + typeof data === "object" && + data !== null && + Array.isArray((data as { choices?: unknown }).choices) && + (data as { choices: unknown[] }).choices.length > 0 + ); +} + export class OpenAIChatCompletionProvider extends BaseAIProvider { - private aiSessionManager: AISessionManager; + private readonly aiSessionManager: AISessionManager; - constructor(config: any, aiSessionManager: AISessionManager) { + constructor(config: ProviderConfig, aiSessionManager: AISessionManager) { super(config); this.aiSessionManager = aiSessionManager; } @@ -39,7 +81,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { private addToolResponse( sessionId: string, - messages: any[], + messages: APIMessage[], toolCallId: string, content: string ): void { @@ -58,22 +100,26 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { }); } - private filterIncompleteToolCallSequences(messages: any[]): any[] { - const result: any[] = []; + protected filterIncompleteToolCallSequences(messages: AIMessage[]): AIMessage[] { + const result: AIMessage[] = []; let i = 0; while (i < messages.length) { const msg = messages[i]; + if (!msg) { + break; + } if (msg.role === "assistant" && msg.toolCalls && msg.toolCalls.length > 0) { - const toolCallIds = new Set(msg.toolCalls.map((tc: any) => tc.id)); - const toolResponses: any[] = []; + const toolCallIds = new Set(msg.toolCalls.map((tc) => tc.id)); + const toolResponses: AIMessage[] = []; let j = i + 1; - while (j < messages.length && messages[j].role === "tool") { - if (toolCallIds.has(messages[j].toolCallId)) { - toolResponses.push(messages[j]); - toolCallIds.delete(messages[j].toolCallId); + while (j < messages.length && messages[j]?.role === "tool") { + const toolMessage = messages[j]; + if (toolMessage?.toolCallId && toolCallIds.has(toolMessage.toolCallId)) { + toolResponses.push(toolMessage); + toolCallIds.delete(toolMessage.toolCallId); } j++; } @@ -110,12 +156,12 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { } const existingMessages = this.aiSessionManager.getMessages(session.id); - const messages: any[] = []; + const messages: APIMessage[] = []; const validatedMessages = this.filterIncompleteToolCallSequences(existingMessages); for (const msg of validatedMessages) { - const apiMsg: any = { + const apiMsg: APIMessage = { role: msg.role, content: msg.content, }; @@ -164,7 +210,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { const timeout = setTimeout(() => controller.abort(), iterationTimeout); try { - const requestBody: any = { + const requestBody: RequestBody = { model: this.config.model, messages, tools: [toolSchema], @@ -224,9 +270,9 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { }; } - const data = (await response.json()) as any; + const data: unknown = await response.json(); - if (data.status && data.msg) { + if (isErrorResponseBody(data)) { log("API returned error in response body", { provider: this.getProviderName(), model: this.config.model, @@ -240,13 +286,18 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { }; } - if (!data.choices || !data.choices[0]) { + if (!isToolCallResponse(data)) { + const choices = + typeof data === "object" && data !== null + ? (data as { choices?: unknown }).choices + : undefined; + log("Invalid API response format", { provider: this.getProviderName(), model: this.config.model, response: JSON.stringify(data).slice(0, 1000), - hasChoices: !!data.choices, - choicesLength: data.choices?.length, + hasChoices: Array.isArray(choices), + choicesLength: Array.isArray(choices) ? choices.length : undefined, }); return { success: false, @@ -256,9 +307,16 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { } const choice = data.choices[0]; + if (!choice) { + return { + success: false, + error: "Invalid API response format", + iterations, + }; + } const assistantSequence = this.aiSessionManager.getLastSequence(session.id) + 1; - const assistantMsg: any = { + const assistantMsg: AssistantSessionMessage = { aiSessionId: session.id, sequence: assistantSequence, role: "assistant", @@ -270,7 +328,11 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { } this.aiSessionManager.addMessage(assistantMsg); - messages.push(choice.message); + messages.push({ + role: "assistant", + content: choice.message.content || "", + tool_calls: choice.message.tool_calls, + }); if (choice.message.tool_calls && choice.message.tool_calls.length > 0) { for (const toolCall of choice.message.tool_calls) { @@ -356,7 +418,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { if (error instanceof Error && error.name === "AbortError") { return { success: false, - error: `API request timeout (${this.config.iterationTimeout}ms)`, + error: `API request timeout (${iterationTimeout}ms)`, iterations, }; } @@ -370,7 +432,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { return { success: false, - error: `Max iterations (${this.config.maxIterations}) reached without tool call`, + error: `Max iterations (${maxIterations}) reached without tool call`, iterations, }; } diff --git a/tests/openai-chat-completion-provider.test.ts b/tests/openai-chat-completion-provider.test.ts new file mode 100644 index 0000000..1e3df30 --- /dev/null +++ b/tests/openai-chat-completion-provider.test.ts @@ -0,0 +1,442 @@ +import { afterEach, describe, expect, it } from "bun:test"; +import { OpenAIChatCompletionProvider } from "../src/services/ai/providers/openai-chat-completion.js"; +import type { AIMessage } from "../src/services/ai/session/session-types.js"; +import type { ChatCompletionTool } from "../src/services/ai/tools/tool-schema.js"; + +const toolSchema: ChatCompletionTool = { + type: "function", + function: { + name: "save_memories", + description: "Save memories", + parameters: { + type: "object", + properties: {}, + required: [], + }, + }, +}; + +class FakeSessionManager { + private readonly session = { id: "session-1" }; + private readonly messages: any[] = []; + + getSession(): any { + return null; + } + + createSession(): any { + return this.session; + } + + getMessages(): any[] { + return this.messages; + } + + getLastSequence(): number { + return this.messages.length - 1; + } + + addMessage(message: any): void { + this.messages.push(message); + } +} + +class TestableOpenAIChatCompletionProvider extends OpenAIChatCompletionProvider { + filterMessages(messages: AIMessage[]): AIMessage[] { + return this.filterIncompleteToolCallSequences(messages); + } +} + +function makeProvider(config: Record = {}) { + return new OpenAIChatCompletionProvider( + { model: "gpt-4o-mini", apiKey: "test-key", ...config }, + new FakeSessionManager() as any + ); +} + +function makeTestableProvider(config: Record = {}) { + return new TestableOpenAIChatCompletionProvider( + { model: "gpt-4o-mini", apiKey: "test-key", ...config }, + new FakeSessionManager() as any + ); +} + +function makeFetch(response: { + ok?: boolean; + status?: number; + statusText?: string; + body?: unknown; +}) { + const textBody = + typeof response.body === "string" ? response.body : JSON.stringify(response.body ?? "error"); + const jsonBody = typeof response.body === "string" ? {} : (response.body ?? {}); + return (async (_input: RequestInfo | URL, _init?: RequestInit) => { + return { + ok: response.ok ?? false, + status: response.status ?? 400, + statusText: response.statusText ?? "Bad Request", + text: async () => textBody, + json: async () => jsonBody, + } as Response; + }) as typeof fetch; +} + +describe("OpenAIChatCompletionProvider", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + it("getProviderName returns openai-chat", () => { + expect(makeProvider().getProviderName()).toBe("openai-chat"); + }); + + it("supportsSession returns true", () => { + expect(makeProvider().supportsSession()).toBe(true); + }); + + it("keeps complete tool call sequences", () => { + const messages: AIMessage[] = [ + { + aiSessionId: "session-1", + sequence: 0, + role: "assistant", + content: "", + toolCalls: [ + { + id: "call-1", + type: "function", + function: { name: "save_memories", arguments: "{}" }, + }, + ], + createdAt: 1, + }, + { + aiSessionId: "session-1", + sequence: 1, + role: "tool", + content: '{"success":true}', + toolCallId: "call-1", + createdAt: 2, + }, + ]; + + expect(makeTestableProvider().filterMessages(messages)).toEqual(messages); + }); + + it("drops trailing incomplete tool call sequences", () => { + const messages: AIMessage[] = [ + { + aiSessionId: "session-1", + sequence: 0, + role: "assistant", + content: "", + toolCalls: [ + { + id: "call-1", + type: "function", + function: { name: "save_memories", arguments: "{}" }, + }, + ], + createdAt: 1, + }, + ]; + + expect(makeTestableProvider().filterMessages(messages)).toEqual([]); + }); + + it("keeps complete prefix and drops later incomplete tool call sequences", () => { + const messages: AIMessage[] = [ + { + aiSessionId: "session-1", + sequence: 0, + role: "assistant", + content: "", + toolCalls: [ + { + id: "call-1", + type: "function", + function: { name: "save_memories", arguments: "{}" }, + }, + ], + createdAt: 1, + }, + { + aiSessionId: "session-1", + sequence: 1, + role: "tool", + content: '{"success":true}', + toolCallId: "call-1", + createdAt: 2, + }, + { + aiSessionId: "session-1", + sequence: 2, + role: "assistant", + content: "", + toolCalls: [ + { + id: "call-2", + type: "function", + function: { name: "save_memories", arguments: "{}" }, + }, + ], + createdAt: 3, + }, + ]; + + expect(makeTestableProvider().filterMessages(messages)).toEqual(messages.slice(0, 2)); + }); + + it("uses custom apiUrl for the request", async () => { + let capturedUrl = ""; + globalThis.fetch = (async (input: RequestInfo | URL, _init?: RequestInit) => { + capturedUrl = String(input); + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ apiUrl: "https://compatible.example.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(capturedUrl).toBe("https://compatible.example.com/v1/chat/completions"); + }); + + it("sends Authorization Bearer header", async () => { + let capturedHeaders: Record | undefined; + globalThis.fetch = (async (_input: RequestInfo | URL, init?: RequestInit) => { + capturedHeaders = init?.headers as Record; + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ apiKey: "sk-mykey", apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(capturedHeaders?.Authorization).toBe("Bearer sk-mykey"); + }); + + it("omits Authorization header when apiKey is not set", async () => { + let capturedHeaders: Record | undefined; + globalThis.fetch = (async (_input: RequestInfo | URL, init?: RequestInit) => { + capturedHeaders = init?.headers as Record; + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ apiKey: undefined, apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(capturedHeaders?.Authorization).toBeUndefined(); + }); + + it("sends model, messages, tools, tool_choice in request body", async () => { + let capturedBody: Record | undefined; + globalThis.fetch = (async (_input: RequestInfo | URL, init?: RequestInit) => { + capturedBody = JSON.parse(String(init?.body ?? "{}")); + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ + model: "gpt-4o-mini", + apiUrl: "https://api.openai.com/v1", + }).executeToolCall("system", "user", toolSchema, "session-id"); + + expect(capturedBody?.model).toBe("gpt-4o-mini"); + expect(Array.isArray(capturedBody?.messages)).toBe(true); + expect(Array.isArray(capturedBody?.tools)).toBe(true); + expect(capturedBody?.tool_choice).toBe("auto"); + }); + + it("includes temperature 0.3 by default", async () => { + let capturedBody: Record | undefined; + globalThis.fetch = (async (_input: RequestInfo | URL, init?: RequestInit) => { + capturedBody = JSON.parse(String(init?.body ?? "{}")); + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(capturedBody?.temperature).toBe(0.3); + }); + + it("omits temperature when memoryTemperature is false", async () => { + let capturedBody: Record | undefined; + globalThis.fetch = (async (_input: RequestInfo | URL, init?: RequestInit) => { + capturedBody = JSON.parse(String(init?.body ?? "{}")); + return { ok: false, status: 400, statusText: "Bad", text: async () => "err" } as Response; + }) as typeof fetch; + + await makeProvider({ + memoryTemperature: false, + apiUrl: "https://api.openai.com/v1", + }).executeToolCall("system", "user", toolSchema, "session-id"); + + expect(capturedBody?.temperature).toBeUndefined(); + }); + + it("returns success: false with error message on API error response", async () => { + globalThis.fetch = makeFetch({ ok: false, status: 401, body: "Unauthorized" }); + + const result = await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(result.success).toBe(false); + expect(result.error).toContain("401"); + }); + + it("returns friendly message on temperature unsupported error", async () => { + globalThis.fetch = makeFetch({ + ok: false, + status: 400, + body: '{"error": {"type": "unsupported_value", "param": "temperature"}}', + }); + + const result = await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(result.success).toBe(false); + expect(result.error).toContain("memoryTemperature"); + }); + + it("returns success: false when response has no choices", async () => { + globalThis.fetch = makeFetch({ ok: true, body: { choices: [] } } as any); + + const result = await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(result.success).toBe(false); + expect(result.error).toContain("Invalid API response format"); + }); + + it("returns success: false when API returns error in response body", async () => { + globalThis.fetch = makeFetch({ + ok: true, + body: { status: "error", msg: "quota exceeded" }, + } as any); + + const result = await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(result.success).toBe(false); + expect(result.error).toContain("quota exceeded"); + }); + + it("returns success: true when model calls the correct tool", async () => { + const validArguments = JSON.stringify({ + preferences: [], + patterns: [], + workflows: [], + codingStyle: {}, + domainKnowledge: [], + }); + + globalThis.fetch = makeFetch({ + ok: true, + body: { + choices: [ + { + message: { + content: null, + tool_calls: [ + { + id: "call-1", + type: "function", + function: { name: "save_memories", arguments: validArguments }, + }, + ], + }, + }, + ], + }, + } as any); + + const result = await makeProvider({ apiUrl: "https://api.openai.com/v1" }).executeToolCall( + "system", + "user", + toolSchema, + "session-id" + ); + + expect(result.success).toBe(true); + expect(result.iterations).toBe(1); + }); + + it("returns success: false after max iterations with no tool call", async () => { + globalThis.fetch = makeFetch({ + ok: true, + body: { + choices: [{ message: { content: "I will not use a tool", tool_calls: undefined } }], + }, + } as any); + + const result = await makeProvider({ + maxIterations: 2, + apiUrl: "https://api.openai.com/v1", + }).executeToolCall("system", "user", toolSchema, "session-id"); + + expect(result.success).toBe(false); + expect(result.error).toContain("Max iterations"); + expect(result.iterations).toBe(2); + }); + + it("returns success: false when model calls wrong tool name", async () => { + globalThis.fetch = makeFetch({ + ok: true, + body: { + choices: [ + { + message: { + content: null, + tool_calls: [ + { + id: "call-1", + type: "function", + function: { name: "wrong_tool", arguments: "{}" }, + }, + ], + }, + }, + ], + }, + } as any); + + const result = await makeProvider({ + maxIterations: 1, + apiUrl: "https://api.openai.com/v1", + }).executeToolCall("system", "user", toolSchema, "session-id"); + + expect(result.success).toBe(false); + }); +}); From 09d3c71381deab459726c39d5c828393d19234e8 Mon Sep 17 00:00:00 2001 From: machen Date: Thu, 9 Apr 2026 19:44:41 +0800 Subject: [PATCH 2/2] refactor(ai): strengthen hasNonEmptyChoices guard and fix content nullability --- .../ai/providers/openai-chat-completion.ts | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/services/ai/providers/openai-chat-completion.ts b/src/services/ai/providers/openai-chat-completion.ts index 5564e2f..a1449dd 100644 --- a/src/services/ai/providers/openai-chat-completion.ts +++ b/src/services/ai/providers/openai-chat-completion.ts @@ -13,7 +13,7 @@ import { UserProfileValidator } from "../validators/user-profile-validator.js"; interface ToolCallResponse { choices: Array<{ message: { - content?: string; + content?: string | null; tool_calls?: Array<{ id: string; type: "function"; @@ -29,7 +29,7 @@ interface ToolCallResponse { type APIMessage = { role: AIMessage["role"]; - content: string; + content: string | null; tool_calls?: ToolCallResponse["choices"][number]["message"]["tool_calls"]; tool_call_id?: string; }; @@ -54,13 +54,20 @@ function isErrorResponseBody(data: unknown): data is { status: string; msg: stri ); } -function isToolCallResponse(data: unknown): data is ToolCallResponse { - return ( - typeof data === "object" && - data !== null && - Array.isArray((data as { choices?: unknown }).choices) && - (data as { choices: unknown[] }).choices.length > 0 - ); +function hasNonEmptyChoices(data: unknown): data is ToolCallResponse { + if (typeof data !== "object" || data === null) return false; + const { choices } = data as { choices?: unknown }; + if (!Array.isArray(choices) || choices.length === 0) return false; + + const first = choices[0] as { message?: unknown }; + if (typeof first !== "object" || first === null) return false; + if (typeof first.message !== "object" || first.message === null) return false; + + const { content, tool_calls } = first.message as { content?: unknown; tool_calls?: unknown }; + if (content !== undefined && content !== null && typeof content !== "string") return false; + if (tool_calls !== undefined && !Array.isArray(tool_calls)) return false; + + return true; } export class OpenAIChatCompletionProvider extends BaseAIProvider { @@ -286,7 +293,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { }; } - if (!isToolCallResponse(data)) { + if (!hasNonEmptyChoices(data)) { const choices = typeof data === "object" && data !== null ? (data as { choices?: unknown }).choices @@ -320,7 +327,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { aiSessionId: session.id, sequence: assistantSequence, role: "assistant", - content: choice.message.content || "", + content: choice.message.content ?? "", }; if (choice.message.tool_calls) { @@ -330,7 +337,7 @@ export class OpenAIChatCompletionProvider extends BaseAIProvider { this.aiSessionManager.addMessage(assistantMsg); messages.push({ role: "assistant", - content: choice.message.content || "", + content: choice.message.content ?? null, tool_calls: choice.message.tool_calls, });