Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 249 additions & 32 deletions lib/chat/__tests__/handleChatStream.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { NextResponse } from "next/server";

import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId";
import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId";
import { setupChatRequest } from "@/lib/chat/setupChatRequest";
import { setupConversation } from "@/lib/chat/setupConversation";
import { saveChatCompletion } from "@/lib/chat/saveChatCompletion";
import { createUIMessageStream, createUIMessageStreamResponse } from "ai";
import { handleChatStream } from "../handleChatStream";

// Mock all dependencies before importing the module under test
vi.mock("@/lib/auth/getApiKeyAccountId", () => ({
getApiKeyAccountId: vi.fn(),
Expand All @@ -23,52 +31,50 @@ vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({
}));

vi.mock("@/lib/chat/setupConversation", () => ({
setupConversation: vi.fn().mockResolvedValue({ roomId: "mock-room-id", memoryId: "mock-memory-id" }),
setupConversation: vi
.fn()
.mockResolvedValue({ roomId: "mock-room-id", memoryId: "mock-memory-id" }),
}));

vi.mock("@/lib/chat/validateMessages", () => ({
validateMessages: vi.fn((messages) => ({
validateMessages: vi.fn(messages => ({
lastMessage: messages[messages.length - 1] || { id: "mock-id", role: "user", parts: [] },
validMessages: messages,
})),
}));

vi.mock("@/lib/messages/convertToUiMessages", () => ({
default: vi.fn((messages) => messages),
default: vi.fn(messages => messages),
}));

vi.mock("@/lib/chat/setupChatRequest", () => ({
setupChatRequest: vi.fn(),
}));

vi.mock("@/lib/chat/handleChatCompletion", () => ({
handleChatCompletion: vi.fn(),
vi.mock("@/lib/chat/saveChatCompletion", () => ({
saveChatCompletion: vi.fn(),
}));

vi.mock("ai", () => ({
createUIMessageStream: vi.fn(),
createUIMessageStreamResponse: vi.fn(),
}));

import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId";
import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId";
import { setupChatRequest } from "@/lib/chat/setupChatRequest";
import { setupConversation } from "@/lib/chat/setupConversation";
import { createUIMessageStream, createUIMessageStreamResponse } from "ai";
import { handleChatStream } from "../handleChatStream";

const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId);
const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId);
const mockSetupConversation = vi.mocked(setupConversation);
const mockSetupChatRequest = vi.mocked(setupChatRequest);
const mockSaveChatCompletion = vi.mocked(saveChatCompletion);
const mockCreateUIMessageStream = vi.mocked(createUIMessageStream);
const mockCreateUIMessageStreamResponse = vi.mocked(createUIMessageStreamResponse);

// Helper to create mock NextRequest
function createMockRequest(
body: unknown,
headers: Record<string, string> = {},
): Request {
/**
*
* @param body
* @param headers
*/
function createMockRequest(body: unknown, headers: Record<string, string> = {}): Request {
return {
json: () => Promise.resolve(body),
headers: {
Expand Down Expand Up @@ -97,10 +103,7 @@ describe("handleChatStream", () => {
it("returns 400 error when neither messages nor prompt is provided", async () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");

const request = createMockRequest(
{ roomId: "room-123" },
{ "x-api-key": "test-key" },
);
const request = createMockRequest({ roomId: "room-123" }, { "x-api-key": "test-key" });

const result = await handleChatStream(request as any);

Expand Down Expand Up @@ -151,10 +154,7 @@ describe("handleChatStream", () => {
const mockResponse = new Response(mockStream);
mockCreateUIMessageStreamResponse.mockReturnValue(mockResponse);

const request = createMockRequest(
{ prompt: "Hello, world!" },
{ "x-api-key": "valid-key" },
);
const request = createMockRequest({ prompt: "Hello, world!" }, { "x-api-key": "valid-key" });

const result = await handleChatStream(request as any);

Expand Down Expand Up @@ -193,10 +193,7 @@ describe("handleChatStream", () => {
mockCreateUIMessageStreamResponse.mockReturnValue(new Response(mockStream));

const messages = [{ role: "user", content: "Hello" }];
const request = createMockRequest(
{ messages },
{ "x-api-key": "valid-key" },
);
const request = createMockRequest({ messages }, { "x-api-key": "valid-key" });

await handleChatStream(request as any);

Expand Down Expand Up @@ -263,10 +260,7 @@ describe("handleChatStream", () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");
mockSetupChatRequest.mockRejectedValue(new Error("Setup failed"));

const request = createMockRequest(
{ prompt: "Hello" },
{ "x-api-key": "valid-key" },
);
const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" });

const result = await handleChatStream(request as any);

Expand Down Expand Up @@ -321,4 +315,227 @@ describe("handleChatStream", () => {
);
});
});

describe("message persistence", () => {
it("calls saveChatCompletion with text from last assistant message in onFinish", async () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");

const mockAgent = {
stream: vi.fn().mockResolvedValue({
toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()),
usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }),
}),
tools: {},
};

mockSetupChatRequest.mockResolvedValue({
agent: mockAgent,
model: "gpt-4",
instructions: "You are a helpful assistant",
system: "You are a helpful assistant",
messages: [],
experimental_generateMessageId: vi.fn(),
tools: {},
providerOptions: {},
} as any);

// Capture the onFinish callback
let capturedOnFinish: ((event: any) => Promise<void>) | undefined;
mockCreateUIMessageStream.mockImplementation((options: any) => {
capturedOnFinish = options.onFinish;
return new ReadableStream();
});

mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream()));

const request = createMockRequest(
{ prompt: "Hello", roomId: "test-room-id" },
{ "x-api-key": "valid-key" },
);

await handleChatStream(request as any);

// Simulate onFinish being called with assistant messages
expect(capturedOnFinish).toBeDefined();
await capturedOnFinish!({
isAborted: false,
messages: [
{
id: "msg-1",
role: "assistant",
parts: [{ type: "text", text: "Hello! How can I help you?" }],
},
],
responseMessage: {
id: "msg-fallback",
role: "assistant",
parts: [{ type: "text", text: "Fallback response" }],
},
});

expect(mockSaveChatCompletion).toHaveBeenCalledWith({
text: "Hello! How can I help you?",
roomId: "test-room-id",
});
});

it("uses responseMessage as fallback when no assistant messages", async () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");

const mockAgent = {
stream: vi.fn().mockResolvedValue({
toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()),
usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }),
}),
tools: {},
};

mockSetupChatRequest.mockResolvedValue({
agent: mockAgent,
model: "gpt-4",
instructions: "test",
system: "test",
messages: [],
experimental_generateMessageId: vi.fn(),
tools: {},
providerOptions: {},
} as any);

let capturedOnFinish: ((event: any) => Promise<void>) | undefined;
mockCreateUIMessageStream.mockImplementation((options: any) => {
capturedOnFinish = options.onFinish;
return new ReadableStream();
});

mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream()));

const request = createMockRequest(
{ prompt: "Hello", roomId: "test-room-id" },
{ "x-api-key": "valid-key" },
);

await handleChatStream(request as any);

await capturedOnFinish!({
isAborted: false,
messages: [], // No assistant messages
responseMessage: {
id: "msg-fallback",
role: "assistant",
parts: [{ type: "text", text: "Fallback response" }],
},
});

expect(mockSaveChatCompletion).toHaveBeenCalledWith({
text: "Fallback response",
roomId: "test-room-id",
});
});

it("does not call saveChatCompletion when stream is aborted", async () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");

const mockAgent = {
stream: vi.fn().mockResolvedValue({
toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()),
usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }),
}),
tools: {},
};

mockSetupChatRequest.mockResolvedValue({
agent: mockAgent,
model: "gpt-4",
instructions: "test",
system: "test",
messages: [],
experimental_generateMessageId: vi.fn(),
tools: {},
providerOptions: {},
} as any);

let capturedOnFinish: ((event: any) => Promise<void>) | undefined;
mockCreateUIMessageStream.mockImplementation((options: any) => {
capturedOnFinish = options.onFinish;
return new ReadableStream();
});

mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream()));

const request = createMockRequest(
{ prompt: "Hello", roomId: "test-room-id" },
{ "x-api-key": "valid-key" },
);

await handleChatStream(request as any);

await capturedOnFinish!({
isAborted: true,
messages: [],
responseMessage: null,
});

expect(mockSaveChatCompletion).not.toHaveBeenCalled();
});

