Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ class SpeechData:
words: list[TimedString] | None = None
source_languages: list[LanguageCode] | None = None
"""the source languages spoken by the user. populated by STT services that support translation,
where `language` holds the target language and `source_languages` holds the original spoken language(s).
where `language` holds the target language and `source_languages` holds the original spoken language(s),
or by multi-language detection services where `language` holds the dominant language and
`source_languages` holds all detected languages sorted by prevalence.
may contain multiple entries when a single utterance spans multiple source languages."""
source_texts: list[str] | None = None
"""the original transcription segments in the source language(s), when translation is active.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"whisper-large",
]

V2Models = Literal["flux-general-en"]
V2Models = Literal["flux-general-en", "flux-general-multi"]

DeepgramLanguages = Literal[
"zh",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class STTOptions:
eot_timeout_ms: NotGivenOr[int] = NOT_GIVEN
mip_opt_out: bool = False
tags: NotGivenOr[list[str]] = NOT_GIVEN
language_hint: NotGivenOr[list[str]] = NOT_GIVEN


class STTv2(stt.STT):
Expand All @@ -71,6 +72,7 @@ def __init__(
eot_timeout_ms: NotGivenOr[int] = NOT_GIVEN,
keyterm: NotGivenOr[str | list[str]] = NOT_GIVEN,
tags: NotGivenOr[list[str]] = NOT_GIVEN,
language_hint: NotGivenOr[list[str]] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
base_url: str = "wss://api.deepgram.com/v2/listen",
Expand All @@ -88,6 +90,7 @@ def __init__(
eot_timeout_ms: The timeout for end of speech detection. Defaults to 3000.
keyterm: str or list of str of key terms to improve recognition accuracy. Defaults to None.
tags: List of tags to add to the requests for usage reporting. Defaults to NOT_GIVEN.
language_hint: List of str of language hints to bias the model for improved accuracy. Only usable with `flux-general-multi`. Defaults to NOT_GIVEN.
api_key: Your Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY environment variable.
http_session: Optional aiohttp ClientSession to use for requests.
base_url: The base URL for Deepgram API. Defaults to "https://api.deepgram.com/v1/listen".
Expand Down Expand Up @@ -128,13 +131,19 @@ def __init__(
f"eager_eot_threshold ({eager_eot_threshold}) must be less than or equal to eot_threshold "
f"({effective_eot}); increase eot_threshold (max 0.9) to use a higher eager value"
)
if language_hint and model != "flux-general-multi":
logger.warning(
"`language_hint` is only supported by `flux-general-multi` and will be ignored for model '%s'",
model,
)

self._opts = STTOptions(
model=model,
sample_rate=sample_rate,
keyterm=keyterm if is_given(keyterm) else [],
mip_opt_out=mip_opt_out,
tags=_validate_tags(tags) if is_given(tags) else [],
language_hint=language_hint if is_given(language_hint) else [],
Comment thread
tinalenguyen marked this conversation as resolved.
eager_eot_threshold=eager_eot_threshold,
eot_threshold=eot_threshold,
eot_timeout_ms=eot_timeout_ms,
Expand Down Expand Up @@ -196,6 +205,7 @@ def update_options(
keyterm: NotGivenOr[str | list[str]] = NOT_GIVEN,
mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
tags: NotGivenOr[list[str]] = NOT_GIVEN,
language_hint: NotGivenOr[list[str]] = NOT_GIVEN,
endpoint_url: NotGivenOr[str] = NOT_GIVEN,
# deprecated
keyterms: NotGivenOr[list[str]] = NOT_GIVEN,
Expand Down Expand Up @@ -231,6 +241,13 @@ def update_options(
self._opts.mip_opt_out = mip_opt_out
if is_given(tags):
self._opts.tags = _validate_tags(tags)
if is_given(language_hint):
self._opts.language_hint = language_hint
if language_hint and self._opts.model != "flux-general-multi":
logger.warning(
"`language_hint` is only supported by `flux-general-multi` and will be ignored for model '%s'",
self._opts.model,
)
if is_given(endpoint_url):
self._opts.endpoint_url = endpoint_url
if is_given(eager_eot_threshold):
Expand All @@ -246,6 +263,7 @@ def update_options(
mip_opt_out=mip_opt_out,
endpoint_url=endpoint_url,
tags=tags,
language_hint=language_hint,
eager_eot_threshold=eager_eot_threshold,
)

Expand Down Expand Up @@ -289,6 +307,7 @@ def update_options(
keyterm: NotGivenOr[str | list[str]] = NOT_GIVEN,
mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
tags: NotGivenOr[list[str]] = NOT_GIVEN,
language_hint: NotGivenOr[list[str]] = NOT_GIVEN,
endpoint_url: NotGivenOr[str] = NOT_GIVEN,
eager_eot_threshold: NotGivenOr[float] = NOT_GIVEN,
# deprecated
Expand All @@ -313,6 +332,8 @@ def update_options(
self._opts.mip_opt_out = mip_opt_out
if is_given(tags):
self._opts.tags = _validate_tags(tags)
if is_given(language_hint):
self._opts.language_hint = language_hint
Comment thread
u9g marked this conversation as resolved.
if is_given(endpoint_url):
self._opts.endpoint_url = endpoint_url
if is_given(eager_eot_threshold):
Expand Down Expand Up @@ -456,6 +477,9 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._opts.tags:
live_config["tag"] = self._opts.tags

if self._opts.language_hint:
live_config["language_hint"] = self._opts.language_hint

try:
ws = await asyncio.wait_for(
self._session.ws_connect(
Expand Down Expand Up @@ -560,12 +584,18 @@ def _parse_transcription(
return []
confidence = sum(word["confidence"] for word in words) / len(words) if words else 0

detected_languages = data.get("languages") or []
primary_language = (
LanguageCode(detected_languages[0]) if detected_languages else LanguageCode(language)
)

sd = stt.SpeechData(
language=LanguageCode(language),
language=primary_language,
start_time=data.get("audio_window_start", 0) + start_time_offset,
end_time=data.get("audio_window_end", 0) + start_time_offset,
confidence=confidence,
text=transcript or "",
source_languages=[LanguageCode(lang) for lang in detected_languages] or None,
words=[
TimedString(
text=word.get("word", ""),
Expand Down
Loading