Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/beige-seals-tie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat(ai): improved text splitter
140 changes: 133 additions & 7 deletions packages/ai/core/generate-text/smooth-stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ describe('smoothStream', () => {
textDelta: 'Hello, ',
type: 'text-delta',
},
'delay 10',
{
textDelta: 'world!',
type: 'text-delta',
Expand Down Expand Up @@ -107,6 +108,7 @@ describe('smoothStream', () => {
textDelta: 'example ',
type: 'text-delta',
},
'delay 10',
{
textDelta: 'text.',
type: 'text-delta',
Expand Down Expand Up @@ -146,21 +148,22 @@ describe('smoothStream', () => {
},
'delay 10',
{
textDelta: 'line \n\n',
textDelta: 'line \n\n ',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will create an additional delay

type: 'text-delta',
},
'delay 10',
{
// note: leading whitespace is included here
// because it is part of the new chunk:
textDelta: ' Multiple ',
// note: leading whitespace not included here
// because it is part of the last chunk:
textDelta: 'Multiple ',
type: 'text-delta',
},
'delay 10',
{
textDelta: 'spaces\n ',
type: 'text-delta',
},
'delay 10',
{
textDelta: 'Indented',
type: 'text-delta',
Expand Down Expand Up @@ -223,6 +226,7 @@ describe('smoothStream', () => {
"textDelta": "in ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "London.",
"type": "text-delta",
Expand Down Expand Up @@ -299,6 +303,7 @@ describe('smoothStream', () => {
"textDelta": "in ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "London.",
"type": "text-delta",
Expand Down Expand Up @@ -333,6 +338,127 @@ describe('smoothStream', () => {
]
`);
});

it('should support kanji', async () => {
const stream = convertArrayToReadableStream([
{ textDelta: 'Vercel', type: 'text-delta' },
{ textDelta: 'はサ', type: 'text-delta' },
{ textDelta: 'サーバーレス', type: 'text-delta' },
{ textDelta: 'の', type: 'text-delta' },
{ textDelta: 'フロントエンド', type: 'text-delta' },
{ textDelta: 'Hello, world!', type: 'text-delta' },
{ type: 'step-finish' },
{ type: 'finish' },
]).pipeThrough(
smoothStream({
delayInMs: 10,
_internal: { delay },
})({ tools: {} }),
);

await consumeStream(stream);

expect(events).toMatchInlineSnapshot(`
[
"delay 10",
{
"textDelta": "Vercelは",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "サ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "サ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ー",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "バ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ー",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "レ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ス",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "の",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "フ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ロ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ン",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ト",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "エ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ン",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "ド",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "Hello, ",
"type": "text-delta",
},
"delay 10",
{
"textDelta": "world!",
"type": "text-delta",
},
{
"type": "step-finish",
},
{
"type": "finish",
},
]
`);
});
});

describe('line chunking', () => {
Expand Down Expand Up @@ -450,9 +576,7 @@ describe('smoothStream', () => {
'delay 10',
{ textDelta: 'o', type: 'text-delta' },
'delay 10',
{ textDelta: ',', type: 'text-delta' },
'delay 10',
{ textDelta: ' ', type: 'text-delta' },
{ textDelta: ', ', type: 'text-delta' },
'delay 10',
{ textDelta: 'w', type: 'text-delta' },
'delay 10',
Expand Down Expand Up @@ -491,6 +615,7 @@ describe('smoothStream', () => {
textDelta: 'Hello, ',
type: 'text-delta',
},
'delay 10',
{
textDelta: 'world!',
type: 'text-delta',
Expand Down Expand Up @@ -524,6 +649,7 @@ describe('smoothStream', () => {
textDelta: 'Hello, ',
type: 'text-delta',
},
'delay 20',
{
textDelta: 'world!',
type: 'text-delta',
Expand Down
56 changes: 42 additions & 14 deletions packages/ai/core/generate-text/smooth-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import { InvalidArgumentError } from '@ai-sdk/provider';
import { delay as originalDelay } from '@ai-sdk/provider-utils';
import { TextStreamPart } from './stream-text-result';
import { ToolSet } from './tool-set';
import { TextSplit, splitText } from './text-splitter';

const CHUNKING_REGEXPS = {
word: /\s*\S+\s+/m,
line: /[^\n]*\n/m,
character: /(?!\s)(?=.)/g,
word: /[\u4E00-\u9FFF\u3040-\u309F\u30A0-\u30FF]|\s+/gm,
line: /\r\n|\r|\n/g,
Comment on lines +8 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will happen to custom user regexp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll add more tests for custom chunking

};

/**
Expand All @@ -22,7 +24,7 @@ export function smoothStream<TOOLS extends ToolSet>({
_internal: { delay = originalDelay } = {},
}: {
delayInMs?: number | null;
chunking?: 'word' | 'line' | RegExp;
chunking?: 'character' | 'word' | 'line' | { split: string } | RegExp;
/**
* Internal. For test use only. May change without notice.
*/
Expand All @@ -32,10 +34,14 @@ export function smoothStream<TOOLS extends ToolSet>({
} = {}): (options: {
tools: TOOLS;
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>> {
const chunkingRegexp =
typeof chunking === 'string' ? CHUNKING_REGEXPS[chunking] : chunking;
const chunker =
typeof chunking === 'object' && 'split' in chunking
? chunking.split
: typeof chunking === 'string'
? CHUNKING_REGEXPS[chunking]
: chunking;

if (chunkingRegexp == null) {
if (chunker == null) {
throw new InvalidArgumentError({
argument: 'chunking',
message: `Chunking must be "word" or "line" or a RegExp. Received: ${chunking}`,
Expand All @@ -44,12 +50,24 @@ export function smoothStream<TOOLS extends ToolSet>({

return () => {
let buffer = '';
let lastSplits: TextSplit[] = [];
let lastIndex = 0;

return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(chunk, controller) {
const lastSplit = lastSplits.at(-1);

if (chunk.type !== 'text-delta') {
if (buffer.length > 0) {
controller.enqueue({ type: 'text-delta', textDelta: buffer });
if (lastSplits.length > 1 && delayInMs) {
await delay(delayInMs);
}

if (lastSplit) {
controller.enqueue({
type: 'text-delta',
textDelta: lastSplit.text,
});
lastSplits = [];
buffer = '';
}

Expand All @@ -59,14 +77,24 @@ export function smoothStream<TOOLS extends ToolSet>({

buffer += chunk.textDelta;

let match;
while ((match = chunkingRegexp.exec(buffer)) != null) {
const chunk = match[0];
controller.enqueue({ type: 'text-delta', textDelta: chunk });
buffer = buffer.slice(chunk.length);
const splits = splitText(buffer, chunker);

await delay(delayInMs);
// If there's a new split with the start index greater than the last index,
// push the new split(s) and delay.
const newSplitIndex = splits.findIndex(
split => !lastSplit || split.start >= lastIndex,
);

if (newSplitIndex !== -1) {
for (let i = newSplitIndex; i < splits.length - 1; i++) {
const split = splits[i];
controller.enqueue({ type: 'text-delta', textDelta: split.text });
lastIndex = split.end;
await delay(delayInMs);
}
}

lastSplits = splits;
},
});
};
Expand Down
Loading
Loading