diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 65554a44..1575d6e5 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -4,6 +4,7 @@ import { normalizeToUIMessage, } from '@tanstack/ai' import { DefaultChatClientEventEmitter } from './events' +import { createDefaultSession } from './session-adapter' import type { AnyClientTool, ContentPart, @@ -12,6 +13,7 @@ import type { } from '@tanstack/ai' import type { ConnectionAdapter } from './connection-adapters' import type { ChatClientEventEmitter } from './events' +import type { SessionAdapter } from './session-adapter' import type { ChatClientOptions, ChatClientState, @@ -23,7 +25,7 @@ import type { export class ChatClient { private processor: StreamProcessor - private connection: ConnectionAdapter + private session!: SessionAdapter private uniqueId: string private body: Record = {} private pendingMessageBody: Record | undefined = undefined @@ -40,6 +42,9 @@ export class ChatClient { private pendingToolExecutions: Map> = new Map() // Flag to deduplicate continuation checks during action draining private continuationPending = false + private subscriptionAbortController: AbortController | null = null + private processingResolve: (() => void) | null = null + private streamGeneration = 0 private callbacksRef: { current: { @@ -57,7 +62,15 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} - this.connection = options.connection + + // Resolve session adapter + if (options.session) { + this.session = options.session + } else if (options.connection) { + this.session = createDefaultSession(options.connection) + } else { + throw new Error('Either connection or session must be provided') + } this.events = new DefaultChatClientEventEmitter(this.uniqueId) // Build client tools map @@ -91,10 +104,24 @@ export class ChatClient { }, onStreamStart: () => { this.setStatus('streaming') + const messages = this.processor.getMessages() + const lastAssistant = messages.findLast( + (m: UIMessage) => m.role === 'assistant', + ) + if (lastAssistant) { + this.currentMessageId = lastAssistant.id + this.events.messageAppended( + lastAssistant, + this.currentStreamId || undefined, + ) + } }, onStreamEnd: (message: UIMessage) => { this.callbacksRef.current.onFinish(message) this.setStatus('ready') + // Resolve the processing-complete promise so streamResponse can continue + this.processingResolve?.() + this.processingResolve = null }, onError: (error: Error) => { this.setError(error) @@ -226,68 +253,66 @@ export class ChatClient { } /** - * Process a stream through the StreamProcessor + * Start the background subscription loop. */ - private async processStream( - source: AsyncIterable, - ): Promise { - // Generate a stream ID for this streaming operation - this.currentStreamId = this.generateUniqueId('stream') + private startSubscription(): void { + this.subscriptionAbortController = new AbortController() + const signal = this.subscriptionAbortController.signal - // Prepare for a new assistant message (created lazily on first content) - this.processor.prepareAssistantMessage() + this.consumeSubscription(signal).catch((err) => { + if (err instanceof Error && err.name !== 'AbortError') { + this.setError(err) + this.setStatus('error') + this.callbacksRef.current.onError(err) + } + // Resolve pending processing so streamResponse doesn't hang + this.processingResolve?.() + this.processingResolve = null + }) + } - // Process each chunk - for await (const chunk of source) { + /** + * Consume chunks from the session subscription. + */ + private async consumeSubscription(signal: AbortSignal): Promise { + const stream = this.session.subscribe(signal) + for await (const chunk of stream) { + if (signal.aborted) break this.callbacksRef.current.onChunk(chunk) this.processor.processChunk(chunk) - - // Track the message ID once the processor lazily creates it - if (!this.currentMessageId) { - const newMessageId = - this.processor.getCurrentAssistantMessageId() ?? null - if (newMessageId) { - this.currentMessageId = newMessageId - // Emit message appended event now that the assistant message exists - const assistantMessage = this.processor - .getMessages() - .find((m: UIMessage) => m.id === newMessageId) - if (assistantMessage) { - this.events.messageAppended( - assistantMessage, - this.currentStreamId || undefined, - ) - } - } + // RUN_FINISHED / RUN_ERROR signal run completion — resolve processing + // (redundant if onStreamEnd already resolved it, harmless) + if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { + this.processingResolve?.() + this.processingResolve = null } - - // Yield control back to event loop to allow UI updates + // Yield control back to event loop for UI updates await new Promise((resolve) => setTimeout(resolve, 0)) } + } - // Wait for all pending tool executions to complete before finalizing - // This ensures client tools finish before we check for continuation - if (this.pendingToolExecutions.size > 0) { - await Promise.all(this.pendingToolExecutions.values()) - } - - // Finalize the stream - this.processor.finalizeStream() - - // Get the message ID (may be null if no content arrived) - const messageId = this.processor.getCurrentAssistantMessageId() - - // Clear the current stream and message IDs - this.currentStreamId = null - this.currentMessageId = null - - // Return the assistant message if one was created - if (messageId) { - const messages = this.processor.getMessages() - return messages.find((m: UIMessage) => m.id === messageId) || null + /** + * Ensure subscription loop is running, starting it if needed. + */ + private ensureSubscription(): void { + if ( + !this.subscriptionAbortController || + this.subscriptionAbortController.signal.aborted + ) { + this.startSubscription() } + } - return null + /** + * Create a promise that resolves when onStreamEnd fires. + * Used by streamResponse to await processing completion. + */ + private waitForProcessing(): Promise { + // Resolve any stale promise (e.g., from a previous aborted request) + this.processingResolve?.() + return new Promise((resolve) => { + this.processingResolve = resolve + }) } /** @@ -407,6 +432,9 @@ export class ChatClient { return } + // Track generation so a superseded stream's cleanup doesn't clobber the new one + const generation = ++this.streamGeneration + this.setIsLoading(true) this.setStatus('submitted') this.setError(undefined) @@ -433,42 +461,78 @@ export class ChatClient { // Clear the pending message body after use this.pendingMessageBody = undefined - // Connect and stream - const stream = this.connection.connect( - messages, - mergedBody, - this.abortController.signal, - ) + // Generate stream ID — assistant message will be created by stream events + this.currentStreamId = this.generateUniqueId('stream') + this.currentMessageId = null + + // Reset processor stream state for new response — prevents stale + // messageStates entries (from a previous stream) from blocking + // creation of a new assistant message (e.g. after reload). + this.processor.prepareAssistantMessage() + + // Ensure subscription loop is running + this.ensureSubscription() + + // Set up promise that resolves when onStreamEnd fires + const processingComplete = this.waitForProcessing() + + // Send through session adapter (pushes chunks to subscription queue) + await this.session.send(messages, mergedBody, this.abortController.signal) + + // Wait for subscription loop to finish processing all chunks + await processingComplete - await this.processStream(stream) + // If this stream was superseded (e.g. by reload()), bail out — + // the new stream owns the processor and processingResolve now. + if (generation !== this.streamGeneration) { + return + } + + // Wait for pending client tool executions + if (this.pendingToolExecutions.size > 0) { + await Promise.all(this.pendingToolExecutions.values()) + } + + // Finalize (idempotent — may already be done by RUN_FINISHED handler) + this.processor.finalizeStream() + + this.currentStreamId = null + this.currentMessageId = null streamCompletedSuccessfully = true } catch (err) { if (err instanceof Error) { if (err.name === 'AbortError') { return } - this.setError(err) - this.setStatus('error') - this.callbacksRef.current.onError(err) + if (generation === this.streamGeneration) { + this.setError(err) + this.setStatus('error') + this.callbacksRef.current.onError(err) + } } } finally { - this.abortController = null - this.setIsLoading(false) - this.pendingMessageBody = undefined // Ensure it's cleared even on error - - // Drain any actions that were queued while the stream was in progress - await this.drainPostStreamActions() - - // Continue conversation if the stream ended with a tool result (server tool completed) - if (streamCompletedSuccessfully) { - const messages = this.processor.getMessages() - const lastPart = messages.at(-1)?.parts.at(-1) - - if (lastPart?.type === 'tool-result' && this.shouldAutoSend()) { - try { - await this.checkForContinuation() - } catch (error) { - console.error('Failed to continue flow after tool result:', error) + // Only clean up if this is still the active stream. + // A superseded stream (e.g. reload() started a new one) must not + // clobber the new stream's abortController or isLoading state. + if (generation === this.streamGeneration) { + this.abortController = null + this.setIsLoading(false) + this.pendingMessageBody = undefined // Ensure it's cleared even on error + + // Drain any actions that were queued while the stream was in progress + await this.drainPostStreamActions() + + // Continue conversation if the stream ended with a tool result (server tool completed) + if (streamCompletedSuccessfully) { + const messages = this.processor.getMessages() + const lastPart = messages.at(-1)?.parts.at(-1) + + if (lastPart?.type === 'tool-result' && this.shouldAutoSend()) { + try { + await this.checkForContinuation() + } catch (error) { + console.error('Failed to continue flow after tool result:', error) + } } } } @@ -489,6 +553,17 @@ export class ChatClient { if (lastUserMessageIndex === -1) return + // Cancel any active stream before reloading + if (this.isLoading) { + this.abortController?.abort() + this.abortController = null + this.subscriptionAbortController?.abort() + this.subscriptionAbortController = null + this.processingResolve?.() + this.processingResolve = null + this.setIsLoading(false) + } + this.events.reloaded(lastUserMessageIndex) // Remove all messages after the last user message @@ -502,10 +577,20 @@ export class ChatClient { * Stop the current stream */ stop(): void { + // Abort any in-flight send if (this.abortController) { this.abortController.abort() this.abortController = null } + + // Abort the subscription loop + this.subscriptionAbortController?.abort() + this.subscriptionAbortController = null + + // Resolve any pending processing promise (unblock streamResponse) + this.processingResolve?.() + this.processingResolve = null + this.setIsLoading(false) this.setStatus('ready') this.events.stopped() @@ -678,6 +763,7 @@ export class ChatClient { */ updateOptions(options: { connection?: ConnectionAdapter + session?: SessionAdapter body?: Record tools?: ReadonlyArray onResponse?: (response?: Response) => void | Promise @@ -685,8 +771,12 @@ export class ChatClient { onFinish?: (message: UIMessage) => void onError?: (error: Error) => void }): void { - if (options.connection !== undefined) { - this.connection = options.connection + if (options.session !== undefined) { + this.subscriptionAbortController?.abort() + this.session = options.session + } else if (options.connection !== undefined) { + this.subscriptionAbortController?.abort() + this.session = createDefaultSession(options.connection) } if (options.body !== undefined) { this.body = options.body diff --git a/packages/typescript/ai-client/src/index.ts b/packages/typescript/ai-client/src/index.ts index b279605d..0ad066a9 100644 --- a/packages/typescript/ai-client/src/index.ts +++ b/packages/typescript/ai-client/src/index.ts @@ -30,6 +30,7 @@ export { type ConnectionAdapter, type FetchConnectionOptions, } from './connection-adapters' +export { createDefaultSession, type SessionAdapter } from './session-adapter' // Re-export message converters from @tanstack/ai export { diff --git a/packages/typescript/ai-client/src/session-adapter.ts b/packages/typescript/ai-client/src/session-adapter.ts new file mode 100644 index 00000000..5fda2aa0 --- /dev/null +++ b/packages/typescript/ai-client/src/session-adapter.ts @@ -0,0 +1,119 @@ +import type { StreamChunk, UIMessage } from '@tanstack/ai' +import type { ConnectionAdapter } from './connection-adapters' + +/** + * Session adapter interface for persistent stream-based chat sessions. + * + * Unlike ConnectionAdapter (which creates a new stream per request), + * a SessionAdapter maintains a persistent subscription. Responses from + * send() arrive through subscribe(), not as a return value. + * + * The subscribe() stream yields standard AG-UI events (StreamChunk). + * The processor handles whichever event types it supports — currently + * text message lifecycle, tool calls, and MESSAGES_SNAPSHOT. Future + * event handlers (STATE_SNAPSHOT, STATE_DELTA, etc.) are purely additive. + */ +export interface SessionAdapter { + /** + * Subscribe to the session stream. + * Returns an async iterable that yields chunks continuously. + * For durable sessions, this may first yield a MESSAGES_SNAPSHOT + * to hydrate the conversation, then subscribe to the live stream + * from the appropriate offset. + */ + subscribe: (signal?: AbortSignal) => AsyncIterable + + /** + * Send messages to the session. + * For durable sessions, the proxy writes to the stream and forwards to the API. + * The response arrives through subscribe(), not as a return value. + */ + send: ( + messages: Array, + data?: Record, + signal?: AbortSignal, + ) => Promise +} + +/** + * Wraps a ConnectionAdapter into a SessionAdapter using an async queue pattern. + * send() calls connection.connect() and pushes chunks to the queue. + * subscribe() yields chunks from the queue. + * + * Each subscribe() call synchronously replaces the active buffer/waiters + * so that concurrent send() calls write to the current subscription's queue. + * This prevents a race condition where an old subscription's async cleanup + * (clearing the shared buffer after abort) could destroy chunks intended + * for a new subscription. + */ +export function createDefaultSession( + connection: ConnectionAdapter, +): SessionAdapter { + // Active buffer and waiters — replaced synchronously on each subscribe() call + let activeBuffer: Array = [] + let activeWaiters: Array<(chunk: StreamChunk | null) => void> = [] + + function push(chunk: StreamChunk): void { + const waiter = activeWaiters.shift() + if (waiter) { + waiter(chunk) + } else { + activeBuffer.push(chunk) + } + } + + return { + subscribe(signal?: AbortSignal): AsyncIterable { + // Drain any buffered chunks (e.g. from send() before subscribe()) into + // a fresh per-subscription buffer. splice(0) atomically empties the old + // array, so a previous subscription's local reference becomes empty. + const myBuffer: Array = activeBuffer.splice(0) + const myWaiters: Array<(chunk: StreamChunk | null) => void> = [] + activeBuffer = myBuffer + activeWaiters = myWaiters + + return (async function* () { + while (!signal?.aborted) { + let chunk: StreamChunk | null + if (myBuffer.length > 0) { + chunk = myBuffer.shift()! + } else { + chunk = await new Promise((resolve) => { + const onAbort = () => resolve(null) + myWaiters.push((c) => { + signal?.removeEventListener('abort', onAbort) + resolve(c) + }) + signal?.addEventListener('abort', onAbort, { once: true }) + }) + } + if (chunk !== null) yield chunk + } + // No shared-state cleanup needed — myBuffer/myWaiters are local + // and will be garbage collected when this generator is released. + })() + }, + + async send(messages, data, signal) { + try { + const stream = connection.connect(messages, data, signal) + for await (const chunk of stream) { + push(chunk) + } + } catch (err) { + // Push a RUN_ERROR event so subscribe() consumers learn about the + // failure through the standard AG-UI protocol, then re-throw so + // send() callers (e.g. streamResponse) can also handle it. + push({ + type: 'RUN_ERROR', + timestamp: Date.now(), + error: { + message: + err instanceof Error ? err.message : 'Unknown error in send()', + }, + }) + throw err + } + }, + } +} diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index 98572548..818b3fd6 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -12,6 +12,7 @@ import type { VideoPart, } from '@tanstack/ai' import type { ConnectionAdapter } from './connection-adapters' +import type { SessionAdapter } from './session-adapter' /** * Tool call states - track the lifecycle of a tool call @@ -178,10 +179,18 @@ export interface ChatClientOptions< TTools extends ReadonlyArray = any, > { /** - * Connection adapter for streaming - * Use fetchServerSentEvents(), fetchHttpStream(), or stream() to create adapters + * Connection adapter for streaming (request-response mode). + * Wrapped in a DefaultSessionAdapter internally. + * Provide either `connection` or `session`, not both. */ - connection: ConnectionAdapter + connection?: ConnectionAdapter + + /** + * Session adapter for persistent stream-based sessions. + * When provided, takes over from connection. + * Provide either `connection` or `session`, not both. + */ + session?: SessionAdapter /** * Initial messages to populate the chat diff --git a/packages/typescript/ai-client/tests/chat-client.test.ts b/packages/typescript/ai-client/tests/chat-client.test.ts index 27960378..ede3e2da 100644 --- a/packages/typescript/ai-client/tests/chat-client.test.ts +++ b/packages/typescript/ai-client/tests/chat-client.test.ts @@ -74,6 +74,12 @@ describe('ChatClient', () => { // Message IDs should be unique between clients expect(client1MessageId).not.toBe(client2MessageId) }) + + it('should throw if neither connection nor session is provided', () => { + expect(() => new ChatClient({} as any)).toThrow( + 'Either connection or session must be provided', + ) + }) }) describe('sendMessage', () => { @@ -387,8 +393,11 @@ describe('ChatClient', () => { await client.sendMessage('Hello') - expect(onError).toHaveBeenCalledWith(error) - expect(client.getError()).toBe(error) + expect(onError).toHaveBeenCalled() + expect(onError.mock.calls[0]![0]).toBeInstanceOf(Error) + expect(onError.mock.calls[0]![0].message).toBe('Connection failed') + expect(client.getError()).toBeInstanceOf(Error) + expect(client.getError()?.message).toBe('Connection failed') }) }) @@ -500,7 +509,8 @@ describe('ChatClient', () => { await client.sendMessage('Hello') - expect(client.getError()).toBe(error) + expect(client.getError()).toBeInstanceOf(Error) + expect(client.getError()?.message).toBe('Network error') expect(client.getStatus()).toBe('error') }) diff --git a/packages/typescript/ai-client/tests/session-adapter.test.ts b/packages/typescript/ai-client/tests/session-adapter.test.ts new file mode 100644 index 00000000..0fed006a --- /dev/null +++ b/packages/typescript/ai-client/tests/session-adapter.test.ts @@ -0,0 +1,276 @@ +import { describe, expect, it, vi } from 'vitest' +import { createDefaultSession } from '../src/session-adapter' +import { createMockConnectionAdapter, createTextChunks } from './test-utils' +import type { StreamChunk } from '@tanstack/ai' + +describe('createDefaultSession', () => { + it('should yield chunks sent through send() via subscribe()', async () => { + const chunks = createTextChunks('Hi', 'msg-1') + const connection = createMockConnectionAdapter({ chunks }) + const session = createDefaultSession(connection) + + const abortController = new AbortController() + const iterator = session.subscribe(abortController.signal) + + // Send messages — this pushes all chunks into the queue + await session.send([], undefined) + + // Collect chunks from the subscription + const received: Array = [] + for await (const chunk of iterator) { + received.push(chunk) + // Stop after receiving all expected chunks + if (received.length === chunks.length) { + abortController.abort() + } + } + + expect(received).toEqual(chunks) + }) + + it('should deliver chunks from multiple sends in order', async () => { + const chunks1: Array = [ + { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'A', + content: 'A', + }, + ] + const chunks2: Array = [ + { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-2', + model: 'test', + timestamp: Date.now(), + delta: 'B', + content: 'B', + }, + ] + + let callCount = 0 + const connection = createMockConnectionAdapter({ + chunks: [], // overridden below + }) + // Override connect to return different chunks per call + connection.connect = function (_messages, _data, _signal) { + callCount++ + const currentChunks = callCount === 1 ? chunks1 : chunks2 + return (async function* () { + for (const chunk of currentChunks) { + yield chunk + } + })() + } + + const session = createDefaultSession(connection) + const abortController = new AbortController() + const iterator = session.subscribe(abortController.signal) + + // Send both in sequence + await session.send([], undefined) + await session.send([], undefined) + + const received: Array = [] + for await (const chunk of iterator) { + received.push(chunk) + if (received.length === 2) { + abortController.abort() + } + } + + expect(received).toEqual([...chunks1, ...chunks2]) + }) + + it('should stop the iterator when the abort signal fires', async () => { + const connection = createMockConnectionAdapter({ chunks: [] }) + const session = createDefaultSession(connection) + + const abortController = new AbortController() + const iterator = session.subscribe(abortController.signal) + + // Abort immediately — the iterator should stop without yielding + abortController.abort() + + const received: Array = [] + for await (const chunk of iterator) { + received.push(chunk) + } + + expect(received).toEqual([]) + }) + + it('should abort a waiting subscriber', async () => { + const connection = createMockConnectionAdapter({ chunks: [] }) + const session = createDefaultSession(connection) + + const abortController = new AbortController() + const iterator = session.subscribe(abortController.signal) + + // Start consuming — this will block waiting for chunks + const resultPromise = (async () => { + const received: Array = [] + for await (const chunk of iterator) { + received.push(chunk) + } + return received + })() + + // Let the subscriber start waiting + await new Promise((resolve) => setTimeout(resolve, 10)) + + // Abort — should unblock the subscriber + abortController.abort() + + const received = await resultPromise + expect(received).toEqual([]) + }) + + it('should propagate errors from connection.connect() through send()', async () => { + const testError = new Error('connection failed') + const connection = createMockConnectionAdapter({ + shouldError: true, + error: testError, + }) + const session = createDefaultSession(connection) + + await expect(session.send([], undefined)).rejects.toThrow( + 'connection failed', + ) + }) + + it('should buffer chunks when subscriber is not yet consuming', async () => { + const chunks = createTextChunks('AB', 'msg-1') + const connection = createMockConnectionAdapter({ chunks }) + const session = createDefaultSession(connection) + + // Send first, before subscribing + await session.send([], undefined) + + // Now subscribe and consume + const abortController = new AbortController() + const iterator = session.subscribe(abortController.signal) + + const received: Array = [] + for await (const chunk of iterator) { + received.push(chunk) + if (received.length === chunks.length) { + abortController.abort() + } + } + + expect(received).toEqual(chunks) + }) + + it('should pass messages and data through to connection.connect()', async () => { + const onConnect = vi.fn() + const connection = createMockConnectionAdapter({ + chunks: [ + { + type: 'RUN_FINISHED', + runId: 'r1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + }, + ], + onConnect, + }) + const session = createDefaultSession(connection) + + const messages = [ + { + id: 'u1', + role: 'user' as const, + parts: [{ type: 'text' as const, content: 'hello' }], + }, + ] + const data = { model: 'gpt-4o' } + + await session.send(messages, data) + + expect(onConnect).toHaveBeenCalledWith( + messages, + data, + undefined, // signal + ) + }) + + it('should pass abort signal from send() to connection.connect()', async () => { + const onConnect = vi.fn() + const connection = createMockConnectionAdapter({ + chunks: [], + onConnect, + }) + const session = createDefaultSession(connection) + + const abortController = new AbortController() + await session.send([], undefined, abortController.signal) + + expect(onConnect).toHaveBeenCalledWith( + [], + undefined, + abortController.signal, + ) + }) + + it('should not lose chunks after stop-then-resume subscription cycle', async () => { + const connection = createMockConnectionAdapter({ chunks: [] }) + const session = createDefaultSession(connection) + + // First subscription — abort while waiting (simulates stop) + const ac1 = new AbortController() + const iter1 = session.subscribe(ac1.signal) + + // Start consuming — will block waiting for chunks + const result1Promise = (async () => { + const received: Array = [] + for await (const chunk of iter1) { + received.push(chunk) + } + return received + })() + + // Let the subscriber enter the wait path + await new Promise((resolve) => setTimeout(resolve, 10)) + + // Abort — this resolves the dead waiter with null + ac1.abort() + const received1 = await result1Promise + expect(received1).toEqual([]) + + // Second subscription — should work correctly + const ac2 = new AbortController() + const iter2 = session.subscribe(ac2.signal) + + // Send a chunk — it should be delivered to the new subscriber + const testChunk: StreamChunk = { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'Hello', + content: 'Hello', + } + + // Override connect to yield the test chunk + connection.connect = function* () { + yield testChunk + } as any + + await session.send([], undefined) + + const received2: Array = [] + for await (const chunk of iter2) { + received2.push(chunk) + if (received2.length === 1) { + ac2.abort() + } + } + + // The chunk should NOT be lost + expect(received2).toEqual([testChunk]) + }) +}) diff --git a/packages/typescript/ai-preact/src/use-chat.ts b/packages/typescript/ai-preact/src/use-chat.ts index cfa9340f..10576277 100644 --- a/packages/typescript/ai-preact/src/use-chat.ts +++ b/packages/typescript/ai-preact/src/use-chat.ts @@ -53,6 +53,7 @@ export function useChat = any>( return new ChatClient({ connection: optionsRef.current.connection, + session: optionsRef.current.session, id: clientId, initialMessages: messagesToUse, body: optionsRef.current.body, diff --git a/packages/typescript/ai-react/src/use-chat.ts b/packages/typescript/ai-react/src/use-chat.ts index 2cdc02d1..7a4d78e5 100644 --- a/packages/typescript/ai-react/src/use-chat.ts +++ b/packages/typescript/ai-react/src/use-chat.ts @@ -52,6 +52,7 @@ export function useChat = any>( return new ChatClient({ connection: optionsRef.current.connection, + session: optionsRef.current.session, id: clientId, initialMessages: messagesToUse, body: optionsRef.current.body, diff --git a/packages/typescript/ai-solid/src/use-chat.ts b/packages/typescript/ai-solid/src/use-chat.ts index 77d0edf9..206e27f4 100644 --- a/packages/typescript/ai-solid/src/use-chat.ts +++ b/packages/typescript/ai-solid/src/use-chat.ts @@ -35,6 +35,7 @@ export function useChat = any>( const client = createMemo(() => { return new ChatClient({ connection: options.connection, + session: options.session, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/packages/typescript/ai-svelte/src/create-chat.svelte.ts b/packages/typescript/ai-svelte/src/create-chat.svelte.ts index 5354ae11..e4490483 100644 --- a/packages/typescript/ai-svelte/src/create-chat.svelte.ts +++ b/packages/typescript/ai-svelte/src/create-chat.svelte.ts @@ -55,6 +55,7 @@ export function createChat = any>( // Create ChatClient instance const client = new ChatClient({ connection: options.connection, + session: options.session, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/packages/typescript/ai-vue/src/use-chat.ts b/packages/typescript/ai-vue/src/use-chat.ts index 6042fc53..80ac0bb8 100644 --- a/packages/typescript/ai-vue/src/use-chat.ts +++ b/packages/typescript/ai-vue/src/use-chat.ts @@ -25,6 +25,7 @@ export function useChat = any>( // Create ChatClient instance with callbacks to sync state const client = new ChatClient({ connection: options.connection, + session: options.session, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/packages/typescript/ai/src/activities/chat/stream/processor.ts b/packages/typescript/ai/src/activities/chat/stream/processor.ts index 96d95865..555c46d4 100644 --- a/packages/typescript/ai/src/activities/chat/stream/processor.ts +++ b/packages/typescript/ai/src/activities/chat/stream/processor.ts @@ -12,6 +12,7 @@ * - Thinking/reasoning content * - Recording/replay for testing * - Event-driven architecture for UI updates + * - Per-message stream state tracking for multi-message sessions * * @see docs/chat-architecture.md — Canonical reference for AG-UI chunk ordering, * adapter contract, single-shot flows, and expected UIMessage output. @@ -32,6 +33,7 @@ import type { ChunkRecording, ChunkStrategy, InternalToolCallState, + MessageStreamState, ProcessorResult, ProcessorState, ToolCallState, @@ -109,9 +111,8 @@ export interface StreamProcessorOptions { * * State tracking: * - Full message array - * - Current assistant message being streamed - * - Text content accumulation (reset on TEXT_MESSAGE_START) - * - Multiple parallel tool calls + * - Per-message stream state (text, tool calls, thinking) + * - Multiple concurrent message streams * - Tool call completion via TOOL_CALL_END events * * @see docs/chat-architecture.md#streamprocessor-internal-state — State field reference @@ -125,17 +126,14 @@ export class StreamProcessor { // Message state private messages: Array = [] - private currentAssistantMessageId: string | null = null - - // Stream state for current assistant message - // Total accumulated text across all segments (for the final result) - private totalTextContent = '' - // Current segment's text content (for onTextUpdate callbacks) - private currentSegmentText = '' - private lastEmittedText = '' - private thinkingContent = '' - private toolCalls: Map = new Map() - private toolCallOrder: Array = [] + + // Per-message stream state + private messageStates: Map = new Map() + private activeMessageIds: Set = new Set() + private toolCallToMessage: Map = new Map() + private pendingManualMessageId: string | null = null + + // Shared stream state private finishReason: string | null = null private hasError = false private isDone = false @@ -224,18 +222,17 @@ export class StreamProcessor { prepareAssistantMessage(): void { // Reset stream state for new message this.resetStreamState() - // Clear the current assistant message ID so ensureAssistantMessage() - // will create a fresh message on the first content chunk - this.currentAssistantMessageId = null } /** * @deprecated Use prepareAssistantMessage() instead. This eagerly creates * an assistant message which can cause empty message flicker. */ - startAssistantMessage(): string { + startAssistantMessage(messageId?: string): string { this.prepareAssistantMessage() - return this.ensureAssistantMessage() + const { messageId: id } = this.ensureAssistantMessage(messageId) + this.pendingManualMessageId = id + return id } /** @@ -244,39 +241,16 @@ export class StreamProcessor { * has arrived yet. */ getCurrentAssistantMessageId(): string | null { - return this.currentAssistantMessageId - } - - /** - * Lazily create the assistant message if it hasn't been created yet. - * Called by content handlers on the first content-bearing chunk. - * Returns the message ID. - * - * Content-bearing chunks that trigger this: - * TEXT_MESSAGE_CONTENT, TOOL_CALL_START, STEP_FINISHED, RUN_ERROR. - * - * @see docs/chat-architecture.md#streamprocessor-internal-state — Lazy creation pattern - */ - private ensureAssistantMessage(): string { - if (this.currentAssistantMessageId) { - return this.currentAssistantMessageId - } - - const assistantMessage: UIMessage = { - id: generateMessageId(), - role: 'assistant', - parts: [], - createdAt: new Date(), + // Scan all message states (not just active) for the last assistant. + // After finalizeStream() clears activeMessageIds, messageStates retains entries. + // After reset() / resetStreamState(), messageStates is cleared → returns null. + let lastId: string | null = null + for (const [id, state] of this.messageStates) { + if (state.role === 'assistant') { + lastId = id + } } - - this.currentAssistantMessageId = assistantMessage.id - this.messages = [...this.messages, assistantMessage] - - // Emit events - this.events.onStreamStart?.() - this.emitMessagesChange() - - return assistantMessage.id + return lastId } /** @@ -403,7 +377,10 @@ export class StreamProcessor { */ clearMessages(): void { this.messages = [] - this.currentAssistantMessageId = null + this.messageStates.clear() + this.activeMessageIds.clear() + this.toolCallToMessage.clear() + this.pendingManualMessageId = null this.emitMessagesChange() } @@ -444,7 +421,7 @@ export class StreamProcessor { * * Central dispatch for all AG-UI events. Each event type maps to a specific * handler. Events not listed in the switch are intentionally ignored - * (RUN_STARTED, TEXT_MESSAGE_END, STEP_STARTED, STATE_SNAPSHOT, STATE_DELTA). + * (RUN_STARTED, STEP_STARTED, STATE_DELTA). * * @see docs/chat-architecture.md#adapter-contract — Expected event types and ordering */ @@ -461,13 +438,17 @@ export class StreamProcessor { switch (chunk.type) { // AG-UI Events case 'TEXT_MESSAGE_START': - this.handleTextMessageStartEvent() + this.handleTextMessageStartEvent(chunk) break case 'TEXT_MESSAGE_CONTENT': this.handleTextMessageContentEvent(chunk) break + case 'TEXT_MESSAGE_END': + this.handleTextMessageEndEvent(chunk) + break + case 'TOOL_CALL_START': this.handleToolCallStartEvent(chunk) break @@ -492,35 +473,230 @@ export class StreamProcessor { this.handleStepFinishedEvent(chunk) break + case 'MESSAGES_SNAPSHOT': + this.handleMessagesSnapshotEvent(chunk) + break + case 'CUSTOM': this.handleCustomEvent(chunk) break default: - // RUN_STARTED, TEXT_MESSAGE_END, STEP_STARTED, - // STATE_SNAPSHOT, STATE_DELTA - no special handling needed + // RUN_STARTED, STEP_STARTED, STATE_SNAPSHOT, STATE_DELTA - no special handling needed break } } + // ============================================ + // Per-Message State Helpers + // ============================================ + /** - * Handle TEXT_MESSAGE_START event — marks the beginning of a new text segment. - * Resets segment accumulation so text after tool calls starts fresh. - * - * This is the key mechanism for multi-segment text (text before and after tool - * calls becoming separate TextParts). Without this reset, all text would merge - * into a single TextPart and tool-call interleaving would be lost. - * - * @see docs/chat-architecture.md#single-shot-text-response — Step-by-step text processing - * @see docs/chat-architecture.md#text-then-tool-interleaving-single-shot — Multi-segment text + * Create a new MessageStreamState for a message + */ + private createMessageState( + messageId: string, + role: 'user' | 'assistant' | 'system', + ): MessageStreamState { + const state: MessageStreamState = { + id: messageId, + role, + totalTextContent: '', + currentSegmentText: '', + lastEmittedText: '', + thinkingContent: '', + toolCalls: new Map(), + toolCallOrder: [], + hasToolCallsSinceTextStart: false, + isComplete: false, + } + this.messageStates.set(messageId, state) + return state + } + + /** + * Get the MessageStreamState for a message */ - private handleTextMessageStartEvent(): void { - // Emit any pending text from a previous segment before resetting - if (this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + private getMessageState(messageId: string): MessageStreamState | undefined { + return this.messageStates.get(messageId) + } + + /** + * Get the most recent active assistant message ID. + * Used as fallback for events that don't include a messageId. + */ + private getActiveAssistantMessageId(): string | null { + // Set iteration is insertion-order; convert to array and search from the end + const ids = Array.from(this.activeMessageIds) + for (let i = ids.length - 1; i >= 0; i--) { + const id = ids[i]! + const state = this.messageStates.get(id) + if (state && state.role === 'assistant') { + return id + } } - this.currentSegmentText = '' - this.lastEmittedText = '' + return null + } + + /** + * Ensure an active assistant message exists, creating one if needed. + * Used for backward compat when events arrive without prior TEXT_MESSAGE_START. + */ + private ensureAssistantMessage(preferredId?: string): { + messageId: string + state: MessageStreamState + } { + // Try to find state by preferred ID + if (preferredId) { + const state = this.getMessageState(preferredId) + if (state) return { messageId: preferredId, state } + } + + // Try active assistant message + const activeId = this.getActiveAssistantMessageId() + if (activeId) { + const state = this.getMessageState(activeId)! + return { messageId: activeId, state } + } + + // Auto-create an assistant message (backward compat for process() without TEXT_MESSAGE_START) + const id = preferredId || generateMessageId() + const assistantMessage: UIMessage = { + id, + role: 'assistant', + parts: [], + createdAt: new Date(), + } + this.messages = [...this.messages, assistantMessage] + const state = this.createMessageState(id, 'assistant') + this.activeMessageIds.add(id) + this.pendingManualMessageId = id + this.events.onStreamStart?.() + this.emitMessagesChange() + return { messageId: id, state } + } + + // ============================================ + // Event Handlers + // ============================================ + + /** + * Handle TEXT_MESSAGE_START event + */ + private handleTextMessageStartEvent( + chunk: Extract, + ): void { + const { messageId, role } = chunk + + // Map 'tool' role to 'assistant' for both UIMessage and MessageStreamState + // (UIMessage doesn't support 'tool' role, and lookups like + // getActiveAssistantMessageId() check state.role === 'assistant') + const uiRole: 'system' | 'user' | 'assistant' = + role === 'tool' ? 'assistant' : role + + // Case 1: A manual message was created via startAssistantMessage() + if (this.pendingManualMessageId) { + const pendingId = this.pendingManualMessageId + this.pendingManualMessageId = null + + if (pendingId !== messageId) { + // Update the message's ID in the messages array + this.messages = this.messages.map((msg) => + msg.id === pendingId ? { ...msg, id: messageId } : msg, + ) + + // Move state to the new key + const existingState = this.messageStates.get(pendingId) + if (existingState) { + existingState.id = messageId + this.messageStates.delete(pendingId) + this.messageStates.set(messageId, existingState) + } + + // Update activeMessageIds + this.activeMessageIds.delete(pendingId) + this.activeMessageIds.add(messageId) + } + + // Ensure state exists + if (!this.messageStates.has(messageId)) { + this.createMessageState(messageId, uiRole) + this.activeMessageIds.add(messageId) + } + + this.emitMessagesChange() + return + } + + // Case 2: Message already exists (dedup) + const existingMsg = this.messages.find((m) => m.id === messageId) + if (existingMsg) { + this.activeMessageIds.add(messageId) + if (!this.messageStates.has(messageId)) { + this.createMessageState(messageId, uiRole) + } else { + const existingState = this.messageStates.get(messageId)! + // If tool calls happened since last text, this TEXT_MESSAGE_START + // signals a new text segment — reset segment accumulation + if (existingState.hasToolCallsSinceTextStart) { + if ( + existingState.currentSegmentText !== existingState.lastEmittedText + ) { + this.emitTextUpdateForMessage(messageId) + } + existingState.currentSegmentText = '' + existingState.lastEmittedText = '' + existingState.hasToolCallsSinceTextStart = false + } + } + return + } + + // Case 3: New message from the stream + const newMessage: UIMessage = { + id: messageId, + role: uiRole, + parts: [], + createdAt: new Date(), + } + + this.messages = [...this.messages, newMessage] + this.createMessageState(messageId, uiRole) + this.activeMessageIds.add(messageId) + + this.events.onStreamStart?.() + this.emitMessagesChange() + } + + /** + * Handle TEXT_MESSAGE_END event + */ + private handleTextMessageEndEvent( + chunk: Extract, + ): void { + const { messageId } = chunk + const state = this.getMessageState(messageId) + if (!state) return + if (state.isComplete) return + + // Emit any pending text for this message + if (state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + + // Complete all tool calls for this message + this.completeAllToolCallsForMessage(messageId) + } + + /** + * Handle MESSAGES_SNAPSHOT event + */ + private handleMessagesSnapshotEvent( + chunk: Extract, + ): void { + this.resetStreamState() + this.messages = [...chunk.messages] + this.emitMessagesChange() } /** @@ -537,17 +713,62 @@ export class StreamProcessor { private handleTextMessageContentEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + const { messageId, state } = this.ensureAssistantMessage(chunk.messageId) + + // Content arriving means all current tool calls for this message are complete + this.completeAllToolCallsForMessage(messageId) + + const previousSegment = state.currentSegmentText + + // Detect if this is a NEW text segment (after tool calls) vs continuation + const isNewSegment = + state.hasToolCallsSinceTextStart && + previousSegment.length > 0 && + this.isNewTextSegment(chunk, previousSegment) + + if (isNewSegment) { + // Emit any accumulated text before starting new segment + if (previousSegment !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + // Reset SEGMENT text accumulation for the new text segment after tool calls + state.currentSegmentText = '' + state.lastEmittedText = '' + state.hasToolCallsSinceTextStart = false + } + + const currentText = state.currentSegmentText + let nextText = currentText + + // Prefer delta over content - delta is the incremental change + // Normalize to empty string to avoid "undefined" string concatenation + const delta = chunk.delta || '' + if (delta !== '') { + nextText = currentText + delta + } else if (chunk.content !== undefined && chunk.content !== '') { + // Fallback: use content if delta is not provided + if (chunk.content.startsWith(currentText)) { + nextText = chunk.content + } else if (currentText.startsWith(chunk.content)) { + nextText = currentText + } else { + nextText = currentText + chunk.content + } + } - this.currentSegmentText += chunk.delta - this.totalTextContent += chunk.delta + // Calculate the delta for totalTextContent + const textDelta = nextText.slice(currentText.length) + state.currentSegmentText = nextText + state.totalTextContent += textDelta + // Use delta for chunk strategy if available + const chunkPortion = chunk.delta || chunk.content || '' const shouldEmit = this.chunkStrategy.shouldEmit( - chunk.delta, - this.currentSegmentText, + chunkPortion, + state.currentSegmentText, ) - if (shouldEmit && this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + if (shouldEmit && state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) } } @@ -567,10 +788,18 @@ export class StreamProcessor { private handleToolCallStartEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + // Determine the message this tool call belongs to + const targetMessageId = + chunk.parentMessageId ?? this.getActiveAssistantMessageId() + const { messageId, state } = this.ensureAssistantMessage( + targetMessageId ?? undefined, + ) + + // Mark that we've seen tool calls since the last text segment + state.hasToolCallsSinceTextStart = true const toolCallId = chunk.toolCallId - const existingToolCall = this.toolCalls.get(toolCallId) + const existingToolCall = state.toolCalls.get(toolCallId) if (!existingToolCall) { // New tool call starting @@ -582,34 +811,31 @@ export class StreamProcessor { arguments: '', state: initialState, parsedArguments: undefined, - index: chunk.index ?? this.toolCalls.size, + index: chunk.index ?? state.toolCalls.size, } - this.toolCalls.set(toolCallId, newToolCall) - this.toolCallOrder.push(toolCallId) + state.toolCalls.set(toolCallId, newToolCall) + state.toolCallOrder.push(toolCallId) + + // Store mapping for TOOL_CALL_ARGS/END routing + this.toolCallToMessage.set(toolCallId, messageId) // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: chunk.toolCallId, - name: chunk.toolName, - arguments: '', - state: initialState, - }, - ) - this.emitMessagesChange() + this.messages = updateToolCallPart(this.messages, messageId, { + id: chunk.toolCallId, + name: chunk.toolName, + arguments: '', + state: initialState, + }) + this.emitMessagesChange() - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - chunk.toolCallId, - initialState, - '', - ) - } + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + chunk.toolCallId, + initialState, + '', + ) } } @@ -629,47 +855,46 @@ export class StreamProcessor { chunk: Extract, ): void { const toolCallId = chunk.toolCallId - const existingToolCall = this.toolCalls.get(toolCallId) + const messageId = this.toolCallToMessage.get(toolCallId) + if (!messageId) return - if (existingToolCall) { - const wasAwaitingInput = existingToolCall.state === 'awaiting-input' + const state = this.getMessageState(messageId) + if (!state) return - // Accumulate arguments from delta - existingToolCall.arguments += chunk.delta || '' + const existingToolCall = state.toolCalls.get(toolCallId) + if (!existingToolCall) return - // Update state - if (wasAwaitingInput && chunk.delta) { - existingToolCall.state = 'input-streaming' - } + const wasAwaitingInput = existingToolCall.state === 'awaiting-input' - // Try to parse the updated arguments - existingToolCall.parsedArguments = this.jsonParser.parse( - existingToolCall.arguments, - ) - - // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: existingToolCall.id, - name: existingToolCall.name, - arguments: existingToolCall.arguments, - state: existingToolCall.state, - }, - ) - this.emitMessagesChange() + // Accumulate arguments from delta + existingToolCall.arguments += chunk.delta || '' - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - existingToolCall.id, - existingToolCall.state, - existingToolCall.arguments, - ) - } + // Update state + if (wasAwaitingInput && chunk.delta) { + existingToolCall.state = 'input-streaming' } + + // Try to parse the updated arguments + existingToolCall.parsedArguments = this.jsonParser.parse( + existingToolCall.arguments, + ) + + // Update UIMessage + this.messages = updateToolCallPart(this.messages, messageId, { + id: existingToolCall.id, + name: existingToolCall.name, + arguments: existingToolCall.arguments, + state: existingToolCall.state, + }) + this.emitMessagesChange() + + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + existingToolCall.id, + existingToolCall.state, + existingToolCall.arguments, + ) } /** @@ -689,11 +914,17 @@ export class StreamProcessor { private handleToolCallEndEvent( chunk: Extract, ): void { + const messageId = this.toolCallToMessage.get(chunk.toolCallId) + if (!messageId) return + + const msgState = this.getMessageState(messageId) + if (!msgState) return + // Transition the tool call to input-complete (the authoritative completion signal) - const existingToolCall = this.toolCalls.get(chunk.toolCallId) + const existingToolCall = msgState.toolCalls.get(chunk.toolCallId) if (existingToolCall && existingToolCall.state !== 'input-complete') { - const index = this.toolCallOrder.indexOf(chunk.toolCallId) - this.completeToolCall(index, existingToolCall) + const index = msgState.toolCallOrder.indexOf(chunk.toolCallId) + this.completeToolCall(messageId, index, existingToolCall) // If TOOL_CALL_END provides parsed input, use it as the canonical parsed // arguments (overrides the accumulated string parse from completeToolCall) if (chunk.input !== undefined) { @@ -701,10 +932,8 @@ export class StreamProcessor { } } - // Update UIMessage if we have a current assistant message and a result - if (this.currentAssistantMessageId && chunk.result) { - const state: ToolResultState = 'complete' - + // Update UIMessage if there's a result + if (chunk.result) { // Step 1: Update the tool-call part's output field (for UI consistency // with client tools — see GitHub issue #176) let output: unknown @@ -720,12 +949,13 @@ export class StreamProcessor { ) // Step 2: Create/update the tool-result part (for LLM conversation history) + const resultState: ToolResultState = 'complete' this.messages = updateToolResultPart( this.messages, - this.currentAssistantMessageId, + messageId, chunk.toolCallId, chunk.result, - state, + resultState, ) this.emitMessagesChange() } @@ -747,6 +977,7 @@ export class StreamProcessor { this.finishReason = chunk.finishReason this.isDone = true this.completeAllToolCalls() + this.finalizeStream() } /** @@ -772,25 +1003,38 @@ export class StreamProcessor { private handleStepFinishedEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + const { messageId, state } = this.ensureAssistantMessage( + this.getActiveAssistantMessageId() ?? undefined, + ) + + const previous = state.thinkingContent + let nextThinking = previous + + // Prefer delta over content + if (chunk.delta && chunk.delta !== '') { + nextThinking = previous + chunk.delta + } else if (chunk.content && chunk.content !== '') { + if (chunk.content.startsWith(previous)) { + nextThinking = chunk.content + } else if (previous.startsWith(chunk.content)) { + nextThinking = previous + } else { + nextThinking = previous + chunk.content + } + } - this.thinkingContent += chunk.delta + state.thinkingContent = nextThinking // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateThinkingPart( - this.messages, - this.currentAssistantMessageId, - this.thinkingContent, - ) - this.emitMessagesChange() + this.messages = updateThinkingPart( + this.messages, + messageId, + state.thinkingContent, + ) + this.emitMessagesChange() - // Emit granular event - this.events.onThinkingUpdate?.( - this.currentAssistantMessageId, - this.thinkingContent, - ) - } + // Emit granular event + this.events.onThinkingUpdate?.(messageId, state.thinkingContent) } /** @@ -806,6 +1050,8 @@ export class StreamProcessor { private handleCustomEvent( chunk: Extract, ): void { + const messageId = this.getActiveAssistantMessageId() + // Handle client tool input availability - trigger client-side execution if (chunk.name === 'tool-input-available' && chunk.data) { const { toolCallId, toolName, input } = chunk.data as { @@ -832,10 +1078,10 @@ export class StreamProcessor { } // Update the tool call part with approval state - if (this.currentAssistantMessageId) { + if (messageId) { this.messages = updateToolCallApproval( this.messages, - this.currentAssistantMessageId, + messageId, toolCallId, approval.id, ) @@ -852,8 +1098,34 @@ export class StreamProcessor { } } + // ============================================ + // Internal Helpers + // ============================================ + + /** + * Detect if an incoming content chunk represents a NEW text segment + */ + private isNewTextSegment( + chunk: Extract, + previous: string, + ): boolean { + // Check if content is present (delta is always defined but may be empty string) + if (chunk.content !== undefined) { + if (chunk.content.length < previous.length) { + return true + } + if ( + !chunk.content.startsWith(previous) && + !previous.startsWith(chunk.content) + ) { + return true + } + } + return false + } + /** - * Complete all tool calls — safety net for stream termination. + * Complete all tool calls across all active messages — safety net for stream termination. * * Called by RUN_FINISHED and finalizeStream(). Force-transitions any tool call * not yet in input-complete state. Handles cases where TOOL_CALL_END was @@ -862,10 +1134,22 @@ export class StreamProcessor { * @see docs/chat-architecture.md#single-shot-tool-call-response — Safety net behavior */ private completeAllToolCalls(): void { - this.toolCalls.forEach((toolCall, id) => { + for (const messageId of this.activeMessageIds) { + this.completeAllToolCallsForMessage(messageId) + } + } + + /** + * Complete all tool calls for a specific message + */ + private completeAllToolCallsForMessage(messageId: string): void { + const state = this.getMessageState(messageId) + if (!state) return + + state.toolCalls.forEach((toolCall, id) => { if (toolCall.state !== 'input-complete') { - const index = this.toolCallOrder.indexOf(id) - this.completeToolCall(index, toolCall) + const index = state.toolCallOrder.indexOf(id) + this.completeToolCall(messageId, index, toolCall) } }) } @@ -874,6 +1158,7 @@ export class StreamProcessor { * Mark a tool call as complete and emit event */ private completeToolCall( + messageId: string, _index: number, toolCall: InternalToolCallState, ): void { @@ -883,31 +1168,25 @@ export class StreamProcessor { toolCall.parsedArguments = this.jsonParser.parse(toolCall.arguments) // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: toolCall.id, - name: toolCall.name, - arguments: toolCall.arguments, - state: 'input-complete', - }, - ) - this.emitMessagesChange() + this.messages = updateToolCallPart(this.messages, messageId, { + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + state: 'input-complete', + }) + this.emitMessagesChange() - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - toolCall.id, - 'input-complete', - toolCall.arguments, - ) - } + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + toolCall.id, + 'input-complete', + toolCall.arguments, + ) } /** - * Emit pending text update. + * Emit pending text update for a specific message. * * Calls updateTextPart() which has critical append-vs-replace logic: * - If last UIMessage part is TextPart → replaces its content (same segment). @@ -915,24 +1194,22 @@ export class StreamProcessor { * * @see docs/chat-architecture.md#uimessage-part-ordering-invariants — Replace vs. push logic */ - private emitTextUpdate(): void { - this.lastEmittedText = this.currentSegmentText + private emitTextUpdateForMessage(messageId: string): void { + const state = this.getMessageState(messageId) + if (!state) return + + state.lastEmittedText = state.currentSegmentText // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateTextPart( - this.messages, - this.currentAssistantMessageId, - this.currentSegmentText, - ) - this.emitMessagesChange() + this.messages = updateTextPart( + this.messages, + messageId, + state.currentSegmentText, + ) + this.emitMessagesChange() - // Emit granular event - this.events.onTextUpdate?.( - this.currentAssistantMessageId, - this.currentSegmentText, - ) - } + // Emit granular event + this.events.onTextUpdate?.(messageId, state.currentSegmentText) } /** @@ -952,81 +1229,116 @@ export class StreamProcessor { * @see docs/chat-architecture.md#single-shot-text-response — Finalization step */ finalizeStream(): void { - // Safety net: complete any remaining tool calls (e.g. on network errors / aborted streams) - this.completeAllToolCalls() + let lastAssistantMessage: UIMessage | undefined - // Emit any pending text if not already emitted - if (this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + // Finalize ALL active messages + for (const messageId of this.activeMessageIds) { + const state = this.getMessageState(messageId) + if (!state) continue + + // Complete any remaining tool calls + this.completeAllToolCallsForMessage(messageId) + + // Emit any pending text if not already emitted + if (state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + + state.isComplete = true + + const msg = this.messages.find((m) => m.id === messageId) + if (msg && msg.role === 'assistant') { + lastAssistantMessage = msg + } } - // Remove the assistant message if it only contains whitespace text - // (no tool calls, no meaningful content). This handles models like Gemini - // that sometimes return just "\n" during auto-continuation. + this.activeMessageIds.clear() + + // Remove whitespace-only assistant messages (handles models like Gemini + // that sometimes return just "\n" during auto-continuation). // Preserve the message on errors so the UI can show error state. - if (this.currentAssistantMessageId && !this.hasError) { - const assistantMessage = this.messages.find( - (m) => m.id === this.currentAssistantMessageId, - ) - if (assistantMessage && this.isWhitespaceOnlyMessage(assistantMessage)) { + if (lastAssistantMessage && !this.hasError) { + if (this.isWhitespaceOnlyMessage(lastAssistantMessage)) { this.messages = this.messages.filter( - (m) => m.id !== this.currentAssistantMessageId, + (m) => m.id !== lastAssistantMessage.id, ) this.emitMessagesChange() - this.currentAssistantMessageId = null return } } - // Emit stream end event (only if a message was actually created) - if (this.currentAssistantMessageId) { - const assistantMessage = this.messages.find( - (m) => m.id === this.currentAssistantMessageId, - ) - if (assistantMessage) { - this.events.onStreamEnd?.(assistantMessage) - } + // Emit stream end for the last assistant message + if (lastAssistantMessage) { + this.events.onStreamEnd?.(lastAssistantMessage) } } /** - * Get completed tool calls in API format + * Get completed tool calls in API format (aggregated across all messages) */ private getCompletedToolCalls(): Array { - return Array.from(this.toolCalls.values()) - .filter((tc) => tc.state === 'input-complete') - .map((tc) => ({ - id: tc.id, - type: 'function' as const, - function: { - name: tc.name, - arguments: tc.arguments, - }, - })) + const result: Array = [] + for (const state of this.messageStates.values()) { + for (const tc of state.toolCalls.values()) { + if (tc.state === 'input-complete') { + result.push({ + id: tc.id, + type: 'function' as const, + function: { + name: tc.name, + arguments: tc.arguments, + }, + }) + } + } + } + return result } /** - * Get current result + * Get current result (aggregated across all messages) */ private getResult(): ProcessorResult { const toolCalls = this.getCompletedToolCalls() + let content = '' + let thinking = '' + + for (const state of this.messageStates.values()) { + content += state.totalTextContent + thinking += state.thinkingContent + } + return { - content: this.totalTextContent, - thinking: this.thinkingContent || undefined, + content, + thinking: thinking || undefined, toolCalls: toolCalls.length > 0 ? toolCalls : undefined, finishReason: this.finishReason, } } /** - * Get current processor state + * Get current processor state (aggregated across all messages) */ getState(): ProcessorState { + let content = '' + let thinking = '' + const toolCalls = new Map() + const toolCallOrder: Array = [] + + for (const state of this.messageStates.values()) { + content += state.totalTextContent + thinking += state.thinkingContent + for (const [id, tc] of state.toolCalls) { + toolCalls.set(id, tc) + } + toolCallOrder.push(...state.toolCallOrder) + } + return { - content: this.totalTextContent, - thinking: this.thinkingContent, - toolCalls: new Map(this.toolCalls), - toolCallOrder: [...this.toolCallOrder], + content, + thinking, + toolCalls, + toolCallOrder, finishReason: this.finishReason, done: this.isDone, } @@ -1056,12 +1368,10 @@ export class StreamProcessor { * Reset stream state (but keep messages) */ private resetStreamState(): void { - this.totalTextContent = '' - this.currentSegmentText = '' - this.lastEmittedText = '' - this.thinkingContent = '' - this.toolCalls.clear() - this.toolCallOrder = [] + this.messageStates.clear() + this.activeMessageIds.clear() + this.toolCallToMessage.clear() + this.pendingManualMessageId = null this.finishReason = null this.hasError = false this.isDone = false @@ -1074,7 +1384,6 @@ export class StreamProcessor { reset(): void { this.resetStreamState() this.messages = [] - this.currentAssistantMessageId = null } /** diff --git a/packages/typescript/ai/src/activities/chat/stream/types.ts b/packages/typescript/ai/src/activities/chat/stream/types.ts index 2a323507..c1806238 100644 --- a/packages/typescript/ai/src/activities/chat/stream/types.ts +++ b/packages/typescript/ai/src/activities/chat/stream/types.ts @@ -45,6 +45,24 @@ export interface ChunkStrategy { reset?: () => void } +/** + * Per-message streaming state. + * Tracks the accumulation of text, tool calls, and thinking content + * for a single message in the stream. + */ +export interface MessageStreamState { + id: string + role: 'user' | 'assistant' | 'system' + totalTextContent: string + currentSegmentText: string + lastEmittedText: string + thinkingContent: string + toolCalls: Map + toolCallOrder: Array + hasToolCallsSinceTextStart: boolean + isComplete: boolean +} + /** * Result from processing a stream */ diff --git a/packages/typescript/ai/src/types.ts b/packages/typescript/ai/src/types.ts index 4d7ca6e5..390abfa9 100644 --- a/packages/typescript/ai/src/types.ts +++ b/packages/typescript/ai/src/types.ts @@ -702,6 +702,7 @@ export type AGUIEventType = | 'TOOL_CALL_END' | 'STEP_STARTED' | 'STEP_FINISHED' + | 'MESSAGES_SNAPSHOT' | 'STATE_SNAPSHOT' | 'STATE_DELTA' | 'CUSTOM' @@ -778,8 +779,8 @@ export interface TextMessageStartEvent extends BaseAGUIEvent { type: 'TEXT_MESSAGE_START' /** Unique identifier for this message */ messageId: string - /** Role is always assistant for generated messages */ - role: 'assistant' + /** Role of the message sender */ + role: 'user' | 'assistant' | 'system' | 'tool' } /** @@ -813,6 +814,8 @@ export interface ToolCallStartEvent extends BaseAGUIEvent { toolCallId: string /** Name of the tool being called */ toolName: string + /** ID of the parent message that initiated this tool call */ + parentMessageId?: string /** Index for parallel tool calls */ index?: number } @@ -869,6 +872,19 @@ export interface StepFinishedEvent extends BaseAGUIEvent { content?: string } +/** + * Emitted to provide a snapshot of all messages in a conversation. + * + * Unlike StateSnapshot (which carries arbitrary application state), + * MessagesSnapshot specifically delivers the conversation transcript. + * This is a first-class AG-UI event type. + */ +export interface MessagesSnapshotEvent extends BaseAGUIEvent { + type: 'MESSAGES_SNAPSHOT' + /** Complete array of messages in the conversation */ + messages: Array +} + /** * Emitted to provide a full state snapshot. */ @@ -913,6 +929,7 @@ export type AGUIEvent = | ToolCallEndEvent | StepStartedEvent | StepFinishedEvent + | MessagesSnapshotEvent | StateSnapshotEvent | StateDeltaEvent | CustomEvent diff --git a/packages/typescript/ai/tests/stream-processor.test.ts b/packages/typescript/ai/tests/stream-processor.test.ts index ddb7f812..033afabc 100644 --- a/packages/typescript/ai/tests/stream-processor.test.ts +++ b/packages/typescript/ai/tests/stream-processor.test.ts @@ -621,8 +621,8 @@ describe('StreamProcessor', () => { processor.processChunk(ev.textContent('First segment')) processor.processChunk(ev.toolStart('tc-1', 'search')) processor.processChunk(ev.toolEnd('tc-1', 'search', { input: {} })) - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent('Second segment', 'msg-2')) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent('Second segment')) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -649,10 +649,10 @@ describe('StreamProcessor', () => { ev.toolEnd('call_1', 'getWeather', { result: '{"temp":"72F"}' }), ) - // Second adapter stream: more text - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent("It's 72F in NYC.", 'msg-2')) - processor.processChunk(ev.textEnd('msg-2')) + // Second adapter stream: more text (same message) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent("It's 72F in NYC.")) + processor.processChunk(ev.textEnd()) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -685,9 +685,9 @@ describe('StreamProcessor', () => { processor.processChunk(ev.textEnd()) processor.processChunk(ev.toolStart('tc-1', 'tool')) processor.processChunk(ev.toolEnd('tc-1', 'tool')) - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent('After', 'msg-2')) - processor.processChunk(ev.textEnd('msg-2')) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent('After')) + processor.processChunk(ev.textEnd()) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -1798,4 +1798,658 @@ describe('StreamProcessor', () => { expect(state2.toolCallOrder).toEqual(['tc-1']) }) }) + + describe('TEXT_MESSAGE_START', () => { + it('should create a message with correct role and messageId', () => { + const processor = new StreamProcessor() + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('msg-1') + expect(messages[0]?.role).toBe('assistant') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + + it('should create a user message via TEXT_MESSAGE_START', () => { + const processor = new StreamProcessor() + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'user-msg-1', + role: 'user', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'user-msg-1', + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('user-msg-1') + expect(messages[0]?.role).toBe('user') + }) + + it('should emit onStreamStart when a new message arrives', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ events: { onStreamStart } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + }) + }) + + describe('TEXT_MESSAGE_END', () => { + it('should not emit onStreamEnd (that happens in finalizeStream)', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello world', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END means "text segment done", not "message done" + // onStreamEnd fires from finalizeStream(), not TEXT_MESSAGE_END + expect(onStreamEnd).not.toHaveBeenCalled() + + processor.finalizeStream() + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + const endMessage = onStreamEnd.mock.calls[0]![0] as UIMessage + expect(endMessage.id).toBe('msg-1') + expect(endMessage.parts[0]).toEqual({ + type: 'text', + content: 'Hello world', + }) + }) + + it('should emit pending text on TEXT_MESSAGE_END', () => { + const onTextUpdate = vi.fn() + // Use a strategy that never emits during streaming + const processor = new StreamProcessor({ + events: { onTextUpdate }, + chunkStrategy: { + shouldEmit: () => false, + }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + // Text not emitted yet due to strategy + expect(onTextUpdate).not.toHaveBeenCalled() + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END should flush pending text + expect(onTextUpdate).toHaveBeenCalledWith('msg-1', 'Hello') + }) + }) + + describe('interleaved messages', () => { + it('should handle two interleaved assistant messages', () => { + const onMessagesChange = vi.fn() + const processor = new StreamProcessor({ + events: { onMessagesChange }, + }) + + // Start two messages + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-b', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Interleave content + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-a', + delta: 'Hello from A', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-b', + delta: 'Hello from B', + timestamp: Date.now(), + } as StreamChunk) + + // End both + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-b', + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(2) + + expect(messages[0]?.id).toBe('msg-a') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello from A', + }) + + expect(messages[1]?.id).toBe('msg-b') + expect(messages[1]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello from B', + }) + }) + + it('should emit onStreamEnd on finalizeStream (not on TEXT_MESSAGE_END)', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamEnd }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-b', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-a', + delta: 'A', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END does not fire onStreamEnd + expect(onStreamEnd).not.toHaveBeenCalled() + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-b', + delta: 'B', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-b', + timestamp: Date.now(), + } as StreamChunk) + + // Still not fired + expect(onStreamEnd).not.toHaveBeenCalled() + + // finalizeStream fires onStreamEnd for the last assistant message + processor.finalizeStream() + expect(onStreamEnd).toHaveBeenCalledTimes(1) + }) + }) + + describe('startAssistantMessage + TEXT_MESSAGE_START dedup', () => { + it('should associate TEXT_MESSAGE_START with pending manual message (different ID)', () => { + const processor = new StreamProcessor() + processor.startAssistantMessage() + + // Server sends TEXT_MESSAGE_START with a different ID + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'server-msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Should have only one message (not two) + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + + // The message should have been updated to the server's ID + expect(messages[0]?.id).toBe('server-msg-1') + expect(messages[0]?.role).toBe('assistant') + + // Content should route to the correct message + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'server-msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + expect(processor.getMessages()[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + + it('should associate TEXT_MESSAGE_START with pending manual message (same ID)', () => { + const processor = new StreamProcessor() + processor.startAssistantMessage('my-msg-id') + + // Server sends TEXT_MESSAGE_START with the same ID + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'my-msg-id', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Should still have only one message + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('my-msg-id') + }) + + it('should work when TEXT_MESSAGE_START arrives without startAssistantMessage', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamStart }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('msg-1') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + }) + + describe('backward compat: ensureAssistantMessage auto-creation', () => { + it('should emit onStreamStart when auto-creating a message from content event', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamStart }, + }) + + // No TEXT_MESSAGE_START or startAssistantMessage — content arrives directly + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'auto-msg', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + expect(processor.getMessages()).toHaveLength(1) + expect(processor.getMessages()[0]?.role).toBe('assistant') + }) + }) + + describe('backward compat: startAssistantMessage without TEXT_MESSAGE_START', () => { + it('should still work when only startAssistantMessage is used', () => { + const processor = new StreamProcessor() + const msgId = processor.startAssistantMessage() + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'some-other-id', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'some-other-id', + delta: ' world', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe(msgId) + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello world', + }) + }) + }) + + describe('MESSAGES_SNAPSHOT', () => { + it('should hydrate messages and emit onMessagesChange', () => { + const onMessagesChange = vi.fn() + const processor = new StreamProcessor({ + events: { onMessagesChange }, + }) + + const snapshotMessages: Array = [ + { + id: 'snap-1', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + { + id: 'snap-2', + role: 'assistant', + parts: [{ type: 'text', content: 'Hi there!' }], + createdAt: new Date(), + }, + ] + + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: snapshotMessages, + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(2) + expect(messages[0]?.id).toBe('snap-1') + expect(messages[0]?.role).toBe('user') + expect(messages[1]?.id).toBe('snap-2') + expect(messages[1]?.role).toBe('assistant') + expect(onMessagesChange).toHaveBeenCalled() + }) + + it('should replace existing messages (not append)', () => { + const processor = new StreamProcessor() + + // Add an initial message + processor.addUserMessage('First message') + expect(processor.getMessages()).toHaveLength(1) + + // Snapshot replaces all messages + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: [ + { + id: 'snap-1', + role: 'assistant', + parts: [{ type: 'text', content: 'Snapshot content' }], + createdAt: new Date(), + }, + ], + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('snap-1') + expect(messages[0]?.role).toBe('assistant') + }) + }) + + describe('per-message tool calls', () => { + it('should route tool calls to the correct message via parentMessageId', () => { + const processor = new StreamProcessor() + + // Create two messages + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Tool call on msg-a + processor.processChunk({ + type: 'TOOL_CALL_START', + toolCallId: 'tc-1', + toolName: 'myTool', + parentMessageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_ARGS', + toolCallId: 'tc-1', + delta: '{"arg": "val"}', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_END', + toolCallId: 'tc-1', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + + const toolCallPart = messages[0]?.parts.find( + (p) => p.type === 'tool-call', + ) + expect(toolCallPart).toBeDefined() + expect(toolCallPart?.type).toBe('tool-call') + if (toolCallPart?.type === 'tool-call') { + expect(toolCallPart.name).toBe('myTool') + expect(toolCallPart.state).toBe('input-complete') + } + }) + }) + + describe('double onStreamEnd guard', () => { + it('should fire onStreamEnd exactly once when RUN_FINISHED arrives before TEXT_MESSAGE_END', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + // RUN_FINISHED fires first — calls finalizeStream which sets isComplete and fires onStreamEnd + processor.processChunk({ + type: 'RUN_FINISHED', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk) + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + + // TEXT_MESSAGE_END arrives after — should NOT fire onStreamEnd again + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + }) + }) + + describe('MESSAGES_SNAPSHOT resets transient state', () => { + it('should reset stale state and process subsequent stream events correctly', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + // Simulate an active streaming session + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-old', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-old', + delta: 'Old content', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_START', + toolCallId: 'tc-old', + toolName: 'oldTool', + parentMessageId: 'msg-old', + timestamp: Date.now(), + } as StreamChunk) + + // MESSAGES_SNAPSHOT replaces everything (e.g., on reconnection) + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: [ + { + id: 'snap-user', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + ], + timestamp: Date.now(), + } as StreamChunk) + + // Verify old messages are replaced + const messagesAfterSnapshot = processor.getMessages() + expect(messagesAfterSnapshot).toHaveLength(1) + expect(messagesAfterSnapshot[0]?.id).toBe('snap-user') + + // New stream events should be processed correctly without stale state + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-new', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-new', + delta: 'New content', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-new', + timestamp: Date.now(), + } as StreamChunk) + + const finalMessages = processor.getMessages() + expect(finalMessages).toHaveLength(2) + expect(finalMessages[1]?.id).toBe('msg-new') + expect(finalMessages[1]?.parts[0]).toEqual({ + type: 'text', + content: 'New content', + }) + + // onStreamEnd fires from finalizeStream, not TEXT_MESSAGE_END + expect(onStreamEnd).not.toHaveBeenCalled() + processor.finalizeStream() + expect(onStreamEnd).toHaveBeenCalledTimes(1) + expect(onStreamEnd.mock.calls[0]![0].id).toBe('msg-new') + }) + }) })