diff --git a/.changeset/seven-fans-exist.md b/.changeset/seven-fans-exist.md new file mode 100644 index 000000000..fd2c19371 --- /dev/null +++ b/.changeset/seven-fans-exist.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +feat(stt): add diarization capabilities and speaker_id support diff --git a/agents/src/inference/api_protos.ts b/agents/src/inference/api_protos.ts index 30c24f429..2e9b29ce6 100644 --- a/agents/src/inference/api_protos.ts +++ b/agents/src/inference/api_protos.ts @@ -91,6 +91,7 @@ export const sttWordSchema = z.object({ start: z.number().optional().default(0), end: z.number().optional().default(0), confidence: z.number().optional().default(0.0), + speaker_id: z.string().nullable().optional(), extra: z.unknown().nullable().optional(), }); @@ -104,6 +105,7 @@ export const sttInterimTranscriptEventSchema = z.object({ duration: z.number().optional().default(0), confidence: z.number().optional().default(1.0), words: z.array(sttWordSchema).optional().default([]), + speaker_id: z.string().nullable().optional(), extra: z.unknown().nullable().optional(), }); @@ -117,6 +119,7 @@ export const sttFinalTranscriptEventSchema = z.object({ duration: z.number().optional().default(0), confidence: z.number().optional().default(1.0), words: z.array(sttWordSchema).optional().default([]), + speaker_id: z.string().nullable().optional(), extra: z.unknown().nullable().optional(), }); diff --git a/agents/src/inference/index.ts b/agents/src/inference/index.ts index 096d6fc27..e6704fb14 100644 --- a/agents/src/inference/index.ts +++ b/agents/src/inference/index.ts @@ -25,6 +25,8 @@ export { type STTModels, type ModelWithLanguage as STTModelString, type STTOptions, + type XaiSTTModels, + type XaiOptions as XaiSTTOptions, } from './stt.js'; export { diff --git a/agents/src/inference/stt.test.ts b/agents/src/inference/stt.test.ts index 429b17839..febc3298c 100644 --- a/agents/src/inference/stt.test.ts +++ b/agents/src/inference/stt.test.ts @@ -5,7 +5,13 @@ import { beforeAll, describe, expect, it } from 'vitest'; import { normalizeLanguage } from '../language.js'; import { initializeLogger } from '../log.js'; import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; -import { STT, type STTFallbackModel, normalizeSTTFallback, parseSTTModelString } from './stt.js'; +import { + STT, + type STTFallbackModel, + type XaiSTTModels, + normalizeSTTFallback, + parseSTTModelString, +} from './stt.js'; beforeAll(() => { initializeLogger({ level: 'silent', pretty: false }); @@ -251,3 +257,70 @@ describe('STT constructor fallback and connOptions', () => { expect(stt['opts'].connOptions!.retryIntervalMs).toBe(2000); }); }); + +describe('STT diarization capabilities', () => { + it('no diarization by default', () => { + const stt = makeStt(); + expect(stt.capabilities.diarization).toBe(false); + }); + + it('diarization enabled with deepgram diarize option', () => { + const stt = makeStt({ modelOptions: { diarize: true } }); + expect(stt.capabilities.diarization).toBe(true); + }); + + it('diarization disabled with diarize false', () => { + const stt = makeStt({ modelOptions: { diarize: false } }); + expect(stt.capabilities.diarization).toBe(false); + }); + + it('diarization enabled with assemblyai speaker_labels', () => { + const stt = makeStt({ + model: 'assemblyai/universal-streaming', + modelOptions: { speaker_labels: true }, + }); + expect(stt.capabilities.diarization).toBe(true); + }); + + it('updateOptions toggles diarization capability', () => { + const stt = makeStt(); + expect(stt.capabilities.diarization).toBe(false); + + stt.updateOptions({ modelOptions: { diarize: true } as Record }); + expect(stt.capabilities.diarization).toBe(true); + + stt.updateOptions({ modelOptions: { diarize: false } as Record }); + expect(stt.capabilities.diarization).toBe(false); + }); + + it('diarization enabled with xai diarize option', () => { + const stt = makeStt({ + model: 'xai/stt-1' satisfies XaiSTTModels, + modelOptions: { diarize: true }, + }); + expect(stt.capabilities.diarization).toBe(true); + }); + + it('updateOptions preserves unrelated flags when merging', () => { + const stt = makeStt({ modelOptions: { diarize: true } }); + expect(stt.capabilities.diarization).toBe(true); + + stt.updateOptions({ modelOptions: { endpointing: 500 } as Record }); + expect(stt['opts'].modelOptions).toHaveProperty('diarize', true); + expect(stt['opts'].modelOptions).toHaveProperty('endpointing', 500); + expect(stt.capabilities.diarization).toBe(true); + }); + + it('updateOptions merges modelOptions on associated streams', () => { + const stt = makeStt({ modelOptions: { diarize: true } }); + const stream = stt.stream(); + + stt.updateOptions({ modelOptions: { endpointing: 500 } as Record }); + + // The stream's local modelOptions must be the merged object, not the partial. + expect(stream['opts'].modelOptions).toHaveProperty('diarize', true); + expect(stream['opts'].modelOptions).toHaveProperty('endpointing', 500); + + stream.close(); + }); +}); diff --git a/agents/src/inference/stt.ts b/agents/src/inference/stt.ts index 03da41f3e..2f542974b 100644 --- a/agents/src/inference/stt.ts +++ b/agents/src/inference/stt.ts @@ -43,6 +43,8 @@ export type AssemblyaiModels = export type ElevenlabsSTTModels = 'elevenlabs/scribe_v2_realtime'; +export type XaiSTTModels = 'xai/stt-1'; + export interface CartesiaOptions { /** Minimum volume threshold. Default: not specified. */ min_volume?: number; @@ -71,6 +73,8 @@ export interface DeepgramOptions { numerals?: boolean; /** Opt out of model improvement program. */ mip_opt_out?: boolean; + /** Enable speaker diarization. Default: false. */ + diarize?: boolean; /** Eager end-of-turn threshold (0.0–1.0). Enables preflight transcripts for preemptive generation. */ eager_eot_threshold?: number; } @@ -86,6 +90,19 @@ export interface AssemblyAIOptions { max_turn_silence?: number; /** Key terms prompt for recognition. Default: not specified. */ keyterms_prompt?: string[]; + /** Enable speaker diarization. Default: false. */ + speaker_labels?: boolean; +} + +export interface XaiOptions { + /** Enable speaker diarization. Default: false. */ + diarize?: boolean; + /** Silence duration in ms before utterance-final (0-5000). */ + endpointing?: number; + /** Enable Inverse Text Normalization. Requires language. */ + format?: boolean; + /** Default true; set false to opt out of interim transcripts. */ + interim_results?: boolean; } export type STTLanguages = @@ -100,7 +117,19 @@ export type STTLanguages = | 'hi' | AnyString; -type _STTModels = DeepgramModels | CartesiaModels | AssemblyaiModels | ElevenlabsSTTModels; +const DIARIZATION_EXTRA_KEYS = ['diarize', 'speaker_labels'] as const; + +function diarizationEnabled(extraKwargs: Record | undefined): boolean { + if (!extraKwargs) return false; + return DIARIZATION_EXTRA_KEYS.some((key) => Boolean(extraKwargs[key])); +} + +type _STTModels = + | DeepgramModels + | CartesiaModels + | AssemblyaiModels + | ElevenlabsSTTModels + | XaiSTTModels; export type STTModels = _STTModels | 'auto' | AnyString; @@ -112,7 +141,9 @@ export type STTOptions = TModel extends DeepgramModels ? CartesiaOptions : TModel extends AssemblyaiModels ? AssemblyAIOptions - : Record; + : TModel extends XaiSTTModels + ? XaiOptions + : Record; /** A fallback model with optional extra configuration. Extra fields are passed through to the provider. */ export interface STTFallbackModel { @@ -191,7 +222,13 @@ export class STT extends BaseSTT { fallback?: STTFallbackModelType | STTFallbackModelType[]; connOptions?: APIConnectOptions; }) { - super({ streaming: true, interimResults: true, alignedTranscript: 'word' }); + const modelOptions = (opts?.modelOptions ?? {}) as STTOptions; + super({ + streaming: true, + interimResults: true, + alignedTranscript: 'word', + diarization: diarizationEnabled(modelOptions as Record), + }); const { model, @@ -201,7 +238,6 @@ export class STT extends BaseSTT { sampleRate = DEFAULT_SAMPLE_RATE, apiKey, apiSecret, - modelOptions = {} as STTOptions, fallback, connOptions, } = opts || {}; @@ -272,13 +308,26 @@ export class STT extends BaseSTT { throw new Error('LiveKit STT does not support batch recognition, use stream() instead'); } - updateOptions(opts: Partial, 'model' | 'language'>>): void { + updateOptions( + opts: Partial, 'model' | 'language' | 'modelOptions'>>, + ): void { + const mergedModelOptions = opts.modelOptions + ? ({ ...this.opts.modelOptions, ...opts.modelOptions } as STTOptions) + : this.opts.modelOptions; + this.opts = { ...this.opts, ...opts, language: opts.language !== undefined ? normalizeLanguage(opts.language) : this.opts.language, + modelOptions: mergedModelOptions, }; + if (opts.modelOptions) { + this.updateCapabilities({ + diarization: diarizationEnabled(this.opts.modelOptions as Record), + }); + } + for (const stream of this.streams) { stream.updateOptions(opts); } @@ -377,11 +426,18 @@ export class SpeechStream extends BaseSpeechStream { return 'inference.SpeechStream'; } - updateOptions(opts: Partial, 'model' | 'language'>>): void { + updateOptions( + opts: Partial, 'model' | 'language' | 'modelOptions'>>, + ): void { + const mergedModelOptions = opts.modelOptions + ? ({ ...this.opts.modelOptions, ...opts.modelOptions } as STTOptions) + : this.opts.modelOptions; + this.opts = { ...this.opts, ...opts, language: opts.language !== undefined ? normalizeLanguage(opts.language) : this.opts.language, + modelOptions: mergedModelOptions, }; this.reconnectEvent.set(); } @@ -617,6 +673,7 @@ export class SpeechStream extends BaseSpeechStream { endTime: this.startTimeOffset + data.start + data.duration, confidence: data.confidence, text, + speakerId: data.speaker_id ?? undefined, words: data.words.map( (word): TimedString => createTimedString({ @@ -625,6 +682,7 @@ export class SpeechStream extends BaseSpeechStream { endTime: word.end + this.startTimeOffset, startTimeOffset: this.startTimeOffset, confidence: word.confidence, + speakerId: word.speaker_id ?? undefined, }), ), }; diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index a16b7f919..69ebd1418 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -62,6 +62,8 @@ export interface SpeechData { confidence: number; /** Word-level timing information. */ words?: TimedString[]; + /** Speaker identifier when the provider supports diarization. */ + speakerId?: string | null; } export interface RecognitionUsage { @@ -97,6 +99,8 @@ export interface STTCapabilities { * - false: Provider does not support aligned transcripts */ alignedTranscript?: 'word' | 'chunk' | false; + /** Whether this STT supports speaker diarization. */ + diarization?: boolean; } export interface STTError { @@ -133,6 +137,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter): void { + this.#capabilities = { ...this.#capabilities, ...caps }; + } + /** * Get the model name/identifier for this STT instance. * @@ -262,8 +270,9 @@ export abstract class SpeechStream implements AsyncIterableIterator return await this.run(); } catch (error) { // If the stream was intentionally aborted (e.g. session shutdown), exit - // silently. Downstream listeners may already be detached, so emitting an - // error here would trigger ERR_UNHANDLED_ERROR in Node's EventEmitter. + // silently. Downstream listeners may already be detached by this point, + // and emitting an `error` event here would trigger ERR_UNHANDLED_ERROR + // in Node's EventEmitter. if (this.abortController.signal.aborted) { return; } diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index fbcac7012..b0ddf9d30 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -1146,7 +1146,7 @@ export class AgentActivity implements RecognitionHooks { transcript: ev.alternatives![0].text, isFinal: false, language: ev.alternatives![0].language, - // TODO(AJS-106): add multi participant support + speakerId: ev.alternatives![0].speakerId ?? null, }), ); @@ -1167,7 +1167,7 @@ export class AgentActivity implements RecognitionHooks { transcript: ev.alternatives![0].text, isFinal: true, language: ev.alternatives![0].language, - // TODO(AJS-106): add multi participant support + speakerId: ev.alternatives![0].speakerId ?? null, }), ); diff --git a/agents/src/voice/io.ts b/agents/src/voice/io.ts index ff5d8a8b1..620e5868f 100644 --- a/agents/src/voice/io.ts +++ b/agents/src/voice/io.ts @@ -44,6 +44,7 @@ export interface TimedString { endTime?: number; // seconds confidence?: number; startTimeOffset?: number; + speakerId?: string | null; } /** @@ -55,6 +56,7 @@ export function createTimedString(opts: { endTime?: number; confidence?: number; startTimeOffset?: number; + speakerId?: string | null; }): TimedString { return { [TIMED_STRING_SYMBOL]: true, @@ -63,6 +65,7 @@ export function createTimedString(opts: { endTime: opts.endTime, confidence: opts.confidence, startTimeOffset: opts.startTimeOffset, + speakerId: opts.speakerId ?? null, }; }