From 79656011174e022272423aa5ae3919acc8702ee9 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Thu, 19 Mar 2026 09:06:27 +0800 Subject: [PATCH] feat: add quit tool and enhance session management in ChannelManager Signed-off-by: Frost Ming --- src/bub/builtin/tools.py | 10 ++++++++++ src/bub/channels/manager.py | 35 ++++++++++++++++++++++++----------- src/bub/framework.py | 8 +++++--- src/bub/types.py | 1 + tests/test_channels.py | 27 ++++++++++++++++++++++++++- 5 files changed, 66 insertions(+), 15 deletions(-) diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index 1a673eb3..f6dd035a 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -295,10 +295,20 @@ def show_help() -> str: " ,bash cmd='sleep 5' background=true\n" " ,bash.output shell_id=bsh-12345678\n" " ,bash.kill shell_id=bsh-12345678\n" + " ,quit\n" "Any unknown command after ',' is executed as shell via bash." ) +@tool(name="quit", context=True) +async def quit_tool(*, context: ToolContext) -> str: + """Quit the tasks of the current session.""" + agent = _get_agent(context) + session_id = context.state.get("session_id", "temp/unknown") + await agent.framework.quit_via_router(session_id) + return "Session tasks stopped." + + def _resolve_path(context: ToolContext, raw_path: str) -> Path: workspace = context.state.get("_runtime_workspace") path = Path(raw_path).expanduser() diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 3c241314..85cd8af4 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import functools from collections.abc import Collection from loguru import logger @@ -45,7 +46,7 @@ def __init__(self, framework: BubFramework, enabled_channels: Collection[str] | else: self._enabled_channels = self._settings.enabled_channels.split(",") self._messages = asyncio.Queue[ChannelMessage]() - self._ongoing_tasks: set[asyncio.Task] = set() + self._ongoing_tasks: dict[str, set[asyncio.Task]] = {} self._session_handlers: dict[str, MessageHandler] = {} async def on_receive(self, message: ChannelMessage) -> None: @@ -92,15 +93,26 @@ async def dispatch(self, message: Envelope) -> bool: await channel.send(outbound) return True + async def quit(self, session_id: str) -> None: + tasks = self._ongoing_tasks.pop(session_id, set()) + for task in tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + logger.info(f"channel.manager quit session_id={session_id}, cancelled {len(tasks)} tasks") + def enabled_channels(self) -> list[Channel]: if "all" in self._enabled_channels: # Exclude 'cli' channel from 'all' to prevent interference with other channels return [channel for name, channel in self._channels.items() if name != "cli"] return [channel for name, channel in self._channels.items() if name in self._enabled_channels] - def _on_task_done(self, task: asyncio.Task) -> None: + def _on_task_done(self, session_id: str, task: asyncio.Task) -> None: task.exception() # to log any exception - self._ongoing_tasks.discard(task) + tasks = self._ongoing_tasks.get(session_id, set()) + tasks.discard(task) + if not tasks: + self._ongoing_tasks.pop(session_id, None) async def listen_and_run(self) -> None: stop_event = asyncio.Event() @@ -112,8 +124,8 @@ async def listen_and_run(self) -> None: while True: message = await wait_until_stopped(self._messages.get(), stop_event) task = asyncio.create_task(self.framework.process_inbound(message)) - task.add_done_callback(self._on_task_done) - self._ongoing_tasks.add(task) + task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) + self._ongoing_tasks.setdefault(message.session_id, set()).add(task) except asyncio.CancelledError: logger.info("channel.manager received shutdown signal") except Exception: @@ -126,12 +138,13 @@ async def listen_and_run(self) -> None: async def shutdown(self) -> None: count = 0 - while self._ongoing_tasks: - task = self._ongoing_tasks.pop() - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await task - count += 1 + for tasks in self._ongoing_tasks.values(): + for task in tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + count += 1 + self._ongoing_tasks.clear() logger.info(f"channel.manager cancelled {count} in-flight tasks") for channel in self.enabled_channels(): await channel.stop() diff --git a/src/bub/framework.py b/src/bub/framework.py index 158416c8..a00c029d 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -68,9 +68,7 @@ def load_hooks(self) -> None: def create_cli_app(self) -> typer.Typer: """Create CLI app by collecting commands from hooks. Can be used for custom CLI entry point.""" - app = typer.Typer( - name="bub", help="A common shape for agents that live alongside people.", add_completion=False - ) + app = typer.Typer(name="bub", help="Batteries-included, hook-first AI framework", add_completion=False) @app.callback(invoke_without_command=True) def _main( @@ -149,6 +147,10 @@ async def dispatch_via_router(self, message: Envelope) -> bool: return False return await self._outbound_router.dispatch(message) + async def quit_via_router(self, session_id: str) -> None: + if self._outbound_router is not None: + await self._outbound_router.quit(session_id) + @staticmethod def _default_session_id(message: Envelope) -> str: session_id = field_of(message, "session_id") diff --git a/src/bub/types.py b/src/bub/types.py index 917f2132..a1f73c77 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -14,6 +14,7 @@ class OutboundChannelRouter(Protocol): async def dispatch(self, message: Envelope) -> bool: ... + async def quit(self, session_id: str) -> None: ... @dataclass(frozen=True) diff --git a/tests/test_channels.py b/tests/test_channels.py index 9debf1b7..5fbec216 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib from datetime import datetime from pathlib import Path from types import SimpleNamespace @@ -159,7 +160,7 @@ async def never_finish() -> None: await asyncio.sleep(10) task = asyncio.create_task(never_finish()) - manager._ongoing_tasks.add(task) + manager._ongoing_tasks["telegram:chat"] = {task} await manager.shutdown() @@ -168,6 +169,30 @@ async def never_finish() -> None: assert cli.stopped is False +@pytest.mark.asyncio +async def test_channel_manager_quit_cancels_only_matching_session_tasks() -> None: + manager = ChannelManager(FakeFramework({"telegram": FakeChannel("telegram")}), enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + target_task = asyncio.create_task(never_finish()) + other_task = asyncio.create_task(never_finish()) + manager._ongoing_tasks["session:target"] = {target_task} + manager._ongoing_tasks["session:other"] = {other_task} + + await manager.quit("session:target") + + assert target_task.cancelled() + assert "session:target" not in manager._ongoing_tasks + assert other_task.cancelled() is False + assert manager._ongoing_tasks["session:other"] == {other_task} + + other_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await other_task + + def test_cli_channel_normalize_input_prefixes_shell_commands() -> None: channel = CliChannel.__new__(CliChannel) channel._mode = "shell"