Skip to content
Merged
203 changes: 203 additions & 0 deletions examples/voice_agents/tool_search_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Example: ToolSearchToolset and ToolProxyToolset for dynamic tool discovery.

Both toolsets wrap a collection of tools and expose a `tool_search` function so the
LLM can discover tools on demand instead of loading them all upfront.

ToolSearchToolset:
- Matched tools are added directly to the LLM's tool list on the next turn.
- The model uses native tool calls to invoke discovered tools.
- May be simpler for the model to understand.

ToolProxyToolset:
- Exposes exactly two fixed tools: `tool_search` and `call_tool`.
- The tool list never changes, so providers may reuse their prompt cache across turns.
- May be better for many tools or cost-sensitive workloads.
"""

import logging

from dotenv import load_dotenv

from livekit.agents import (
Agent,
AgentServer,
AgentSession,
JobContext,
MetricsCollectedEvent,
cli,
inference,
llm,
metrics,
)
from livekit.agents.beta.toolsets import ToolProxyToolset, ToolSearchToolset
from livekit.agents.metrics.base import LLMMetrics
from livekit.plugins import silero

logger = logging.getLogger("tool-search-example")
logger.setLevel(logging.INFO)

load_dotenv()


class WeatherToolset(llm.Toolset):
@llm.function_tool
async def get_weather(self, location: str) -> str:
"""Get current weather for a location.

Args:
location: City name or region
"""
logger.info(f"Getting weather for {location}")
return f"Sunny, 72F in {location}"

@llm.function_tool
async def get_forecast(self, location: str, days: int) -> str:
"""Get weather forecast for upcoming days.

Args:
location: City name or region
days: Number of days to forecast
"""
logger.info(f"Getting {days}-day forecast for {location}")
return f"{days}-day forecast for {location}: mostly sunny"


class FlightToolset(llm.Toolset):
@llm.function_tool
async def search_flights(self, origin: str, destination: str, date: str) -> str:
"""Search for available flights.

Args:
origin: Departure city or airport code
destination: Arrival city or airport code
date: Travel date
"""
logger.info(f"Searching flights {origin} -> {destination} on {date}")
return f"Found 3 flights from {origin} to {destination} on {date}"

@llm.function_tool
async def book_flight(self, flight_id: str) -> str:
"""Book a specific flight.

Args:
flight_id: The flight identifier to book
"""
logger.info(f"Booking flight {flight_id}")
return f"Flight {flight_id} booked successfully"


class HotelToolset(llm.Toolset):
@llm.function_tool
async def search_hotels(self, city: str, check_in: str, check_out: str) -> str:
"""Search for hotels in a city.

Args:
city: City to search hotels in
check_in: Check-in date
check_out: Check-out date
"""
logger.info(f"Searching hotels in {city}")
return f"Found 5 hotels in {city} from {check_in} to {check_out}"

@llm.function_tool
async def book_hotel(self, hotel_id: str) -> str:
"""Book a specific hotel.

Args:
hotel_id: The hotel identifier to book
"""
logger.info(f"Booking hotel {hotel_id}")
return f"Hotel {hotel_id} booked successfully"


@llm.function_tool
async def convert_currency(amount: float, from_currency: str, to_currency: str) -> str:
"""Convert between currencies.

Args:
amount: Amount to convert
from_currency: Source currency code (e.g. USD)
to_currency: Target currency code (e.g. EUR)
"""
logger.info(f"Converting {amount} {from_currency} to {to_currency}")
return f"{amount} {from_currency} = {amount * 0.85} {to_currency}"


class TravelAgent(Agent):
def __init__(self, use_tool_proxy: bool = True) -> None:
toolset_cls = ToolProxyToolset if use_tool_proxy else ToolSearchToolset
super().__init__(
instructions="""
You are a comprehensive travel planning assistant with access to multiple specialized
toolsets. Your role is to help users plan trips by providing weather information,
searching and booking flights, finding and reserving hotels, and converting currencies.

Tool Discovery: You have access to a `tool_search` function that lets you discover
available tools.

Voice Interaction Style: This is a voice conversation, not text chat. Keep your responses
short and natural — one or two sentences at a time. Do not list multiple options in a
single response; instead, mention the top choice and ask if the user wants to hear more.
Gather information one piece at a time rather than asking multiple questions at once.

