Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/seven-fans-exist.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@livekit/agents": patch
---

feat(stt): add diarization capabilities and speaker_id support
3 changes: 3 additions & 0 deletions agents/src/inference/api_protos.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ export const sttWordSchema = z.object({
start: z.number().optional().default(0),
end: z.number().optional().default(0),
confidence: z.number().optional().default(0.0),
speaker_id: z.string().nullable().optional(),
extra: z.unknown().nullable().optional(),
});

Expand All @@ -104,6 +105,7 @@ export const sttInterimTranscriptEventSchema = z.object({
duration: z.number().optional().default(0),
confidence: z.number().optional().default(1.0),
words: z.array(sttWordSchema).optional().default([]),
speaker_id: z.string().nullable().optional(),
extra: z.unknown().nullable().optional(),
});

Expand All @@ -117,6 +119,7 @@ export const sttFinalTranscriptEventSchema = z.object({
duration: z.number().optional().default(0),
confidence: z.number().optional().default(1.0),
words: z.array(sttWordSchema).optional().default([]),
speaker_id: z.string().nullable().optional(),
extra: z.unknown().nullable().optional(),
});

Expand Down
2 changes: 2 additions & 0 deletions agents/src/inference/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ export {
type STTModels,
type ModelWithLanguage as STTModelString,
type STTOptions,
type XaiSTTModels,
type XaiOptions as XaiSTTOptions,
} from './stt.js';

