diff --git a/src/bub/builtin/shell_manager.py b/src/bub/builtin/shell_manager.py new file mode 100644 index 00000000..284c965b --- /dev/null +++ b/src/bub/builtin/shell_manager.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import asyncio +import contextlib +import uuid +from dataclasses import dataclass, field + + +@dataclass(slots=True) +class ManagedShell: + shell_id: str + cmd: str + cwd: str | None + process: asyncio.subprocess.Process + output_chunks: list[str] = field(default_factory=list) + read_tasks: list[asyncio.Task[None]] = field(default_factory=list) + + @property + def output(self) -> str: + return "".join(self.output_chunks) + + @property + def returncode(self) -> int | None: + return self.process.returncode + + @property + def status(self) -> str: + return "running" if self.returncode is None else "exited" + + +class ShellManager: + def __init__(self) -> None: + self._shells: dict[str, ManagedShell] = {} + + async def start(self, *, cmd: str, cwd: str | None) -> ManagedShell: + process = await asyncio.create_subprocess_shell( + cmd, + cwd=cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + shell = ManagedShell(shell_id=f"bash-{uuid.uuid4().hex[:8]}", cmd=cmd, cwd=cwd, process=process) + shell.read_tasks.extend([ + asyncio.create_task(self._drain_stream(shell, process.stdout)), + asyncio.create_task(self._drain_stream(shell, process.stderr)), + ]) + self._shells[shell.shell_id] = shell + return shell + + def get(self, shell_id: str) -> ManagedShell: + try: + return self._shells[shell_id] + except KeyError as exc: + raise KeyError(f"unknown shell id: {shell_id}") from exc + + async def terminate(self, shell_id: str) -> ManagedShell: + shell = self.get(shell_id) + if shell.returncode is not None: + await self._finalize_shell(shell) + return shell + + shell.process.terminate() + try: + async with asyncio.timeout(3): + await shell.process.wait() + except TimeoutError: + shell.process.kill() + await shell.process.wait() + await self._finalize_shell(shell) + return shell + + async def wait_closed(self, shell_id: str) -> ManagedShell: + shell = self.get(shell_id) + if shell.returncode is None: + await shell.process.wait() + await self._finalize_shell(shell) + return shell + + async def _finalize_shell(self, shell: ManagedShell) -> None: + for task in shell.read_tasks: + with contextlib.suppress(asyncio.CancelledError): + await task + + async def _drain_stream( + self, + shell: ManagedShell, + stream: asyncio.StreamReader | None, + ) -> None: + if stream is None: + return + while chunk := await stream.read(4096): + shell.output_chunks.append(chunk.decode("utf-8", errors="replace")) + + +shell_manager = ShellManager() diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index e0b89f41..04747917 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from republic import AsyncTapeStore, TapeQuery, ToolContext +from bub.builtin.shell_manager import shell_manager from bub.skills import discover_skills from bub.tools import tool @@ -59,24 +60,52 @@ class SubAgentInput(BaseModel): @tool(context=True) async def bash( - cmd: str, cwd: str | None = None, timeout_seconds: int = DEFAULT_COMMAND_TIMEOUT_SECONDS, *, context: ToolContext + cmd: str, + cwd: str | None = None, + timeout_seconds: int = DEFAULT_COMMAND_TIMEOUT_SECONDS, + background: bool = False, + *, + context: ToolContext, ) -> str: - """Run a shell command and return its output within a time limit. Raises if the command fails or times out.""" + """Run a shell command. Use background=true to keep it running and fetch output later via bash_output.""" workspace = context.state.get("_runtime_workspace") - completed = await asyncio.create_subprocess_shell( - cmd, - cwd=cwd or workspace, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - async with asyncio.timeout(timeout_seconds): - stdout_bytes, stderr_bytes = await completed.communicate() - stdout_text = (stdout_bytes or b"").decode("utf-8", errors="replace").strip() - stderr_text = (stderr_bytes or b"").decode("utf-8", errors="replace").strip() - if completed.returncode != 0: - message = stderr_text or stdout_text or f"exit={completed.returncode}" - raise RuntimeError(f"exit={completed.returncode}: {message}") - return stdout_text or "(no output)" + target_cwd = cwd or workspace + shell = await shell_manager.start(cmd=cmd, cwd=target_cwd) + if background: + return f"started: {shell.shell_id}" + try: + async with asyncio.timeout(timeout_seconds): + await shell_manager.wait_closed(shell.shell_id) + except TimeoutError: + await shell_manager.terminate(shell.shell_id) + return f"command timed out after {timeout_seconds} seconds and was terminated" + return shell.output.strip() or "(no output)" + + +@tool(name="bash.output") +async def bash_output(shell_id: str, offset: int = 0, limit: int | None = None) -> str: + """Read buffered output from a background shell, with optional offset/limit for incremental polling.""" + shell = shell_manager.get(shell_id) + if shell.returncode is not None: + await shell_manager.wait_closed(shell_id) + output = shell.output + start = max(0, min(offset, len(output))) + end = len(output) if limit is None else min(len(output), start + max(0, limit)) + chunk = output[start:end].rstrip() + exit_code = "null" if shell.returncode is None else str(shell.returncode) + body = chunk or "(no output)" + return f"id: {shell.shell_id}\nstatus: {shell.status}\nexit_code: {exit_code}\nnext_offset: {end}\noutput:\n{body}" + + +@tool(name="bash.kill") +async def kill_bash(shell_id: str) -> str: + """Terminate a background shell process.""" + shell = shell_manager.get(shell_id) + if shell.returncode is None: + shell = await shell_manager.terminate(shell_id) + else: + await shell_manager.wait_closed(shell_id) + return f"id: {shell.shell_id}\nstatus: {shell.status}\nexit_code: {shell.returncode}" @tool(context=True, name="fs.read") @@ -243,6 +272,9 @@ def show_help() -> str: " ,fs.read path=README.md\n" " ,fs.write path=tmp.txt content='hello'\n" " ,fs.edit path=tmp.txt old=hello new=world\n" + " ,bash cmd='sleep 5' background=true\n" + " ,bash_output shell_id=bsh-12345678\n" + " ,kill_bash shell_id=bsh-12345678\n" "Any unknown command after ',' is executed as shell via bash." ) diff --git a/tests/test_builtin_tools.py b/tests/test_builtin_tools.py new file mode 100644 index 00000000..3431492c --- /dev/null +++ b/tests/test_builtin_tools.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import asyncio +import shlex +import sys + +import pytest +from republic import ToolContext + +from bub.builtin.tools import bash, bash_output, kill_bash + + +def _tool_context(tmp_path) -> ToolContext: + return ToolContext(tape="test-tape", run_id="test-run", state={"_runtime_workspace": str(tmp_path)}) + + +def _python_shell(code: str) -> str: + return f"{shlex.quote(sys.executable)} -c {shlex.quote(code)}" + + +@pytest.mark.asyncio +async def test_bash_returns_stdout_for_foreground_command(tmp_path) -> None: + result = await bash.run(cmd=_python_shell("print('hello')"), context=_tool_context(tmp_path)) + + assert result == "hello" + + +@pytest.mark.asyncio +async def test_background_bash_exposes_output_via_bash_output(tmp_path) -> None: + command = _python_shell( + "import sys, time; print('start'); sys.stdout.flush(); time.sleep(0.2); print('done'); sys.stdout.flush()" + ) + + started = await bash.run(cmd=command, background=True, context=_tool_context(tmp_path)) + shell_id = started.removeprefix("started: ").strip() + + await asyncio.sleep(0.35) + output = await bash_output.run(shell_id=shell_id) + + assert output.startswith(f"id: {shell_id}\nstatus: exited\n") + assert "exit_code: 0" in output + assert "start" in output + assert "done" in output + + +@pytest.mark.asyncio +async def test_kill_bash_terminates_background_process(tmp_path) -> None: + started = await bash.run( + cmd=_python_shell("import time; time.sleep(10)"), + background=True, + context=_tool_context(tmp_path), + ) + shell_id = started.removeprefix("started: ").strip() + + killed = await kill_bash.run(shell_id=shell_id) + output = await bash_output.run(shell_id=shell_id) + + assert killed.startswith(f"id: {shell_id}\nstatus: exited\nexit_code: ") + assert "exit_code: null" not in killed + assert output.startswith(f"id: {shell_id}\nstatus: exited\n") + + +@pytest.mark.asyncio +async def test_kill_bash_returns_status_when_process_already_finished(tmp_path) -> None: + started = await bash.run( + cmd=_python_shell("print('done')"), + background=True, + context=_tool_context(tmp_path), + ) + shell_id = started.removeprefix("started: ").strip() + + await asyncio.sleep(0.1) + result = await kill_bash.run(shell_id=shell_id) + + assert result == f"id: {shell_id}\nstatus: exited\nexit_code: 0"