diff --git a/.changeset/beige-seals-tie.md b/.changeset/beige-seals-tie.md new file mode 100644 index 000000000000..6f3a98b088e9 --- /dev/null +++ b/.changeset/beige-seals-tie.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat(ai): improved text splitter diff --git a/packages/ai/core/generate-text/smooth-stream.test.ts b/packages/ai/core/generate-text/smooth-stream.test.ts index 7d5bbd5db00d..39bd5ad657e2 100644 --- a/packages/ai/core/generate-text/smooth-stream.test.ts +++ b/packages/ai/core/generate-text/smooth-stream.test.ts @@ -46,6 +46,7 @@ describe('smoothStream', () => { textDelta: 'Hello, ', type: 'text-delta', }, + 'delay 10', { textDelta: 'world!', type: 'text-delta', @@ -107,6 +108,7 @@ describe('smoothStream', () => { textDelta: 'example ', type: 'text-delta', }, + 'delay 10', { textDelta: 'text.', type: 'text-delta', @@ -146,14 +148,14 @@ describe('smoothStream', () => { }, 'delay 10', { - textDelta: 'line \n\n', + textDelta: 'line \n\n ', type: 'text-delta', }, 'delay 10', { - // note: leading whitespace is included here - // because it is part of the new chunk: - textDelta: ' Multiple ', + // note: leading whitespace not included here + // because it is part of the last chunk: + textDelta: 'Multiple ', type: 'text-delta', }, 'delay 10', @@ -161,6 +163,7 @@ describe('smoothStream', () => { textDelta: 'spaces\n ', type: 'text-delta', }, + 'delay 10', { textDelta: 'Indented', type: 'text-delta', @@ -223,6 +226,7 @@ describe('smoothStream', () => { "textDelta": "in ", "type": "text-delta", }, + "delay 10", { "textDelta": "London.", "type": "text-delta", @@ -299,6 +303,7 @@ describe('smoothStream', () => { "textDelta": "in ", "type": "text-delta", }, + "delay 10", { "textDelta": "London.", "type": "text-delta", @@ -333,6 +338,127 @@ describe('smoothStream', () => { ] `); }); + + it('should support kanji', async () => { + const stream = convertArrayToReadableStream([ + { textDelta: 'Vercel', type: 'text-delta' }, + { textDelta: 'はサ', type: 'text-delta' }, + { textDelta: 'サーバーレス', type: 'text-delta' }, + { textDelta: 'の', type: 'text-delta' }, + { textDelta: 'フロントエンド', type: 'text-delta' }, + { textDelta: 'Hello, world!', type: 'text-delta' }, + { type: 'step-finish' }, + { type: 'finish' }, + ]).pipeThrough( + smoothStream({ + delayInMs: 10, + _internal: { delay }, + })({ tools: {} }), + ); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + "delay 10", + { + "textDelta": "Vercelは", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "サ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "サ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ー", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "バ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ー", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "レ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ス", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "の", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "フ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ロ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ン", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ト", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "エ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ン", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "ド", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "Hello, ", + "type": "text-delta", + }, + "delay 10", + { + "textDelta": "world!", + "type": "text-delta", + }, + { + "type": "step-finish", + }, + { + "type": "finish", + }, + ] + `); + }); }); describe('line chunking', () => { @@ -450,9 +576,7 @@ describe('smoothStream', () => { 'delay 10', { textDelta: 'o', type: 'text-delta' }, 'delay 10', - { textDelta: ',', type: 'text-delta' }, - 'delay 10', - { textDelta: ' ', type: 'text-delta' }, + { textDelta: ', ', type: 'text-delta' }, 'delay 10', { textDelta: 'w', type: 'text-delta' }, 'delay 10', @@ -491,6 +615,7 @@ describe('smoothStream', () => { textDelta: 'Hello, ', type: 'text-delta', }, + 'delay 10', { textDelta: 'world!', type: 'text-delta', @@ -524,6 +649,7 @@ describe('smoothStream', () => { textDelta: 'Hello, ', type: 'text-delta', }, + 'delay 20', { textDelta: 'world!', type: 'text-delta', diff --git a/packages/ai/core/generate-text/smooth-stream.ts b/packages/ai/core/generate-text/smooth-stream.ts index 3cca81c15ee9..31c7f9243c2e 100644 --- a/packages/ai/core/generate-text/smooth-stream.ts +++ b/packages/ai/core/generate-text/smooth-stream.ts @@ -2,10 +2,12 @@ import { InvalidArgumentError } from '@ai-sdk/provider'; import { delay as originalDelay } from '@ai-sdk/provider-utils'; import { TextStreamPart } from './stream-text-result'; import { ToolSet } from './tool-set'; +import { TextSplit, splitText } from './text-splitter'; const CHUNKING_REGEXPS = { - word: /\s*\S+\s+/m, - line: /[^\n]*\n/m, + character: /(?!\s)(?=.)/g, + word: /[\u4E00-\u9FFF\u3040-\u309F\u30A0-\u30FF]|\s+/gm, + line: /\r\n|\r|\n/g, }; /** @@ -22,7 +24,7 @@ export function smoothStream({ _internal: { delay = originalDelay } = {}, }: { delayInMs?: number | null; - chunking?: 'word' | 'line' | RegExp; + chunking?: 'character' | 'word' | 'line' | { split: string } | RegExp; /** * Internal. For test use only. May change without notice. */ @@ -32,10 +34,14 @@ export function smoothStream({ } = {}): (options: { tools: TOOLS; }) => TransformStream, TextStreamPart> { - const chunkingRegexp = - typeof chunking === 'string' ? CHUNKING_REGEXPS[chunking] : chunking; + const chunker = + typeof chunking === 'object' && 'split' in chunking + ? chunking.split + : typeof chunking === 'string' + ? CHUNKING_REGEXPS[chunking] + : chunking; - if (chunkingRegexp == null) { + if (chunker == null) { throw new InvalidArgumentError({ argument: 'chunking', message: `Chunking must be "word" or "line" or a RegExp. Received: ${chunking}`, @@ -44,12 +50,24 @@ export function smoothStream({ return () => { let buffer = ''; + let lastSplits: TextSplit[] = []; + let lastIndex = 0; return new TransformStream, TextStreamPart>({ async transform(chunk, controller) { + const lastSplit = lastSplits.at(-1); + if (chunk.type !== 'text-delta') { - if (buffer.length > 0) { - controller.enqueue({ type: 'text-delta', textDelta: buffer }); + if (lastSplits.length > 1 && delayInMs) { + await delay(delayInMs); + } + + if (lastSplit) { + controller.enqueue({ + type: 'text-delta', + textDelta: lastSplit.text, + }); + lastSplits = []; buffer = ''; } @@ -59,14 +77,24 @@ export function smoothStream({ buffer += chunk.textDelta; - let match; - while ((match = chunkingRegexp.exec(buffer)) != null) { - const chunk = match[0]; - controller.enqueue({ type: 'text-delta', textDelta: chunk }); - buffer = buffer.slice(chunk.length); + const splits = splitText(buffer, chunker); - await delay(delayInMs); + // If there's a new split with the start index greater than the last index, + // push the new split(s) and delay. + const newSplitIndex = splits.findIndex( + split => !lastSplit || split.start >= lastIndex, + ); + + if (newSplitIndex !== -1) { + for (let i = newSplitIndex; i < splits.length - 1; i++) { + const split = splits[i]; + controller.enqueue({ type: 'text-delta', textDelta: split.text }); + lastIndex = split.end; + await delay(delayInMs); + } } + + lastSplits = splits; }, }); }; diff --git a/packages/ai/core/generate-text/text-splitter.ts b/packages/ai/core/generate-text/text-splitter.ts new file mode 100644 index 000000000000..94fc1a825e53 --- /dev/null +++ b/packages/ai/core/generate-text/text-splitter.ts @@ -0,0 +1,84 @@ +export interface TextSplit { + start: number; + end: number; + text: string; +} + +export function splitText( + text: string, + splitter: RegExp | string, +): TextSplit[] { + const splits: TextSplit[] = []; + let lastIndex = 0; + + function getNextMatch() { + if (lastIndex === text.length) { + return null; + } + + if (typeof splitter === 'string') { + if (splitter === '') { + return { index: lastIndex, 0: text.slice(lastIndex, lastIndex + 1) }; + } + + const index = text.indexOf(splitter, lastIndex); + return index === -1 ? null : { index, 0: splitter }; + } + + const regex = splitter.flags.includes('g') + ? splitter + : new RegExp(splitter.source, `${splitter.flags}g`); + regex.lastIndex = lastIndex; + const match = regex.exec(text); + + // If it's a zero-width match, we need to find the next match position + if (match && match[0] === '') { + regex.lastIndex = match.index + 1; + const nextMatch = regex.exec(text); + return { ...match, endIndex: nextMatch ? nextMatch.index : text.length }; + } + + return match; + } + + let match: ReturnType; + + while ((match = getNextMatch())) { + const matchEndIndex = + 'endIndex' in match ? match.endIndex : match.index + match[0].length; + + const end = matchEndIndex; + + if (end > lastIndex) { + const segment = text.slice(lastIndex, end); + + if (!segment.trim()) { + if (splits.length > 0) { + const previousSplit = splits[splits.length - 1]; + if (previousSplit) { + previousSplit.end = end; + previousSplit.text = text.slice(previousSplit.start, end); + } + } + } else { + splits.push({ + start: lastIndex, + end, + text: segment, + }); + } + } + + lastIndex = matchEndIndex; + } + + if (lastIndex < text.length) { + splits.push({ + start: lastIndex, + end: text.length, + text: text.slice(lastIndex), + }); + } + + return splits; +}