diff --git a/src/bub/builtin/shell_manager.py b/src/bub/builtin/shell_manager.py index 284c965b..68aef08d 100644 --- a/src/bub/builtin/shell_manager.py +++ b/src/bub/builtin/shell_manager.py @@ -53,6 +53,9 @@ def get(self, shell_id: str) -> ManagedShell: except KeyError as exc: raise KeyError(f"unknown shell id: {shell_id}") from exc + def release(self, shell_id: str) -> ManagedShell | None: + return self._shells.pop(shell_id, None) + async def terminate(self, shell_id: str) -> ManagedShell: shell = self.get(shell_id) if shell.returncode is not None: @@ -80,6 +83,7 @@ async def _finalize_shell(self, shell: ManagedShell) -> None: for task in shell.read_tasks: with contextlib.suppress(asyncio.CancelledError): await task + self._shells.pop(shell.shell_id, None) async def _drain_stream( self, diff --git a/tests/test_builtin_tools.py b/tests/test_builtin_tools.py index 36c7e41e..cf731859 100644 --- a/tests/test_builtin_tools.py +++ b/tests/test_builtin_tools.py @@ -9,6 +9,8 @@ from republic.core.errors import ErrorKind from republic.tools.executor import ToolExecutor +import bub.builtin.tools as builtin_tools +from bub.builtin.shell_manager import ShellManager from bub.builtin.tools import bash, bash_output, kill_bash @@ -27,6 +29,28 @@ async def test_bash_returns_stdout_for_foreground_command(tmp_path) -> None: assert result == "hello" +@pytest.mark.asyncio +async def test_foreground_bash_releases_shell_from_shell_manager(tmp_path, monkeypatch) -> None: + manager = ShellManager() + monkeypatch.setattr(builtin_tools, "shell_manager", manager) + + result = await bash.run(cmd=_python_shell("print('hello')"), context=_tool_context(tmp_path)) + + assert result == "hello" + assert manager._shells == {} + + +@pytest.mark.asyncio +async def test_foreground_bash_releases_shell_when_command_fails(tmp_path, monkeypatch) -> None: + manager = ShellManager() + monkeypatch.setattr(builtin_tools, "shell_manager", manager) + + with pytest.raises(RuntimeError, match="command exited with code"): + await bash.run(cmd=_python_shell("import sys; sys.exit(2)"), context=_tool_context(tmp_path)) + + assert manager._shells == {} + + @pytest.mark.asyncio async def test_bash_non_zero_exit_is_returned_as_tool_error(tmp_path) -> None: command = _python_shell("import sys; print('boom'); sys.exit(7)") @@ -71,7 +95,7 @@ async def test_background_bash_exposes_output_via_bash_output(tmp_path) -> None: @pytest.mark.asyncio -async def test_kill_bash_terminates_background_process(tmp_path) -> None: +async def test_kill_bash_terminates_background_process_and_releases_shell(tmp_path) -> None: started = await bash.run( cmd=_python_shell("import time; time.sleep(10)"), background=True, @@ -80,11 +104,11 @@ async def test_kill_bash_terminates_background_process(tmp_path) -> None: shell_id = started.removeprefix("started: ").strip() killed = await kill_bash.run(shell_id=shell_id) - output = await bash_output.run(shell_id=shell_id) assert killed.startswith(f"id: {shell_id}\nstatus: exited\nexit_code: ") assert "exit_code: null" not in killed - assert output.startswith(f"id: {shell_id}\nstatus: exited\n") + with pytest.raises(KeyError, match="unknown shell id"): + await bash_output.run(shell_id=shell_id) @pytest.mark.asyncio