Remember to use tool_search to find the right tools before trying to help the user.
""",
tools=[
toolset_cls(
id="travel_tools",
tools=[
WeatherToolset(id="weather"),
FlightToolset(id="flights"),
HotelToolset(id="hotels"),
convert_currency,
],
max_results=3,
)
],
)

async def on_enter(self):
self.session.generate_reply(
instructions="Greet the user and let them know you can help with "
"travel planning: weather, flights, hotels, and currency exchange."
)


server = AgentServer()


@server.rtc_session()
async def entrypoint(ctx: JobContext):
session = AgentSession(
llm=inference.LLM("openai/gpt-4.1-mini"),
tts=inference.TTS("cartesia/sonic-3"),
stt=inference.STT("deepgram/nova-3"),
vad=silero.VAD.load(),
)

# Track token usage to observe prompt caching behavior.
# With ToolProxyToolset, the tool list is constant, so prompt_cached_tokens
# should increase after the first turn as the provider caches tool definitions.
usage_collector = metrics.UsageCollector()

@session.on("metrics_collected")
def _on_metrics_collected(ev: MetricsCollectedEvent):
usage_collector.collect(ev.metrics)

# Log cache hit ratio for LLM requests
if isinstance(ev.metrics, LLMMetrics) and ev.metrics.prompt_tokens > 0:
metrics.log_metrics(ev.metrics)
cache_ratio = ev.metrics.prompt_cached_tokens / ev.metrics.prompt_tokens
logger.info(
f"Prompt cache: {ev.metrics.prompt_cached_tokens}/{ev.metrics.prompt_tokens} "
f"tokens cached ({cache_ratio:.0%})"
)

await session.start(
agent=TravelAgent(use_tool_proxy=True),
room=ctx.room,
)


if __name__ == "__main__":
cli.run_app(server)
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/beta/toolsets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .tool_proxy import ToolProxyToolset
from .tool_search import ToolSearchToolset

__all__ = ["ToolProxyToolset", "ToolSearchToolset"]
171 changes: 171 additions & 0 deletions livekit-agents/livekit/agents/beta/toolsets/tool_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from __future__ import annotations

import json
from typing import Any

from pydantic import ValidationError
from typing_extensions import Self

from ...llm.tool_context import (
FunctionTool,
RawFunctionTool,
Tool,
ToolContext,
ToolError,
Toolset,
function_tool,
)
from ...llm.utils import function_arguments_to_pydantic_model, prepare_function_arguments
from ...log import logger
from ...types import NOT_GIVEN, NotGivenOr
from ...voice.events import RunContext
from .tool_search import SearchStrategy, ToolSearchToolset

_DEFAULT_SEARCH_DESCRIPTION = (
"Search for available tools by describing what you need. "
"Returns the schemas of matching tools. Use call_tool to invoke them."
)

_DEFAULT_CALL_DESCRIPTION = (
"Call a tool by name with the given arguments. "
"Use search_tools to discover available tools and their schemas if it isn't already loaded."
)


class ToolProxyToolset(ToolSearchToolset):
"""Exposes exactly two fixed tools: search_tools and call_tool.

Unlike ToolSearchToolset which dynamically modifies the tool list,
ToolProxyToolset keeps a constant tool list. ``search_tools`` returns
tool schemas as text, and ``call_tool`` executes tools by name.

This is useful for maximizing prompt cache hit rates with providers
that cache based on tool definitions (e.g. Anthropic, OpenAI).
"""

def __init__(
self,
*,
id: str,
tools: list[Tool | Toolset] | None = None,
max_results: int = 5,
search_strategy: NotGivenOr[SearchStrategy] = NOT_GIVEN,
search_description: NotGivenOr[str] = NOT_GIVEN,
query_description: NotGivenOr[str] = NOT_GIVEN,
call_description: NotGivenOr[str] = NOT_GIVEN,
) -> None:
super().__init__(
id=id,
tools=tools,
max_results=max_results,
search_strategy=search_strategy,
search_description=search_description or _DEFAULT_SEARCH_DESCRIPTION,
query_description=query_description,
)
self._tool_ctx: ToolContext | None = None

call_description = call_description or _DEFAULT_CALL_DESCRIPTION
self._call_tool = function_tool(
self._handle_call,
raw_schema={
"name": "call_tool",
"description": call_description,
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the tool to call",
},
"parameters": {
"type": "object",
"description": "The parameters to pass to the tool",
},
},
"required": ["name", "parameters"],
},
},
)

@property
def tools(self) -> list[Tool | Toolset]:
# constant tool list — only search_tools and call_tool
return [self._search_tool, self._call_tool]

async def setup(self, *, reload: bool = False) -> Self:
await super().setup(reload=reload)

# build a ToolContext from all wrapped tools for call_tool execution
self._tool_ctx = ToolContext(self._tools)
return self

async def _handle_search(self, raw_arguments: dict[str, object]) -> str:
query = str(raw_arguments.get("query", ""))
if not query:
raise ToolError("query cannot be empty")

tools = await self._search_tools(query)
if not tools:
raise ToolError(f"No tools found matching '{query}'.")

tool_ctx = ToolContext(tools)
schemas = [_build_tool_schema(tool) for tool in tool_ctx.function_tools.values()]
return "\n".join(json.dumps(schema) for schema in schemas)

async def _handle_call(self, ctx: RunContext[Any], raw_arguments: dict[str, object]) -> Any:
name = str(raw_arguments.get("name", ""))
parameters = raw_arguments.get("parameters")

if not name:
raise ToolError("tool name cannot be empty")

if parameters is None:
raise ToolError("parameters is required")

if self._tool_ctx is None:
raise RuntimeError("toolset not initialized, call setup() first")

fnc_tool = self._tool_ctx.get_function_tool(name)
if fnc_tool is None:
raise ToolError(f"unknown tool '{name}', use search_tools to discover available tools")

try:
json_args = json.dumps(parameters) if isinstance(parameters, dict) else str(parameters)
fnc_args, fnc_kwargs = prepare_function_arguments(
fnc=fnc_tool,
json_arguments=json_args,
call_ctx=ctx,
)
except ValidationError as e:
raise ToolError(
f"invalid parameters for tool '{name}': {e.json(include_url=False)}"
) from e
except ToolError:
raise
except Exception as e:
logger.exception(
f"error parsing arguments for tool '{name}'",
extra={"tool": name, "arguments": parameters},
)
raise ToolError(f"error calling '{name}': {e}") from e

return await fnc_tool(*fnc_args, **fnc_kwargs)


def _build_tool_schema(tool: FunctionTool | RawFunctionTool) -> dict[str, Any]:
"""Build a JSON-serializable tool schema with full parameter type info."""
if isinstance(tool, FunctionTool):
model = function_arguments_to_pydantic_model(tool)
return {
"name": tool.info.name,
"description": tool.info.description or "",
"parameters": model.model_json_schema(),
}

# RawFunctionTool — use raw_schema directly
raw = tool.info.raw_schema
return {
"name": raw.get("name", tool.id),
"description": raw.get("description", ""),
"parameters": raw.get("parameters", {}),
}
Loading
Loading