export {
Expand Down
75 changes: 74 additions & 1 deletion agents/src/inference/stt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ import { beforeAll, describe, expect, it } from 'vitest';
import { normalizeLanguage } from '../language.js';
import { initializeLogger } from '../log.js';
import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js';
import { STT, type STTFallbackModel, normalizeSTTFallback, parseSTTModelString } from './stt.js';
import {
STT,
type STTFallbackModel,
type XaiSTTModels,
normalizeSTTFallback,
parseSTTModelString,
} from './stt.js';

beforeAll(() => {
initializeLogger({ level: 'silent', pretty: false });
Expand Down Expand Up @@ -251,3 +257,70 @@ describe('STT constructor fallback and connOptions', () => {
expect(stt['opts'].connOptions!.retryIntervalMs).toBe(2000);
});
});

describe('STT diarization capabilities', () => {
it('no diarization by default', () => {
const stt = makeStt();
expect(stt.capabilities.diarization).toBe(false);
});

it('diarization enabled with deepgram diarize option', () => {
const stt = makeStt({ modelOptions: { diarize: true } });
expect(stt.capabilities.diarization).toBe(true);
});

it('diarization disabled with diarize false', () => {
const stt = makeStt({ modelOptions: { diarize: false } });
expect(stt.capabilities.diarization).toBe(false);
});

it('diarization enabled with assemblyai speaker_labels', () => {
const stt = makeStt({
model: 'assemblyai/universal-streaming',
modelOptions: { speaker_labels: true },
});
expect(stt.capabilities.diarization).toBe(true);
});

it('updateOptions toggles diarization capability', () => {
const stt = makeStt();
expect(stt.capabilities.diarization).toBe(false);

stt.updateOptions({ modelOptions: { diarize: true } as Record<string, unknown> });
expect(stt.capabilities.diarization).toBe(true);

stt.updateOptions({ modelOptions: { diarize: false } as Record<string, unknown> });
expect(stt.capabilities.diarization).toBe(false);
});

it('diarization enabled with xai diarize option', () => {
const stt = makeStt({
model: 'xai/stt-1' satisfies XaiSTTModels,
modelOptions: { diarize: true },
});
expect(stt.capabilities.diarization).toBe(true);
});

it('updateOptions preserves unrelated flags when merging', () => {
const stt = makeStt({ modelOptions: { diarize: true } });
expect(stt.capabilities.diarization).toBe(true);

stt.updateOptions({ modelOptions: { endpointing: 500 } as Record<string, unknown> });
expect(stt['opts'].modelOptions).toHaveProperty('diarize', true);
expect(stt['opts'].modelOptions).toHaveProperty('endpointing', 500);
expect(stt.capabilities.diarization).toBe(true);
});

it('updateOptions merges modelOptions on associated streams', () => {
const stt = makeStt({ modelOptions: { diarize: true } });
const stream = stt.stream();

stt.updateOptions({ modelOptions: { endpointing: 500 } as Record<string, unknown> });

// The stream's local modelOptions must be the merged object, not the partial.
expect(stream['opts'].modelOptions).toHaveProperty('diarize', true);
expect(stream['opts'].modelOptions).toHaveProperty('endpointing', 500);

stream.close();
});
});
70 changes: 64 additions & 6 deletions agents/src/inference/stt.ts
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export type AssemblyaiModels =

export type ElevenlabsSTTModels = 'elevenlabs/scribe_v2_realtime';

export type XaiSTTModels = 'xai/stt-1';

export interface CartesiaOptions {
/** Minimum volume threshold. Default: not specified. */
min_volume?: number;
Expand Down Expand Up @@ -71,6 +73,8 @@ export interface DeepgramOptions {
numerals?: boolean;
/** Opt out of model improvement program. */
mip_opt_out?: boolean;
/** Enable speaker diarization. Default: false. */
diarize?: boolean;
/** Eager end-of-turn threshold (0.0–1.0). Enables preflight transcripts for preemptive generation. */
eager_eot_threshold?: number;
}
Expand All @@ -86,6 +90,19 @@ export interface AssemblyAIOptions {
max_turn_silence?: number;
/** Key terms prompt for recognition. Default: not specified. */
keyterms_prompt?: string[];
/** Enable speaker diarization. Default: false. */
speaker_labels?: boolean;
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
}

export interface XaiOptions {
/** Enable speaker diarization. Default: false. */
diarize?: boolean;
/** Silence duration in ms before utterance-final (0-5000). */
endpointing?: number;
/** Enable Inverse Text Normalization. Requires language. */
format?: boolean;
/** Default true; set false to opt out of interim transcripts. */
interim_results?: boolean;
}

export type STTLanguages =
Expand All @@ -100,7 +117,19 @@ export type STTLanguages =
| 'hi'
| AnyString;

type _STTModels = DeepgramModels | CartesiaModels | AssemblyaiModels | ElevenlabsSTTModels;
const DIARIZATION_EXTRA_KEYS = ['diarize', 'speaker_labels'] as const;
Comment thread
toubatbrian marked this conversation as resolved.

function diarizationEnabled(extraKwargs: Record<string, unknown> | undefined): boolean {
if (!extraKwargs) return false;
return DIARIZATION_EXTRA_KEYS.some((key) => Boolean(extraKwargs[key]));
}

type _STTModels =
| DeepgramModels
| CartesiaModels
| AssemblyaiModels
| ElevenlabsSTTModels
| XaiSTTModels;

export type STTModels = _STTModels | 'auto' | AnyString;

Expand All @@ -112,7 +141,9 @@ export type STTOptions<TModel extends STTModels> = TModel extends DeepgramModels
? CartesiaOptions
: TModel extends AssemblyaiModels
? AssemblyAIOptions
: Record<string, unknown>;
: TModel extends XaiSTTModels
? XaiOptions
: Record<string, unknown>;

/** A fallback model with optional extra configuration. Extra fields are passed through to the provider. */
export interface STTFallbackModel {
Expand Down Expand Up @@ -191,7 +222,13 @@ export class STT<TModel extends STTModels> extends BaseSTT {
fallback?: STTFallbackModelType | STTFallbackModelType[];
connOptions?: APIConnectOptions;
}) {
super({ streaming: true, interimResults: true, alignedTranscript: 'word' });
const modelOptions = (opts?.modelOptions ?? {}) as STTOptions<TModel>;
super({
streaming: true,
interimResults: true,
alignedTranscript: 'word',
diarization: diarizationEnabled(modelOptions as Record<string, unknown>),
});

const {
model,
Expand All @@ -201,7 +238,6 @@ export class STT<TModel extends STTModels> extends BaseSTT {
sampleRate = DEFAULT_SAMPLE_RATE,
apiKey,
apiSecret,
modelOptions = {} as STTOptions<TModel>,
fallback,
connOptions,
} = opts || {};
Expand Down Expand Up @@ -272,13 +308,26 @@ export class STT<TModel extends STTModels> extends BaseSTT {
throw new Error('LiveKit STT does not support batch recognition, use stream() instead');
}

updateOptions(opts: Partial<Pick<InferenceSTTOptions<TModel>, 'model' | 'language'>>): void {
updateOptions(
opts: Partial<Pick<InferenceSTTOptions<TModel>, 'model' | 'language' | 'modelOptions'>>,
): void {
const mergedModelOptions = opts.modelOptions
? ({ ...this.opts.modelOptions, ...opts.modelOptions } as STTOptions<TModel>)
: this.opts.modelOptions;

this.opts = {
...this.opts,
...opts,
language: opts.language !== undefined ? normalizeLanguage(opts.language) : this.opts.language,
modelOptions: mergedModelOptions,
};

if (opts.modelOptions) {
this.updateCapabilities({
diarization: diarizationEnabled(this.opts.modelOptions as Record<string, unknown>),
});
}

for (const stream of this.streams) {
stream.updateOptions(opts);
}
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -377,11 +426,18 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
return 'inference.SpeechStream';
}

updateOptions(opts: Partial<Pick<InferenceSTTOptions<TModel>, 'model' | 'language'>>): void {
updateOptions(
opts: Partial<Pick<InferenceSTTOptions<TModel>, 'model' | 'language' | 'modelOptions'>>,
): void {
const mergedModelOptions = opts.modelOptions
? ({ ...this.opts.modelOptions, ...opts.modelOptions } as STTOptions<TModel>)
: this.opts.modelOptions;

this.opts = {
...this.opts,
...opts,
language: opts.language !== undefined ? normalizeLanguage(opts.language) : this.opts.language,
modelOptions: mergedModelOptions,
};
this.reconnectEvent.set();
}
Expand Down Expand Up @@ -617,6 +673,7 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
endTime: this.startTimeOffset + data.start + data.duration,
confidence: data.confidence,
text,
speakerId: data.speaker_id ?? undefined,
words: data.words.map(
(word): TimedString =>
createTimedString({
Expand All @@ -625,6 +682,7 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
endTime: word.end + this.startTimeOffset,
startTimeOffset: this.startTimeOffset,
confidence: word.confidence,
speakerId: word.speaker_id ?? undefined,
}),
),
};
Expand Down
13 changes: 11 additions & 2 deletions agents/src/stt/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ export interface SpeechData {
confidence: number;
/** Word-level timing information. */
words?: TimedString[];
/** Speaker identifier when the provider supports diarization. */
speakerId?: string | null;
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
}

export interface RecognitionUsage {
Expand Down Expand Up @@ -97,6 +99,8 @@ export interface STTCapabilities {
* - false: Provider does not support aligned transcripts
*/
alignedTranscript?: 'word' | 'chunk' | false;
/** Whether this STT supports speaker diarization. */
diarization?: boolean;
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
}

export interface STTError {
Expand Down Expand Up @@ -133,6 +137,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter<STTCal
return this.#capabilities;
}

protected updateCapabilities(caps: Partial<STTCapabilities>): void {
this.#capabilities = { ...this.#capabilities, ...caps };
}

/**
* Get the model name/identifier for this STT instance.
*
Expand Down Expand Up @@ -262,8 +270,9 @@ export abstract class SpeechStream implements AsyncIterableIterator<SpeechEvent>
return await this.run();
} catch (error) {
// If the stream was intentionally aborted (e.g. session shutdown), exit
// silently. Downstream listeners may already be detached, so emitting an
// error here would trigger ERR_UNHANDLED_ERROR in Node's EventEmitter.
// silently. Downstream listeners may already be detached by this point,
// and emitting an `error` event here would trigger ERR_UNHANDLED_ERROR
// in Node's EventEmitter.
if (this.abortController.signal.aborted) {
return;
}
Expand Down
4 changes: 2 additions & 2 deletions agents/src/voice/agent_activity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ export class AgentActivity implements RecognitionHooks {
transcript: ev.alternatives![0].text,
isFinal: false,
language: ev.alternatives![0].language,
// TODO(AJS-106): add multi participant support
speakerId: ev.alternatives![0].speakerId ?? null,
}),
);

Expand All @@ -1167,7 +1167,7 @@ export class AgentActivity implements RecognitionHooks {
transcript: ev.alternatives![0].text,
isFinal: true,
language: ev.alternatives![0].language,
// TODO(AJS-106): add multi participant support
speakerId: ev.alternatives![0].speakerId ?? null,
}),
);

Expand Down
3 changes: 3 additions & 0 deletions agents/src/voice/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export interface TimedString {
endTime?: number; // seconds
confidence?: number;
startTimeOffset?: number;
speakerId?: string | null;
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
}

/**
Expand All @@ -55,6 +56,7 @@ export function createTimedString(opts: {
endTime?: number;
confidence?: number;
startTimeOffset?: number;
speakerId?: string | null;
}): TimedString {
return {
[TIMED_STRING_SYMBOL]: true,
Expand All @@ -63,6 +65,7 @@ export function createTimedString(opts: {
endTime: opts.endTime,
confidence: opts.confidence,
startTimeOffset: opts.startTimeOffset,
speakerId: opts.speakerId ?? null,
};
}

Expand Down
Loading