Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/bub/channels/handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import asyncio
import re

from loguru import logger

from bub.channels.message import ChannelMessage
from bub.types import MessageHandler

MEDIA_DATA_URL_RE = re.compile(r"data:[^;\s]+;base64,[^\"'\s]+", re.IGNORECASE)


class BufferedMessageHandler:
"""A message handler that buffers incoming messages and processes them in batch with debounce and active time window."""
Expand Down
5 changes: 3 additions & 2 deletions src/bub/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def tool(
context=context,
)
if isinstance(result, Tool):
REGISTRY[result.name] = result
return _add_logging(result)
tool_instance = _add_logging(result)
REGISTRY[tool_instance.name] = tool_instance
return tool_instance

def decorator(func: Callable) -> Tool:
tool_instance = _add_logging(result(func))
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def failing_tool() -> str:
assert errors[0].startswith("tool.call.error name=tests.failing_tool elapsed_time=")


@pytest.mark.asyncio
async def test_tool_direct_call_registers_wrapped_instance_in_registry() -> None:
tool_name = "tests.direct_call"
REGISTRY.pop(tool_name, None)

def direct_call(value: str) -> str:
return value.upper()

direct_tool = tool(direct_call, name=tool_name)

assert REGISTRY[tool_name] is direct_tool
assert await REGISTRY[tool_name].run("hello") == "HELLO"


def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None:
tool_name = "tests.rename_me"
REGISTRY.pop(tool_name, None)
Expand Down
Loading