Skip to content
Merged
61 changes: 60 additions & 1 deletion livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"assemblyai/u3-rt-pro",
]
ElevenlabsModels = Literal["elevenlabs/scribe_v2_realtime",]
XaiModels = Literal["xai/stt-1",]


class CartesiaOptions(TypedDict, total=False):
Expand All @@ -70,7 +71,7 @@ class DeepgramOptions(TypedDict, total=False):
numerals: bool
mip_opt_out: bool # default: False
vad_events: bool # default: False
diarize: bool
diarize: bool # when True, enables speaker diarization (default off)
dictation: bool
detect_language: bool
no_delay: bool # default: True
Expand Down Expand Up @@ -105,6 +106,7 @@ class AssemblyaiOptions(TypedDict, total=False):
language_detection: bool
inactivity_timeout: float # seconds
prompt: str # default: not specified (u3-rt-pro only, mutually exclusive with keyterms_prompt)
speaker_labels: bool # when True, enables speaker diarization (default off)


class ElevenlabsOptions(TypedDict, total=False):
Expand All @@ -117,6 +119,30 @@ class ElevenlabsOptions(TypedDict, total=False):
language_code: str


class XaiOptions(TypedDict, total=False):
diarize: bool # when True, enables speaker diarization (default off)
endpointing: int # silence duration in ms before utterance-final (0-5000)
format: bool # enables Inverse Text Normalization (e.g. "one hundred dollars" -> "$100"); requires language
interim_results: bool # default True; set False to opt out of interim transcripts


# Diarization is requested via different extra_kwargs keys across
# providers. Keep this list in one place so adding a new provider is a
# single-line change and there's no divergence between __init__ and
# update_options capability inference.
_DIARIZATION_EXTRA_KEYS: tuple[str, ...] = (
"diarize", # Deepgram, xAI
"speaker_labels", # AssemblyAI
)


def _diarization_enabled(extra_kwargs: dict[str, Any] | None) -> bool:
"""Return True if any known provider diarization flag is truthy."""
if not extra_kwargs:
return False
return any(bool(extra_kwargs.get(key)) for key in _DIARIZATION_EXTRA_KEYS)


STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"]


