Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,7 +2217,7 @@ def _on_first_frame(fut: asyncio.Future[float] | asyncio.Future[None]) -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

audio_out: _AudioOutput | None = None
tts_gen_data: _TTSGenerationData | None = None
Expand Down Expand Up @@ -2574,7 +2574,7 @@ def _on_first_frame(fut: asyncio.Future[float] | asyncio.Future[None]) -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

audio_out: _AudioOutput | None = None
if audio_output is not None:
Expand Down Expand Up @@ -3035,7 +3035,7 @@ def _on_first_frame(
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

tasks: list[asyncio.Task[Any]] = []
tees: list[utils.aio.itertools.Tee[Any]] = []
Expand Down Expand Up @@ -3472,7 +3472,7 @@ def _on_false_interruption() -> None:
if self._audio_recognition and self._paused_speech.agent_state == "speaking":
self._audio_recognition.on_start_of_agent_speech(started_at=time.time())
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()
audio_output.resume()
resumed = True
logger.debug("resumed false interrupted speech", extra={"timeout": timeout})
Expand Down Expand Up @@ -3526,7 +3526,27 @@ async def _cancel_speech_pause(
):
self._session.output.audio.resume()

def _disable_vad_interruption_soon(self) -> None:
"""Disable VAD interruption after the backchannel boundary expires."""
if self._audio_recognition and self._audio_recognition.backchannel_boundary_active:

def _disable_vad_interruption() -> None:
# only disable it if the agent is still speaking
if (
self._session.agent_state == "speaking"
and self._interruption_by_audio_activity_enabled
):
logger.trace("backchannel boundary expired")
self._interruption_by_audio_activity_enabled = False

self._audio_recognition.backchannel_boundary_callback = _disable_vad_interruption
else:
self._interruption_by_audio_activity_enabled = False

def _restore_interruption_by_audio_activity(self) -> None:
if self._audio_recognition:
self._audio_recognition._cancel_backchannel_boundary()

self._interruption_by_audio_activity_enabled = (
self._default_interruption_by_audio_activity_enabled
)
Expand Down
104 changes: 96 additions & 8 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import time
from collections import deque
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol

Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(

self._tasks: set[asyncio.Task[Any]] = set()

# used for adaptive interruption detection
# region: adaptive interruption detection
self._interruption_atask: asyncio.Task[None] | None = None
self._interruption_detection = interruption_detection
self._interruption_ch: aio.Chan[inference.InterruptionDataFrameType] | None = None
Expand All @@ -184,6 +184,22 @@ def __init__(
self._interruption_enabled: bool = interruption_detection is not None and vad is not None
self._agent_speaking: bool = False

_backchannel_boundary: float | tuple[float, float] | None = (
session.options.interruption.get("backchannel_boundary")
)
self._backchannel_boundary: tuple[float, float] | None = (
(_backchannel_boundary, _backchannel_boundary)
if isinstance(_backchannel_boundary, int | float)
else _backchannel_boundary
)
if self._backchannel_boundary and (
len(self._backchannel_boundary) != 2 or any(x < 0.0 for x in self._backchannel_boundary)
):
raise ValueError("backchannel_boundary must be a tuple of two non-negative floats")
self._backchannel_boundary_timer: asyncio.TimerHandle | None = None
self.backchannel_boundary_callback: Callable[[], None] | None = None
# endregion

self._user_turn_span: trace.Span | None = None
self._stt_request_ids: list[str] = []
self._closing = asyncio.Event()
Expand Down Expand Up @@ -236,14 +252,45 @@ def adaptive_interruption_active(self) -> bool:
and not self._interruption_ch.closed
)

# region: boundary for adaptive interruption detection

@property
def backchannel_boundary_active(self) -> bool:
return self._backchannel_boundary_timer is not None

def _on_backchannel_boundary_done(self) -> None:
self._backchannel_boundary_timer = None
cb, self.backchannel_boundary_callback = (
self.backchannel_boundary_callback,
None,
)
if cb is not None:
cb()

def _cancel_backchannel_boundary(self) -> None:
if self._backchannel_boundary_timer is not None:
self._backchannel_boundary_timer.cancel()
self._backchannel_boundary_timer = None
self.backchannel_boundary_callback = None

# endregion

def on_start_of_agent_speech(self, started_at: float) -> None:
self._agent_speaking = True
self._endpointing.on_start_of_agent_speech(started_at=started_at)

if self._backchannel_boundary and (start_cooldown := self._backchannel_boundary[0]) > 0:
self._cancel_backchannel_boundary()
self._backchannel_boundary_timer = asyncio.get_running_loop().call_later(
start_cooldown, self._on_backchannel_boundary_done
)

if self.adaptive_interruption_active:
self._interruption_ch.send_nowait(_AgentSpeechStartedSentinel()) # type: ignore[union-attr]

def on_end_of_agent_speech(self, *, ignore_user_transcript_until: float) -> None:
self._cancel_backchannel_boundary()

if self._agent_speaking:
self._endpointing.on_end_of_agent_speech(ended_at=time.time())

Expand All @@ -257,14 +304,27 @@ def on_end_of_agent_speech(self, *, ignore_user_transcript_until: float) -> None
# no interruption is detected, end the inference (idempotent)
if not is_given(self._ignore_user_transcript_until):
self.on_end_of_overlap_speech(ended_at=time.time())
self._ignore_user_transcript_until = (

end_cooldown: float = (
self._backchannel_boundary[1] if self._backchannel_boundary else 0.0
)

ignore_until = (
ignore_user_transcript_until
if not is_given(self._ignore_user_transcript_until)
else min(ignore_user_transcript_until, self._ignore_user_transcript_until)
)
logger.trace(
"flushing held transcripts",
extra={
"ignore_until": ignore_until,
"end_cooldown": end_cooldown,
},
)
self._ignore_user_transcript_until = ignore_until - end_cooldown

# flush held transcripts if possible
task = asyncio.create_task(self._flush_held_transcripts())
task = asyncio.create_task(self._flush_held_transcripts(cooldown=end_cooldown))
task.add_done_callback(lambda _: self._tasks.discard(task))
self._tasks.add(task)

Expand Down Expand Up @@ -329,8 +389,8 @@ def on_end_of_overlap_speech(
_OverlapSpeechEndedSentinel(ended_at=ended_at or time.time())
)

async def _flush_held_transcripts(self) -> None:
"""Flush held transcripts whose *end time* is after the ignore_user_transcript_until timestamp.
async def _flush_held_transcripts(self, cooldown: float) -> None:
"""Flush held transcripts whose *end time* is after the ignore_user_transcript_until - cooldown timestamp.

If the event has no timestamps, we assume it is the same as the next valid event.
"""
Expand Down Expand Up @@ -374,13 +434,27 @@ async def _flush_held_transcripts(self) -> None:
if emit_from_index is not None and should_flush
else []
)
_ignore_user_transcript_until = self._ignore_user_transcript_until
self._reset_interruption_detection()

for ev in events_to_emit:
added_delay = 0.0
if ev.alternatives and ev.alternatives[0].end_time > 0:
added_delay = max(
0,
(
ev.alternatives[0].end_time
+ self._input_started_at
- _ignore_user_transcript_until
)
+ (cooldown or 0.0),
)
logger.trace(
"re-emitting held user transcript",
extra={
"event": ev.type,
"cooldown": cooldown,
"added_delay": added_delay,
},
)
await self._on_stt_event(ev)
Expand Down Expand Up @@ -471,6 +545,11 @@ async def aclose(self) -> None:
if self._end_of_turn_task is not None:
await self._end_of_turn_task

if self._backchannel_boundary_timer is not None:
self._backchannel_boundary_timer.cancel()
self._backchannel_boundary_timer = None
self.backchannel_boundary_callback = None

def update_stt(self, stt: io.STTNode | None, *, pipeline: _STTPipeline | None = None) -> None:
self._stt = stt
if pipeline is None and stt is not None:
Expand Down Expand Up @@ -557,6 +636,7 @@ def update_interruption_detection(
self._tasks.add(task)
self._interruption_atask = None
self._interruption_ch = None
self._cancel_backchannel_boundary()

self._interruption_enabled = (
self._interruption_detection is not None and self._vad is not None
Expand Down Expand Up @@ -714,8 +794,12 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
)
self._transcript_buffer.append(ev)
return
elif self._transcript_buffer:
await self._flush_held_transcripts()

if self._transcript_buffer:
end_cooldown: float = (
self._backchannel_boundary[1] if self._backchannel_boundary else 0.0
)
await self._flush_held_transcripts(cooldown=end_cooldown)
# no return here to allow the new event to be processed normally

if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
Expand Down Expand Up @@ -919,6 +1003,10 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None:
self._session.amd._on_user_speech_ended(ev.silence_duration)

async def _on_overlap_speech_event(self, ev: inference.OverlappingSpeechEvent) -> None:
if self.backchannel_boundary_active:
logger.trace("ignoring overlap speech event during backchannel boundary cooldown")
return

if ev.is_interruption:
self._hooks.on_interruption(ev)

Expand Down
10 changes: 10 additions & 0 deletions livekit-agents/livekit/agents/voice/turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class InterruptionOptions(TypedDict, total=False):
false_interruption_timeout: float | None
"""Seconds of silence after an interruption before it is
classified as false. ``None`` disables. Defaults to ``2.0``."""
backchannel_boundary: float | tuple[float, float] | None
"""Seconds to suppress adaptive interruption handling when the agent
starts or stops speaking each turn to allow for easier turn correction.
Use tuple to apply different values for start and end separately.
``None`` disables. Defaults to ``(1.0, 3.5)``. End value should be higher
to account for STT transcript timestamp inaccuracy."""


_INTERRUPTION_DEFAULTS: InterruptionOptions = {
Expand All @@ -109,6 +115,10 @@ class InterruptionOptions(TypedDict, total=False):
"min_words": 0,
"resume_false_interruption": True,
"false_interruption_timeout": 2.0,
"backchannel_boundary": (
1.0,
3.5, # higher value for the end as STT timestamps aren't very reliable
),
}


Expand Down
Loading
Loading