it("logs error but does not throw when saveChatCompletion fails", async () => {
mockGetApiKeyAccountId.mockResolvedValue("account-123");
mockSaveChatCompletion.mockRejectedValue(new Error("Database error"));

const mockAgent = {
stream: vi.fn().mockResolvedValue({
toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()),
usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }),
}),
tools: {},
};

mockSetupChatRequest.mockResolvedValue({
agent: mockAgent,
model: "gpt-4",
instructions: "test",
system: "test",
messages: [],
experimental_generateMessageId: vi.fn(),
tools: {},
providerOptions: {},
} as any);

let capturedOnFinish: ((event: any) => Promise<void>) | undefined;
mockCreateUIMessageStream.mockImplementation((options: any) => {
capturedOnFinish = options.onFinish;
return new ReadableStream();
});

mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream()));

const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {});

const request = createMockRequest(
{ prompt: "Hello", roomId: "test-room-id" },
{ "x-api-key": "valid-key" },
);

await handleChatStream(request as any);

// Should not throw
await capturedOnFinish!({
isAborted: false,
messages: [
{
id: "msg-1",
role: "assistant",
parts: [{ type: "text", text: "Hello!" }],
},
],
responseMessage: null,
});

expect(consoleErrorSpy).toHaveBeenCalledWith(
"Failed to persist assistant message:",
expect.any(Error),
);
consoleErrorSpy.mockRestore();
});
});
});
Loading