Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions livekit-agents/livekit/agents/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dataclasses import dataclass
from enum import Enum, unique
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, overload
from urllib.parse import urlparse

import aiohttp
Expand Down Expand Up @@ -64,9 +64,17 @@ def _observability_url(livekit_url: str) -> str | None:
from .voice.report import SessionReport


def get_job_context() -> JobContext:
@overload
def get_job_context(*, required: Literal[True] = True) -> JobContext: ...


@overload
def get_job_context(*, required: Literal[False]) -> JobContext | None: ...


def get_job_context(*, required: bool = True) -> JobContext | None:
ctx = _JobContextVar.get(None)
if ctx is None:
if ctx is None and required:
raise RuntimeError(
"no job context found, are you running this code inside a job entrypoint?"
)
Expand Down
55 changes: 39 additions & 16 deletions livekit-agents/livekit/agents/llm/async_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any, Literal, get_origin, get_type_hints

from .. import utils
from ..job import JobContext, get_job_context
from ..llm.chat_context import ChatItem, FunctionCall
from ..llm.tool_context import (
FunctionTool,
Expand All @@ -23,6 +24,31 @@
if TYPE_CHECKING:
from ..voice.agent_session import AgentSession

# Module-level registry of running async tasks, keyed by (JobContext | None, call_id).
# Registered when a task starts in _wrap_tool, auto-removed via done callback.
_RunningTasks: dict[tuple[JobContext | None, str], _RunningTask] = {}


@function_tool
async def get_running_tasks() -> list[dict]:
"""Get the list of running async tool calls across all async toolsets."""
job_ctx = get_job_context(required=False)
return [
task.ctx.function_call.model_dump()
for (ctx, _), task in list(_RunningTasks.items())
if ctx is job_ctx
]
Comment thread
longcw marked this conversation as resolved.


@function_tool
async def cancel_task(call_id: str) -> str:
"""Cancel a running async tool call by call_id."""
job_ctx = get_job_context(required=False)
task = _RunningTasks.get((job_ctx, call_id))
if task and await task.ctx._toolset.cancel(call_id):
return f"Task {call_id} cancelled successfully."
return f"Task {call_id} not found or already completed."


UPDATE_TEMPLATE = """The tool `{function_name}` has updated, message: {message}
The task is still running, so DON'T make up or give information not included in the message above."""
Expand Down Expand Up @@ -158,27 +184,14 @@ def __init__(
self._wrap_tool(t) if isinstance(t, FunctionTool | RawFunctionTool) else t
for t in self._tools
]
self._tools.extend([get_running_tasks, cancel_task])

self._running_tasks: dict[str, _RunningTask] = {}

# speech delivery — shared across all tools in this toolset
self._pending_updates: list[_PendingUpdate] = []
self._reply_task: asyncio.Task[None] | None = None

@function_tool
async def get_running_tasks(self) -> list[dict]:
"""Get the list of running async tool calls."""
return [task.ctx.function_call.model_dump() for task in self._running_tasks.values()]

@function_tool
async def cancel_task(self, call_id: str) -> str:
"""Cancel a running async tool call by call_id."""
success = await self.cancel(call_id)
if success:
return f"Task {call_id} cancelled successfully."
else:
return f"Task {call_id} not found or already completed."

async def cancel(self, call_id: str) -> bool:
task = self._running_tasks.get(call_id)
if task is not None:
Expand Down Expand Up @@ -274,8 +287,18 @@ async def _execute_tool() -> Any:
exe_task = asyncio.create_task(_execute_tool(), name=f"async_tool_{fnc_name}")
_pass_through_activity_task_info(exe_task)

self._running_tasks[call_id] = _RunningTask(ctx=async_ctx, exe_task=exe_task)
exe_task.add_done_callback(lambda _: self._running_tasks.pop(call_id, None))
running_task = _RunningTask(ctx=async_ctx, exe_task=exe_task)
self._running_tasks[call_id] = running_task

# register in the module-level registry
task_key = (get_job_context(required=False), call_id)
_RunningTasks[task_key] = running_task

def _on_done(_: asyncio.Task[Any]) -> None:
self._running_tasks.pop(call_id, None)
_RunningTasks.pop(task_key, None)

exe_task.add_done_callback(_on_done)

return await async_ctx._pending_fut

Expand Down
7 changes: 5 additions & 2 deletions livekit-agents/livekit/agents/llm/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,11 @@ def add_tool(tool: Tool | Toolset) -> None:
self._provider_tools.append(tool)

elif isinstance(tool, (FunctionTool, RawFunctionTool)):
if tool.info.name in self._fnc_tools_map:
raise ValueError(f"duplicate function name: {tool.info.name}")
existing = self._fnc_tools_map.get(tool.info.name)
if existing is not None:
if existing is not tool:
raise ValueError(f"duplicate function name: {tool.info.name}")
return # same instance, skip
self._fnc_tools_map[tool.info.name] = tool

elif isinstance(tool, Toolset):
Expand Down
14 changes: 4 additions & 10 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .. import cli, inference, llm, stt, tts, utils, vad
from .._exceptions import APIError
from ..job import JobContext, get_job_context
from ..job import get_job_context
from ..llm import AgentHandoff, ChatContext, MetricsReport
from ..llm.chat_context import Instructions
from ..log import logger
Expand Down Expand Up @@ -615,16 +615,10 @@ async def start(

# configure observability first
record_is_given = is_given(record)
job_ctx: JobContext | None = None
try:
job_ctx = get_job_context(required=False)
if not is_given(record):
# defer to server-side setting for recording
job_ctx = get_job_context()
if not is_given(record):
record = job_ctx.job.enable_recording
except RuntimeError:
# JobContext is not available in evals, will not be able to record
if not is_given(record):
record = False
record = job_ctx.job.enable_recording if job_ctx else False

self._recording_options = _resolve_recording_options(record) # type: ignore[arg-type]

Expand Down
53 changes: 46 additions & 7 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,24 @@ def test_toolset_with_regular_tools(self):
assert ctx == ctx_copy
assert toolset in ctx_copy.toolsets

def test_toolset_duplicate_name_conflict(self):
def test_toolset_same_instance_dedup(self):
# same instance appearing in multiple places is allowed (deduplication)
toolset = MockToolset1() # contains mock_tool_1, mock_tool_2
ToolContext([toolset, mock_tool_1]) # same mock_tool_1 instance, no conflict
ToolContext([mock_tool_1, toolset])
ToolContext([MockToolset1(), MockToolset2()]) # same mock_tool_2 instance

# conflict: toolset before tool, tool before toolset, multiple toolsets
with pytest.raises(ValueError, match="duplicate function name"):
ToolContext([toolset, mock_tool_1])
def test_toolset_duplicate_name_conflict(self):
toolset = MockToolset1() # contains mock_tool_1, mock_tool_2

with pytest.raises(ValueError, match="duplicate function name"):
ToolContext([mock_tool_1, toolset])
# different instances with the same name should raise
@function_tool
async def mock_tool_1() -> str:
"""Duplicate name, different instance"""
return ""

with pytest.raises(ValueError, match="duplicate function name"):
ToolContext([MockToolset1(), MockToolset2()]) # both have mock_tool_2
ToolContext([toolset, mock_tool_1])

def test_toolset_equality(self):
toolset = MockToolset1()
Expand Down Expand Up @@ -653,3 +659,36 @@ async def test_invalid_json_surfaces_error(self):
assert result.fnc_call_out.is_error is True
# Should contain error details, not generic message
assert "An internal error occurred" not in result.fnc_call_out.output


class TestAsyncToolsetDedup:
"""Test that multiple AsyncToolsets can coexist without duplicate tool name conflicts."""

def test_two_async_toolsets_no_conflict(self):
"""Two AsyncToolsets share the same get_running_tasks/cancel_task singleton tools."""
from livekit.agents.llm.async_toolset import AsyncToolset

ts1 = AsyncToolset(id="booking", tools=[mock_tool_1])
ts2 = AsyncToolset(id="search", tools=[mock_tool_2])

# should not raise — the management tools are the same module-level instances
ctx = ToolContext([ts1, ts2])

# only one copy of each management tool in the flattened list
names = [t.id for t in ctx.flatten() if hasattr(t, "id")]
assert names.count("get_running_tasks") == 1
assert names.count("cancel_task") == 1

def test_async_toolset_same_id_no_conflict(self):
"""Two AsyncToolsets with the same id should not conflict."""
from livekit.agents.llm.async_toolset import AsyncToolset

ts1 = AsyncToolset(id="same_id", tools=[mock_tool_1])
ts2 = AsyncToolset(id="same_id", tools=[mock_tool_2])

# should not raise
ctx = ToolContext([ts1, ts2])

names = [t.id for t in ctx.flatten() if hasattr(t, "id")]
assert names.count("get_running_tasks") == 1
assert names.count("cancel_task") == 1
Loading