Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/bub/builtin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,20 @@ def show_help() -> str:
" ,bash cmd='sleep 5' background=true\n"
" ,bash.output shell_id=bsh-12345678\n"
" ,bash.kill shell_id=bsh-12345678\n"
" ,quit\n"
"Any unknown command after ',' is executed as shell via bash."
)


@tool(name="quit", context=True)
async def quit_tool(*, context: ToolContext) -> str:
"""Quit the tasks of the current session."""
agent = _get_agent(context)
session_id = context.state.get("session_id", "temp/unknown")
await agent.framework.quit_via_router(session_id)
return "Session tasks stopped."


def _resolve_path(context: ToolContext, raw_path: str) -> Path:
workspace = context.state.get("_runtime_workspace")
path = Path(raw_path).expanduser()
Expand Down
35 changes: 24 additions & 11 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import contextlib
import functools
from collections.abc import Collection

from loguru import logger
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, framework: BubFramework, enabled_channels: Collection[str] |
else:
self._enabled_channels = self._settings.enabled_channels.split(",")
self._messages = asyncio.Queue[ChannelMessage]()
self._ongoing_tasks: set[asyncio.Task] = set()
self._ongoing_tasks: dict[str, set[asyncio.Task]] = {}
self._session_handlers: dict[str, MessageHandler] = {}

async def on_receive(self, message: ChannelMessage) -> None:
Expand Down Expand Up @@ -92,15 +93,26 @@ async def dispatch(self, message: Envelope) -> bool:
await channel.send(outbound)
return True

async def quit(self, session_id: str) -> None:
tasks = self._ongoing_tasks.pop(session_id, set())
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
logger.info(f"channel.manager quit session_id={session_id}, cancelled {len(tasks)} tasks")

def enabled_channels(self) -> list[Channel]:
if "all" in self._enabled_channels:
# Exclude 'cli' channel from 'all' to prevent interference with other channels
return [channel for name, channel in self._channels.items() if name != "cli"]
return [channel for name, channel in self._channels.items() if name in self._enabled_channels]

def _on_task_done(self, task: asyncio.Task) -> None:
def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
task.exception() # to log any exception
self._ongoing_tasks.discard(task)
tasks = self._ongoing_tasks.get(session_id, set())
tasks.discard(task)
if not tasks:
self._ongoing_tasks.pop(session_id, None)

async def listen_and_run(self) -> None:
stop_event = asyncio.Event()
Expand All @@ -112,8 +124,8 @@ async def listen_and_run(self) -> None:
while True:
message = await wait_until_stopped(self._messages.get(), stop_event)
task = asyncio.create_task(self.framework.process_inbound(message))
task.add_done_callback(self._on_task_done)
self._ongoing_tasks.add(task)
task.add_done_callback(functools.partial(self._on_task_done, message.session_id))
self._ongoing_tasks.setdefault(message.session_id, set()).add(task)
except asyncio.CancelledError:
logger.info("channel.manager received shutdown signal")
except Exception:
Expand All @@ -126,12 +138,13 @@ async def listen_and_run(self) -> None:

async def shutdown(self) -> None:
count = 0
while self._ongoing_tasks:
task = self._ongoing_tasks.pop()
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
count += 1
for tasks in self._ongoing_tasks.values():
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
count += 1
self._ongoing_tasks.clear()
logger.info(f"channel.manager cancelled {count} in-flight tasks")
for channel in self.enabled_channels():
await channel.stop()
8 changes: 5 additions & 3 deletions src/bub/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def load_hooks(self) -> None:

def create_cli_app(self) -> typer.Typer:
"""Create CLI app by collecting commands from hooks. Can be used for custom CLI entry point."""
app = typer.Typer(
name="bub", help="A common shape for agents that live alongside people.", add_completion=False
)
app = typer.Typer(name="bub", help="Batteries-included, hook-first AI framework", add_completion=False)

@app.callback(invoke_without_command=True)
def _main(
Expand Down Expand Up @@ -149,6 +147,10 @@ async def dispatch_via_router(self, message: Envelope) -> bool:
return False
return await self._outbound_router.dispatch(message)

async def quit_via_router(self, session_id: str) -> None:
if self._outbound_router is not None:
await self._outbound_router.quit(session_id)

@staticmethod
def _default_session_id(message: Envelope) -> str:
session_id = field_of(message, "session_id")
Expand Down
1 change: 1 addition & 0 deletions src/bub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class OutboundChannelRouter(Protocol):
async def dispatch(self, message: Envelope) -> bool: ...
async def quit(self, session_id: str) -> None: ...


@dataclass(frozen=True)
Expand Down
27 changes: 26 additions & 1 deletion tests/test_channels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
Expand Down Expand Up @@ -159,7 +160,7 @@ async def never_finish() -> None:
await asyncio.sleep(10)

task = asyncio.create_task(never_finish())
manager._ongoing_tasks.add(task)
manager._ongoing_tasks["telegram:chat"] = {task}

await manager.shutdown()

Expand All @@ -168,6 +169,30 @@ async def never_finish() -> None:
assert cli.stopped is False


@pytest.mark.asyncio
async def test_channel_manager_quit_cancels_only_matching_session_tasks() -> None:
manager = ChannelManager(FakeFramework({"telegram": FakeChannel("telegram")}), enabled_channels=["telegram"])

async def never_finish() -> None:
await asyncio.sleep(10)

target_task = asyncio.create_task(never_finish())
other_task = asyncio.create_task(never_finish())
manager._ongoing_tasks["session:target"] = {target_task}
manager._ongoing_tasks["session:other"] = {other_task}

await manager.quit("session:target")

assert target_task.cancelled()
assert "session:target" not in manager._ongoing_tasks
assert other_task.cancelled() is False
assert manager._ongoing_tasks["session:other"] == {other_task}

other_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await other_task


def test_cli_channel_normalize_input_prefixes_shell_commands() -> None:
channel = CliChannel.__new__(CliChannel)
channel._mode = "shell"
Expand Down
Loading