Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions livekit-plugins/livekit-plugins-xai/livekit/plugins/xai/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class STTOptions:
sample_rate: int
enable_diarization: bool
language: STTLanguages | str
endpointing: int


class STT(stt.STT):
Expand All @@ -67,6 +68,7 @@ def __init__(
sample_rate: int = SAMPLE_RATE,
enable_diarization: bool = False,
language: STTLanguages | str = "en",
endpointing: int = 100,
api_key: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
) -> None:
Expand All @@ -77,6 +79,7 @@ def __init__(
sample_rate: The sample rate of the audio in Hz. Defaults to 16000.
enable_diarization: Whether to enable speaker diarization. Words will include a speaker field. Defaults to False.
language: BCP-47 language code for transcription (e.g. "en", "fr", "de"). Defaults to "en".
endpointing: Silence duration in milliseconds before an utterance-final event is fired. xAI's default is 10ms, but we default to 100ms for better compatibility with LK EOT models.
api_key: Your xAI API key. If not provided, will look for XAI_API_KEY environment variable.
http_session: Optional aiohttp ClientSession to use for requests.

Expand Down Expand Up @@ -107,6 +110,7 @@ def __init__(
sample_rate=sample_rate,
enable_diarization=enable_diarization,
language=language,
endpointing=endpointing,
)
self._session = http_session
self._streams = weakref.WeakSet[SpeechStream]()
Expand Down Expand Up @@ -193,6 +197,7 @@ def update_options(
sample_rate: NotGivenOr[int] = NOT_GIVEN,
enable_diarization: NotGivenOr[bool] = NOT_GIVEN,
language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
endpointing: NotGivenOr[int] = NOT_GIVEN,
) -> None:
if is_given(interim_results):
self._opts.enable_interim_results = interim_results
Expand All @@ -206,12 +211,16 @@ def update_options(
if is_given(language):
self._opts.language = language

if is_given(endpointing):
self._opts.endpointing = endpointing

for stream in self._streams:
stream.update_options(
enable_interim_results=interim_results,
sample_rate=sample_rate,
enable_diarization=enable_diarization,
language=language,
endpointing=endpointing,
)


Expand Down Expand Up @@ -247,6 +256,7 @@ def update_options(
sample_rate: NotGivenOr[int] = NOT_GIVEN,
enable_diarization: NotGivenOr[bool] = NOT_GIVEN,
language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
endpointing: NotGivenOr[int] = NOT_GIVEN,
) -> None:
if is_given(enable_interim_results):
self._opts.enable_interim_results = enable_interim_results
Expand All @@ -260,6 +270,9 @@ def update_options(
if is_given(language):
self._opts.language = language

if is_given(endpointing):
self._opts.endpointing = endpointing

self._reconnect_event.set()

async def _run(self) -> None:
Expand Down Expand Up @@ -359,6 +372,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
"interim_results": str(self._opts.enable_interim_results).lower(),
"diarize": str(self._opts.enable_diarization).lower(),
"language": str(self._opts.language),
"endpointing": str(self._opts.endpointing),
}
try:
ws = await asyncio.wait_for(
Expand Down
Loading