diff --git a/src/bub/channels/handler.py b/src/bub/channels/handler.py index 37b1e7a7..24e664da 100644 --- a/src/bub/channels/handler.py +++ b/src/bub/channels/handler.py @@ -1,13 +1,10 @@ import asyncio -import re from loguru import logger from bub.channels.message import ChannelMessage from bub.types import MessageHandler -MEDIA_DATA_URL_RE = re.compile(r"data:[^;\s]+;base64,[^\"'\s]+", re.IGNORECASE) - class BufferedMessageHandler: """A message handler that buffers incoming messages and processes them in batch with debounce and active time window.""" diff --git a/src/bub/tools.py b/src/bub/tools.py index d723ec63..a57c9663 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -120,8 +120,9 @@ def tool( context=context, ) if isinstance(result, Tool): - REGISTRY[result.name] = result - return _add_logging(result) + tool_instance = _add_logging(result) + REGISTRY[tool_instance.name] = tool_instance + return tool_instance def decorator(func: Callable) -> Tool: tool_instance = _add_logging(result(func)) diff --git a/tests/test_tools.py b/tests/test_tools.py index 0c113f61..4d97e1e2 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -74,6 +74,20 @@ def failing_tool() -> str: assert errors[0].startswith("tool.call.error name=tests.failing_tool elapsed_time=") +@pytest.mark.asyncio +async def test_tool_direct_call_registers_wrapped_instance_in_registry() -> None: + tool_name = "tests.direct_call" + REGISTRY.pop(tool_name, None) + + def direct_call(value: str) -> str: + return value.upper() + + direct_tool = tool(direct_call, name=tool_name) + + assert REGISTRY[tool_name] is direct_tool + assert await REGISTRY[tool_name].run("hello") == "HELLO" + + def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None: tool_name = "tests.rename_me" REGISTRY.pop(tool_name, None)