Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/bub/builtin/shell_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import asyncio
import contextlib
import uuid
from dataclasses import dataclass, field


@dataclass(slots=True)
class ManagedShell:
shell_id: str
cmd: str
cwd: str | None
process: asyncio.subprocess.Process
output_chunks: list[str] = field(default_factory=list)
read_tasks: list[asyncio.Task[None]] = field(default_factory=list)

@property
def output(self) -> str:
return "".join(self.output_chunks)

@property
def returncode(self) -> int | None:
return self.process.returncode

@property
def status(self) -> str:
return "running" if self.returncode is None else "exited"


class ShellManager:
def __init__(self) -> None:
self._shells: dict[str, ManagedShell] = {}

async def start(self, *, cmd: str, cwd: str | None) -> ManagedShell:
process = await asyncio.create_subprocess_shell(
cmd,
cwd=cwd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
shell = ManagedShell(shell_id=f"bash-{uuid.uuid4().hex[:8]}", cmd=cmd, cwd=cwd, process=process)
shell.read_tasks.extend([
asyncio.create_task(self._drain_stream(shell, process.stdout)),
asyncio.create_task(self._drain_stream(shell, process.stderr)),
])
self._shells[shell.shell_id] = shell
return shell

def get(self, shell_id: str) -> ManagedShell:
try:
return self._shells[shell_id]
except KeyError as exc:
raise KeyError(f"unknown shell id: {shell_id}") from exc

async def terminate(self, shell_id: str) -> ManagedShell:
shell = self.get(shell_id)
if shell.returncode is not None:
await self._finalize_shell(shell)
return shell

shell.process.terminate()
try:
async with asyncio.timeout(3):
await shell.process.wait()
except TimeoutError:
shell.process.kill()
await shell.process.wait()
await self._finalize_shell(shell)
return shell

async def wait_closed(self, shell_id: str) -> ManagedShell:
shell = self.get(shell_id)
if shell.returncode is None:
await shell.process.wait()
await self._finalize_shell(shell)
return shell

async def _finalize_shell(self, shell: ManagedShell) -> None:
for task in shell.read_tasks:
with contextlib.suppress(asyncio.CancelledError):
await task

async def _drain_stream(
self,
shell: ManagedShell,
stream: asyncio.StreamReader | None,
) -> None:
if stream is None:
return
while chunk := await stream.read(4096):
shell.output_chunks.append(chunk.decode("utf-8", errors="replace"))


shell_manager = ShellManager()
64 changes: 48 additions & 16 deletions src/bub/builtin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel, Field
from republic import AsyncTapeStore, TapeQuery, ToolContext

from bub.builtin.shell_manager import shell_manager
from bub.skills import discover_skills
from bub.tools import tool

Expand Down Expand Up @@ -59,24 +60,52 @@ class SubAgentInput(BaseModel):

@tool(context=True)
async def bash(
cmd: str, cwd: str | None = None, timeout_seconds: int = DEFAULT_COMMAND_TIMEOUT_SECONDS, *, context: ToolContext
cmd: str,
cwd: str | None = None,
timeout_seconds: int = DEFAULT_COMMAND_TIMEOUT_SECONDS,
background: bool = False,
*,
context: ToolContext,
) -> str:
"""Run a shell command and return its output within a time limit. Raises if the command fails or times out."""
"""Run a shell command. Use background=true to keep it running and fetch output later via bash_output."""
workspace = context.state.get("_runtime_workspace")
completed = await asyncio.create_subprocess_shell(
cmd,
cwd=cwd or workspace,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
async with asyncio.timeout(timeout_seconds):
stdout_bytes, stderr_bytes = await completed.communicate()
stdout_text = (stdout_bytes or b"").decode("utf-8", errors="replace").strip()
stderr_text = (stderr_bytes or b"").decode("utf-8", errors="replace").strip()
if completed.returncode != 0:
message = stderr_text or stdout_text or f"exit={completed.returncode}"
raise RuntimeError(f"exit={completed.returncode}: {message}")
return stdout_text or "(no output)"
target_cwd = cwd or workspace
shell = await shell_manager.start(cmd=cmd, cwd=target_cwd)
if background:
return f"started: {shell.shell_id}"
try:
async with asyncio.timeout(timeout_seconds):
await shell_manager.wait_closed(shell.shell_id)
except TimeoutError:
await shell_manager.terminate(shell.shell_id)
return f"command timed out after {timeout_seconds} seconds and was terminated"
return shell.output.strip() or "(no output)"


@tool(name="bash.output")
async def bash_output(shell_id: str, offset: int = 0, limit: int | None = None) -> str:
"""Read buffered output from a background shell, with optional offset/limit for incremental polling."""
shell = shell_manager.get(shell_id)
if shell.returncode is not None:
await shell_manager.wait_closed(shell_id)
output = shell.output
start = max(0, min(offset, len(output)))
end = len(output) if limit is None else min(len(output), start + max(0, limit))
chunk = output[start:end].rstrip()
exit_code = "null" if shell.returncode is None else str(shell.returncode)
body = chunk or "(no output)"
return f"id: {shell.shell_id}\nstatus: {shell.status}\nexit_code: {exit_code}\nnext_offset: {end}\noutput:\n{body}"


@tool(name="bash.kill")
async def kill_bash(shell_id: str) -> str:
"""Terminate a background shell process."""
shell = shell_manager.get(shell_id)
if shell.returncode is None:
shell = await shell_manager.terminate(shell_id)
else:
await shell_manager.wait_closed(shell_id)
return f"id: {shell.shell_id}\nstatus: {shell.status}\nexit_code: {shell.returncode}"


@tool(context=True, name="fs.read")
Expand Down Expand Up @@ -243,6 +272,9 @@ def show_help() -> str:
" ,fs.read path=README.md\n"
" ,fs.write path=tmp.txt content='hello'\n"
" ,fs.edit path=tmp.txt old=hello new=world\n"
" ,bash cmd='sleep 5' background=true\n"
" ,bash_output shell_id=bsh-12345678\n"
" ,kill_bash shell_id=bsh-12345678\n"
"Any unknown command after ',' is executed as shell via bash."
)

Expand Down
75 changes: 75 additions & 0 deletions tests/test_builtin_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import asyncio
import shlex
import sys

import pytest
from republic import ToolContext

from bub.builtin.tools import bash, bash_output, kill_bash


def _tool_context(tmp_path) -> ToolContext:
return ToolContext(tape="test-tape", run_id="test-run", state={"_runtime_workspace": str(tmp_path)})


def _python_shell(code: str) -> str:
return f"{shlex.quote(sys.executable)} -c {shlex.quote(code)}"


@pytest.mark.asyncio
async def test_bash_returns_stdout_for_foreground_command(tmp_path) -> None:
result = await bash.run(cmd=_python_shell("print('hello')"), context=_tool_context(tmp_path))

assert result == "hello"


@pytest.mark.asyncio
async def test_background_bash_exposes_output_via_bash_output(tmp_path) -> None:
command = _python_shell(
"import sys, time; print('start'); sys.stdout.flush(); time.sleep(0.2); print('done'); sys.stdout.flush()"
)

started = await bash.run(cmd=command, background=True, context=_tool_context(tmp_path))
shell_id = started.removeprefix("started: ").strip()

await asyncio.sleep(0.35)
output = await bash_output.run(shell_id=shell_id)

assert output.startswith(f"id: {shell_id}\nstatus: exited\n")
assert "exit_code: 0" in output
assert "start" in output
assert "done" in output


@pytest.mark.asyncio
async def test_kill_bash_terminates_background_process(tmp_path) -> None:
started = await bash.run(
cmd=_python_shell("import time; time.sleep(10)"),
background=True,
context=_tool_context(tmp_path),
)
shell_id = started.removeprefix("started: ").strip()

killed = await kill_bash.run(shell_id=shell_id)
output = await bash_output.run(shell_id=shell_id)

assert killed.startswith(f"id: {shell_id}\nstatus: exited\nexit_code: ")
assert "exit_code: null" not in killed
assert output.startswith(f"id: {shell_id}\nstatus: exited\n")


@pytest.mark.asyncio
async def test_kill_bash_returns_status_when_process_already_finished(tmp_path) -> None:
started = await bash.run(
cmd=_python_shell("print('done')"),
background=True,
context=_tool_context(tmp_path),
)
shell_id = started.removeprefix("started: ").strip()

await asyncio.sleep(0.1)
result = await kill_bash.run(shell_id=shell_id)

assert result == f"id: {shell_id}\nstatus: exited\nexit_code: 0"
Loading