Expand Down Expand Up @@ -168,6 +194,7 @@ def _make_fallback(model: FallbackModelType) -> FallbackModel:
| CartesiaModels
| AssemblyAIModels
| ElevenlabsModels
| XaiModels
| Literal["auto"] # automatically select a provider based on the language
)
STTEncoding = Literal["pcm_s16le"]
Expand Down Expand Up @@ -277,6 +304,23 @@ def __init__(
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

@overload
def __init__(
self,
model: XaiModels,
*,
language: NotGivenOr[str] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
sample_rate: NotGivenOr[int] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[XaiOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
) -> None: ...

@overload
def __init__(
self,
Expand Down Expand Up @@ -312,6 +356,7 @@ def __init__(
| DeepgramFluxOptions
| AssemblyaiOptions
| ElevenlabsOptions
| XaiOptions
] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
Expand All @@ -332,10 +377,18 @@ def __init__(
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
"""
# Infer diarization capability from provider-specific extra_kwargs
# keys (see _DIARIZATION_EXTRA_KEYS). xAI uses "diarize" (same as
# Deepgram); AssemblyAI uses "speaker_labels".
diarization_enabled = _diarization_enabled(
dict(extra_kwargs) if is_given(extra_kwargs) else None
)

super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
diarization=diarization_enabled,
aligned_transcript="word",
offline_recognize=False,
),
Expand Down Expand Up @@ -452,6 +505,10 @@ def update_options(
self._opts.language = LanguageCode(language)
if is_given(extra):
self._opts.extra_kwargs.update(extra)
self._capabilities = replace(
self._capabilities,
diarization=_diarization_enabled(self._opts.extra_kwargs),
)

for stream in self._streams:
stream.update_options(model=model, language=language, extra=extra)
Expand Down Expand Up @@ -689,13 +746,15 @@ def _build_speech_data(self, data: dict) -> stt.SpeechData:
end_time=self.start_time_offset + data.get("start", 0) + data.get("duration", 0),
confidence=data.get("confidence", 1.0),
text=data.get("transcript", ""),
speaker_id=data.get("speaker_id"),
words=[
TimedString(
text=word.get("word", ""),
start_time=word.get("start", 0) + self.start_time_offset,
end_time=word.get("end", 0) + self.start_time_offset,
start_time_offset=self.start_time_offset,
confidence=word.get("confidence", 0.0),
speaker_id=word.get("speaker_id"),
)
for word in words
],
Expand Down
20 changes: 18 additions & 2 deletions livekit-agents/livekit/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,27 @@ def _interval_for_retry(self, num_retries: int) -> float:


class TimedString(str):
"""A string with optional start and end timestamps for word-level alignment."""
"""A string with optional start and end timestamps for word-level alignment.

Attributes:
start_time: Word start time in seconds (NOT_GIVEN when unavailable).
end_time: Word end time in seconds (NOT_GIVEN when unavailable).
confidence: Per-word confidence score (NOT_GIVEN when unavailable).
start_time_offset: Offset in seconds relative to the start of the audio
input stream or session. Used by STT plugins to align words against
the session timeline (NOT_GIVEN when unavailable).
speaker_id: Speaker identifier when the provider supports diarization.
Uses ``str | None`` rather than ``NotGivenOr[str]`` because the
absence of a speaker is a routine, expected case across all
providers — not a "not given" boundary condition — and downstream
consumers gate on ``speaker_id is None`` rather than ``is_given``.
"""

start_time: NotGivenOr[float]
end_time: NotGivenOr[float]
confidence: NotGivenOr[float]
start_time_offset: NotGivenOr[float]
# offset relative to the start of the audio input stream or session in seconds, used in STT plugins
speaker_id: str | None

def __new__(
cls,
Expand All @@ -128,10 +142,12 @@ def __new__(
end_time: NotGivenOr[float] = NOT_GIVEN,
confidence: NotGivenOr[float] = NOT_GIVEN,
start_time_offset: NotGivenOr[float] = NOT_GIVEN,
speaker_id: str | None = None,
) -> "TimedString":
obj = super().__new__(cls, text)
obj.start_time = start_time
obj.end_time = end_time
obj.confidence = confidence
obj.start_time_offset = start_time_offset
obj.speaker_id = speaker_id
return obj
53 changes: 53 additions & 0 deletions tests/test_inference_stt_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,56 @@ def test_connect_options_full_custom(self):
assert stt._opts.conn_options.timeout == 60.0
assert stt._opts.conn_options.max_retry == 10
assert stt._opts.conn_options.retry_interval == 2.0


class TestSTTDiarizationCapabilities:
"""Tests for STT diarization capability detection from extra_kwargs."""

def test_no_diarization_by_default(self):
"""Without diarization params, capabilities.diarization is False."""
stt = _make_stt()
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_deepgram_diarize(self):
"""Deepgram's 'diarize' param enables diarization capability."""
stt = _make_stt(extra_kwargs={"diarize": True})
assert stt.capabilities.diarization is True

def test_diarization_disabled_with_diarize_false(self):
"""Deepgram's 'diarize: False' keeps diarization capability False."""
stt = _make_stt(extra_kwargs={"diarize": False})
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_assemblyai_speaker_labels(self):
"""AssemblyAI's 'speaker_labels' param enables diarization capability."""
stt = _make_stt(
model="assemblyai/universal-streaming", extra_kwargs={"speaker_labels": True}
)
assert stt.capabilities.diarization is True

def test_update_options_toggles_diarization(self):
"""update_options can enable and disable diarization capability."""
stt = _make_stt()
assert stt.capabilities.diarization is False
stt.update_options(extra={"diarize": True})
assert stt.capabilities.diarization is True
stt.update_options(extra={"diarize": False})
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_xai_diarize(self):
"""xAI shares the 'diarize' key with Deepgram; capability flips True."""
stt = _make_stt(model="xai/stt-1", extra_kwargs={"diarize": True})
assert stt.capabilities.diarization is True

def test_update_options_extra_preserves_unrelated_flags(self):
"""Partial extra updates merge into existing extra_kwargs instead of
replacing them — so a prior diarize=True is retained when the new
extra only mentions an unrelated key.
"""
stt = _make_stt(extra_kwargs={"diarize": True})
assert stt.capabilities.diarization is True
stt.update_options(extra={"endpointing": 500})
# diarize should still be active after the partial update.
assert stt._opts.extra_kwargs.get("diarize") is True
assert stt._opts.extra_kwargs.get("endpointing") == 500
assert stt.capabilities.diarization is True
68 changes: 67 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading