Skip to content

Commit 7d3113f

Browse files
committed
feat: add session management to shell commands and improve task cancellation handling
Signed-off-by: Frost Ming <me@frostming.com>
1 parent 14f2974 commit 7d3113f

5 files changed

Lines changed: 141 additions & 9 deletions

File tree

src/bub/builtin/shell_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class ManagedShell:
1313
shell_id: str
1414
cmd: str
1515
cwd: str | None
16+
session_id: str | None
1617
process: asyncio.subprocess.Process
1718
output_chunks: list[str] = field(default_factory=list)
1819
read_tasks: list[asyncio.Task[None]] = field(default_factory=list)
@@ -36,15 +37,21 @@ class ShellManager:
3637
def __init__(self) -> None:
3738
self._shells: dict[str, ManagedShell] = {}
3839

39-
async def start(self, *, cmd: str, cwd: str | None) -> ManagedShell:
40+
async def start(self, *, cmd: str, cwd: str | None, session_id: str | None = None) -> ManagedShell:
4041
process = await asyncio.create_subprocess_shell(
4142
cmd,
4243
cwd=cwd,
4344
stdout=asyncio.subprocess.PIPE,
4445
stderr=asyncio.subprocess.PIPE,
4546
executable=self.SHELL,
4647
)
47-
shell = ManagedShell(shell_id=f"bash-{uuid.uuid4().hex[:8]}", cmd=cmd, cwd=cwd, process=process)
48+
shell = ManagedShell(
49+
shell_id=f"bash-{uuid.uuid4().hex[:8]}",
50+
cmd=cmd,
51+
cwd=cwd,
52+
session_id=session_id,
53+
process=process,
54+
)
4855
shell.read_tasks.extend([
4956
asyncio.create_task(self._drain_stream(shell, process.stdout)),
5057
asyncio.create_task(self._drain_stream(shell, process.stderr)),
@@ -77,6 +84,13 @@ async def terminate(self, shell_id: str) -> ManagedShell:
7784
await self._finalize_shell(shell)
7885
return shell
7986

87+
async def terminate_session(self, session_id: str) -> int:
88+
shell_ids = [shell.shell_id for shell in self._shells.values() if shell.session_id == session_id]
89+
for shell_id in shell_ids:
90+
with contextlib.suppress(KeyError):
91+
await self.terminate(shell_id)
92+
return len(shell_ids)
93+
8094
async def wait_closed(self, shell_id: str) -> ManagedShell:
8195
shell = self.get(shell_id)
8296
if shell.returncode is None:

src/bub/builtin/tools.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,17 @@ async def bash(
7979
"""Run a shell command. Use background=true to keep it running and fetch output later via bash_output."""
8080
workspace = context.state.get("_runtime_workspace")
8181
target_cwd = cwd or workspace
82-
shell = await shell_manager.start(cmd=cmd, cwd=target_cwd)
82+
raw_session_id = context.state.get("session_id")
83+
session_id = str(raw_session_id) if raw_session_id is not None else None
84+
shell = await shell_manager.start(cmd=cmd, cwd=target_cwd, session_id=session_id)
8385
if background:
8486
return f"started: {shell.shell_id}"
8587
try:
8688
async with asyncio.timeout(timeout_seconds):
8789
shell = await shell_manager.wait_closed(shell.shell_id)
90+
except asyncio.CancelledError:
91+
await shell_manager.terminate(shell.shell_id)
92+
raise
8893
except TimeoutError:
8994
await shell_manager.terminate(shell.shell_id)
9095
return f"command timed out after {timeout_seconds} seconds and was terminated"
@@ -309,7 +314,8 @@ def show_help() -> str:
309314
async def quit_tool(*, context: ToolContext) -> str:
310315
"""Quit the tasks of the current session."""
311316
agent = _get_agent(context)
312-
session_id = context.state.get("session_id", "temp/unknown")
317+
session_id = str(context.state.get("session_id", "temp/unknown"))
318+
await shell_manager.terminate_session(session_id)
313319
await agent.framework.quit_via_router(session_id)
314320
return "Session tasks stopped."
315321

src/bub/channels/manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,16 @@ def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) ->
118118

119119
async def quit(self, session_id: str) -> None:
120120
tasks = self._ongoing_tasks.pop(session_id, set())
121+
current_task = asyncio.current_task()
122+
cancelled_count = 0
121123
for task in tasks:
124+
if task is current_task:
125+
continue
122126
task.cancel()
123127
with contextlib.suppress(asyncio.CancelledError):
124128
await task
125-
logger.info(f"channel.manager quit session_id={session_id}, cancelled {len(tasks)} tasks")
129+
cancelled_count += 1
130+
logger.info(f"channel.manager quit session_id={session_id}, cancelled {cancelled_count} tasks")
126131

127132
def enabled_channels(self) -> list[Channel]:
128133
if "all" in self._enabled_channels:
@@ -133,7 +138,10 @@ def enabled_channels(self) -> list[Channel]:
133138
]
134139

135140
def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
136-
task.exception() # to log any exception
141+
if task.cancelled():
142+
logger.info("channel.manager task cancelled session_id={}", session_id)
143+
else:
144+
task.exception() # to log any exception
137145
tasks = self._ongoing_tasks.get(session_id, set())
138146
tasks.discard(task)
139147
if not tasks:

tests/test_builtin_tools.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import shlex
56
import sys
7+
from types import SimpleNamespace
68

79
import pytest
810
from republic import ToolContext
@@ -11,11 +13,11 @@
1113

1214
import bub.builtin.tools as builtin_tools
1315
from bub.builtin.shell_manager import ShellManager
14-
from bub.builtin.tools import bash, bash_output, kill_bash
16+
from bub.builtin.tools import bash, bash_output, kill_bash, quit_tool
1517

1618

17-
def _tool_context(tmp_path) -> ToolContext:
18-
return ToolContext(tape="test-tape", run_id="test-run", state={"_runtime_workspace": str(tmp_path)})
19+
def _tool_context(tmp_path, **state) -> ToolContext:
20+
return ToolContext(tape="test-tape", run_id="test-run", state={"_runtime_workspace": str(tmp_path), **state})
1921

2022

2123
def _python_shell(code: str) -> str:
@@ -51,6 +53,26 @@ async def test_foreground_bash_releases_shell_when_command_fails(tmp_path, monke
5153
assert manager._shells == {}
5254

5355

56+
@pytest.mark.asyncio
57+
async def test_foreground_bash_terminates_shell_when_cancelled(tmp_path, monkeypatch) -> None:
58+
manager = ShellManager()
59+
monkeypatch.setattr(builtin_tools, "shell_manager", manager)
60+
61+
task = asyncio.create_task(
62+
bash.run(
63+
cmd=_python_shell("import time; time.sleep(10)"),
64+
context=_tool_context(tmp_path, session_id="session:target"),
65+
)
66+
)
67+
await asyncio.sleep(0.1)
68+
69+
task.cancel()
70+
with contextlib.suppress(asyncio.CancelledError):
71+
await task
72+
73+
assert manager._shells == {}
74+
75+
5476
@pytest.mark.asyncio
5577
async def test_bash_non_zero_exit_is_returned_as_tool_error(tmp_path) -> None:
5678
command = _python_shell("import sys; print('boom'); sys.exit(7)")
@@ -124,3 +146,46 @@ async def test_kill_bash_returns_status_when_process_already_finished(tmp_path)
124146
result = await kill_bash.run(shell_id=shell_id)
125147

126148
assert result == f"id: {shell_id}\nstatus: exited\nexit_code: 0"
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_quit_tool_terminates_background_shells_for_current_session(tmp_path, monkeypatch) -> None:
153+
manager = ShellManager()
154+
monkeypatch.setattr(builtin_tools, "shell_manager", manager)
155+
156+
target_started = await bash.run(
157+
cmd=_python_shell("import time; time.sleep(10)"),
158+
background=True,
159+
context=_tool_context(tmp_path, session_id="session:target"),
160+
)
161+
target_shell_id = target_started.removeprefix("started: ").strip()
162+
other_started = await bash.run(
163+
cmd=_python_shell("import time; time.sleep(10)"),
164+
background=True,
165+
context=_tool_context(tmp_path, session_id="session:other"),
166+
)
167+
other_shell_id = other_started.removeprefix("started: ").strip()
168+
169+
class FakeFramework:
170+
def __init__(self) -> None:
171+
self.quit_sessions: list[str] = []
172+
173+
async def quit_via_router(self, session_id: str) -> None:
174+
self.quit_sessions.append(session_id)
175+
176+
framework = FakeFramework()
177+
context = _tool_context(
178+
tmp_path,
179+
session_id="session:target",
180+
_runtime_agent=SimpleNamespace(framework=framework),
181+
)
182+
183+
result = await quit_tool.run(context=context)
184+
185+
assert result == "Session tasks stopped."
186+
assert framework.quit_sessions == ["session:target"]
187+
with pytest.raises(KeyError, match="unknown shell id"):
188+
await bash_output.run(shell_id=target_shell_id)
189+
assert manager.get(other_shell_id).returncode is None
190+
191+
await kill_bash.run(shell_id=other_shell_id)

tests/test_channels.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,45 @@ async def never_finish() -> None:
301301
await other_task
302302

303303

304+
@pytest.mark.asyncio
305+
async def test_channel_manager_quit_skips_current_task(load_config) -> None:
306+
_load_channel_config(load_config, enabled_channels="telegram")
307+
manager = ChannelManager(FakeFramework({"telegram": FakeChannel("telegram")}), enabled_channels=["telegram"])
308+
309+
async def never_finish() -> None:
310+
await asyncio.sleep(10)
311+
312+
current_task = asyncio.current_task()
313+
assert current_task is not None
314+
target_task = asyncio.create_task(never_finish())
315+
manager._ongoing_tasks["session:target"] = {current_task, target_task}
316+
317+
await manager.quit("session:target")
318+
319+
assert current_task.cancelled() is False
320+
assert target_task.cancelled()
321+
assert "session:target" not in manager._ongoing_tasks
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_channel_manager_done_callback_handles_cancelled_task(load_config) -> None:
326+
_load_channel_config(load_config, enabled_channels="telegram")
327+
manager = ChannelManager(FakeFramework({"telegram": FakeChannel("telegram")}), enabled_channels=["telegram"])
328+
329+
async def never_finish() -> None:
330+
await asyncio.sleep(10)
331+
332+
task = asyncio.create_task(never_finish())
333+
manager._ongoing_tasks["session:target"] = {task}
334+
task.cancel()
335+
with contextlib.suppress(asyncio.CancelledError):
336+
await task
337+
338+
manager._on_task_done("session:target", task)
339+
340+
assert "session:target" not in manager._ongoing_tasks
341+
342+
304343
def test_cli_channel_normalize_input_prefixes_shell_commands() -> None:
305344
channel = CliChannel.__new__(CliChannel)
306345
channel._mode = "shell"

0 commit comments

Comments
 (0)