Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 129 additions & 58 deletions livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,38 @@ def _build_websocket_url(base_url: str, opts: SarvamSTTOptions) -> str:
return f"{base_url}?{urlencode(params)}"


def _looks_like_error_text(value: object) -> bool:
"""Heuristic to detect server-side error hints in text payloads/reasons."""
if not isinstance(value, str):
return False

lowered = value.lower()
error_hints = (
"error",
"invalid",
"failed",
"forbidden",
"unauthorized",
"not found",
"rate limit",
"timeout",
)
return any(hint in lowered for hint in error_hints)


def _has_error_field(data: dict) -> bool:
"""Check whether a parsed message carries an explicit error indicator."""
if data.get("error") is not None:
return True
nested = data.get("data")
if isinstance(nested, dict):
if nested.get("error") is not None:
return True
if nested.get("event_type") == "error" or nested.get("event") == "error":
return True
return False


class STT(stt.STT):
"""Sarvam.ai Speech-to-Text implementation.

Expand Down Expand Up @@ -439,6 +471,28 @@ def _ensure_session(self) -> aiohttp.ClientSession:
self._session = utils.http_context.http_session()
return self._session

@staticmethod
def _single_attempt_conn_options(conn_options: APIConnectOptions) -> APIConnectOptions:
return APIConnectOptions(
max_retry=0,
retry_interval=conn_options.retry_interval,
timeout=conn_options.timeout,
)

async def recognize(
self,
buffer: AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.SpeechEvent:
single_attempt_conn_options = self._single_attempt_conn_options(conn_options)
return await super().recognize(
buffer,
language=language,
conn_options=single_attempt_conn_options,
)

def _resolve_opts(
self,
*,
Expand Down Expand Up @@ -517,9 +571,6 @@ async def _recognize_impl(
form_data.add_field("model", str(opts_model))
if _model_supports_mode(opts_model):
form_data.add_field("mode", str(opts_mode))
# input_audio_codec is intentionally not sent here: the audio is always
# converted to WAV above (to_wav_bytes), so the codec is always "wav".
# input_audio_codec is only relevant for the WebSocket streaming path.

if not self._api_key:
raise ValueError("API key cannot be None")
Expand Down Expand Up @@ -598,6 +649,9 @@ async def _recognize_impl(
except aiohttp.ClientError as e:
self._logger.error(f"Sarvam API client error: {e}")
raise APIConnectionError(f"Sarvam API connection error: {e}") from e
except (APIStatusError, APIConnectionError, APITimeoutError):
# Preserve provider-originated status/body/retry metadata.
raise
except Exception as e:
self._logger.error(f"Error during Sarvam STT processing: {e}")
raise APIConnectionError(f"Unexpected error in Sarvam STT: {e}") from e
Expand Down Expand Up @@ -635,6 +689,7 @@ def stream(
opts_input_codec = (
input_audio_codec if is_given(input_audio_codec) else self._opts.input_audio_codec
)
single_attempt_conn_options = self._single_attempt_conn_options(conn_options)

# Create options for the stream
stream_opts = SarvamSTTOptions(
Expand All @@ -657,7 +712,7 @@ def stream(
stream = SpeechStream(
stt=self,
opts=stream_opts,
conn_options=conn_options,
conn_options=single_attempt_conn_options,
api_key=self._api_key,
http_session=stream_session,
)
Expand Down Expand Up @@ -857,51 +912,27 @@ async def _run(self) -> None:
request_id = utils.shortuuid()
self._client_request_id = request_id
self._server_request_id = None
num_retries = 0
max_retries = getattr(self._conn_options, "max_retry_count", 3)

while num_retries <= max_retries:
try:
await self._run_connection()
break # Success, exit retry loop

except (
aiohttp.ClientConnectorError,
asyncio.TimeoutError,
) as e: # TODO: Check if retry should happen for every Exception type
if num_retries == max_retries:
async with self._connection_lock:
self._connection_state = ConnectionState.FAILED
raise APIConnectionError(
f"Failed to connect to STT WebSocket after {max_retries} attempts"
) from e

# Exponential backoff with jitter, max 30 seconds
retry_interval = min(2**num_retries + (num_retries * 0.1), 30)
async with self._connection_lock:
self._connection_state = ConnectionState.RECONNECTING

self._logger.warning(
f"Connection failed, retrying in {retry_interval:.1f}s",
extra={
**self._build_log_context(),
"attempt": num_retries + 1,
"max_retries": max_retries + 1,
"error": str(e),
},
)
await asyncio.sleep(retry_interval)
num_retries += 1

except Exception as e:
async with self._connection_lock:
self._connection_state = ConnectionState.FAILED
self._logger.error(
f"Unrecoverable error in WebSocket connection: {e}",
extra=self._build_log_context(),
exc_info=True,
)
raise APIConnectionError(f"WebSocket connection failed: {e}") from e
try:
await self._run_connection()
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
async with self._connection_lock:
self._connection_state = ConnectionState.FAILED
self._logger.error(f"Connection failed: {e}", extra=self._build_log_context())
raise APIConnectionError(f"Failed to connect to STT WebSocket: {e}") from e
except (APIStatusError, APIConnectionError, APITimeoutError):
async with self._connection_lock:
self._connection_state = ConnectionState.FAILED
# Preserve provider-originated status/body/retry metadata.
raise
except Exception as e:
async with self._connection_lock:
self._connection_state = ConnectionState.FAILED
self._logger.error(
f"Unexpected error in STT WebSocket session: {e}",
extra=self._build_log_context(),
exc_info=True,
)
raise APIStatusError(f"STT WebSocket session failed: {e}") from e

async def _run_connection(self) -> None:
"""Run a single WebSocket connection attempt."""
Expand Down Expand Up @@ -971,6 +1002,33 @@ async def _run_connection(self) -> None:
self._reconnect_event.clear()
return

# Keep listening for server-side terminal errors when audio finishes first.
if self._audio_task in done and self._message_task in pending:
audio_exc = self._audio_task.exception()
if audio_exc is not None:
raise audio_exc

done2, pending2 = await asyncio.wait(
[self._message_task, reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
timeout=self._conn_options.timeout,
)
done |= done2
pending = pending2

if reconnect_task in done2:
self._logger.info(
"Reconnection requested, closing current connection",
extra=self._build_log_context(),
)
self._reconnect_event.clear()
return

if not done2:
raise APITimeoutError(
"Timed out waiting for STT server response after audio input ended"
)

# Cancel remaining tasks using LiveKit's utility
if pending:
await utils.aio.cancel_and_wait(*pending)
Expand Down Expand Up @@ -1105,12 +1163,18 @@ async def _process_messages(self, ws: aiohttp.ClientWebSocketResponse) -> None:
)

try:
async for msg in ws:
while True:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
await self._handle_message(data)
except json.JSONDecodeError as e:
if _looks_like_error_text(msg.data):
raise APIStatusError(
message=(f"Sarvam STT non-JSON error message: {msg.data}"),
body={"raw_message": msg.data},
) from e
self._logger.warning(
"Invalid JSON received from WebSocket",
extra={
Expand All @@ -1135,7 +1199,7 @@ async def _process_messages(self, ws: aiohttp.ClientWebSocketResponse) -> None:
) from e

elif msg.type == aiohttp.WSMsgType.ERROR:
error_msg = f"WebSocket error: {ws.exception()}"
error_msg = f"WebSocket error: {msg.data}"
self._logger.error(error_msg, extra=self._build_log_context())
raise APIConnectionError(error_msg)

Expand All @@ -1147,8 +1211,9 @@ async def _process_messages(self, ws: aiohttp.ClientWebSocketResponse) -> None:
close_code = ws.close_code if ws.close_code is not None else msg.data
close_reason = msg.extra
is_expected_close = close_code in (1000, 1001, None)
has_error_reason = _looks_like_error_text(close_reason)

if not is_expected_close:
if not is_expected_close or has_error_reason:
self._logger.error(
f"WebSocket closed: {msg.type}",
extra={
Expand All @@ -1157,14 +1222,15 @@ async def _process_messages(self, ws: aiohttp.ClientWebSocketResponse) -> None:
"close_reason": close_reason,
},
)
msg_type = getattr(msg.type, "name", str(msg.type))
raw_close = {
"msg_type": str(msg.type),
"msg_type": msg_type,
"close_code": close_code,
"close_reason": close_reason,
}
raise APIStatusError(
message=(
"Sarvam STT WebSocket closed with non-graceful status: "
"Sarvam STT WebSocket closed unexpectedly: "
f"{json.dumps(raw_close, ensure_ascii=False)}"
),
status_code=int(close_code) if isinstance(close_code, int) else -1,
Expand All @@ -1186,7 +1252,7 @@ async def _process_messages(self, ws: aiohttp.ClientWebSocketResponse) -> None:
extra=self._build_log_context(),
)

except (APIStatusError, APIConnectionError):
except (APIStatusError, APIConnectionError, APITimeoutError):
# Already logged at origin — just propagate
raise
except Exception as e:
Expand All @@ -1211,9 +1277,14 @@ async def _handle_message(self, data: dict) -> None:

if msg_type == "data":
await self._handle_transcript_data(data)
elif msg_type == "events":
await self._handle_events(data)
elif msg_type == "error":
elif msg_type in ("events", "event"):
if _has_error_field(data):
await self._handle_error_message(data)
else:
await self._handle_events(data)
elif msg_type in ("error", "errors"):
await self._handle_error_message(data)
elif _has_error_field(data):
await self._handle_error_message(data)
else:
self._logger.debug(
Expand Down
Loading