diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 6959ad37..0ee4df9a 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -5,6 +5,7 @@ from hashlib import md5 from pathlib import Path +from loguru import logger from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter from prompt_toolkit.formatted_text import FormattedText @@ -43,10 +44,16 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None: self._mode = "agent" # or "shell" self._main_task: asyncio.Task | None = None self._renderer = CliRenderer(get_console()) + self._log_handler_id = self._install_log_sink() self._last_tape_info: TapeInfo | None = None self._workspace = self._agent.framework.workspace self._prompt = self._build_prompt(self._workspace) + def _install_log_sink(self) -> int: + with contextlib.suppress(ValueError): + logger.remove(0) + return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}") + async def _refresh_tape_info(self) -> None: tape = self._agent.tapes.session_tape(self._message_template["session_id"], self._workspace) info = await self._agent.tapes.info(tape.name) @@ -67,6 +74,8 @@ async def stop(self) -> None: self._main_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._main_task + with contextlib.suppress(ValueError): + logger.remove(self._log_handler_id) async def send(self, message: ChannelMessage) -> None: if message.kind != "error": @@ -140,10 +149,11 @@ async def stream_events( content = str(event.data.get("delta", "")) if not content.strip() and not text: continue # skip leading whitespace-only events - if live is None: - live = self._renderer.start_stream(message.kind) text += content - self._renderer.update_stream(live, kind=message.kind, text=text) + if live is None: + live = self._renderer.start_stream(message.kind, text) + else: + self._renderer.update_stream(live, kind=message.kind, text=text) yield event finally: if live is not None: diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py index 2db91d7c..8f4581f9 100644 --- a/src/bub/channels/cli/renderer.py +++ b/src/bub/channels/cli/renderer.py @@ -52,16 +52,20 @@ def error(self, text: str) -> None: return self.console.print(self.panel("error", text)) - def start_stream(self, kind: MessageKind) -> Live: + def log(self, message: object) -> None: + text = str(message).rstrip() + if text: + self.console.print(text) + + def start_stream(self, kind: MessageKind, text: str) -> Live: live = Live( - self.panel(kind, ""), + self.panel(kind, text), console=self.console, auto_refresh=False, transient=False, vertical_overflow="visible", ) - live.start() - live.refresh() + live.start(refresh=True) return live def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None: diff --git a/tests/test_channels.py b/tests/test_channels.py index b272ae3b..54f56158 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -10,6 +10,8 @@ from republic import StreamEvent from bub.channels.cli import CliChannel +from bub.channels.cli import renderer as cli_renderer +from bub.channels.cli.renderer import CliRenderer from bub.channels.handler import BufferedMessageHandler from bub.channels.manager import ChannelManager from bub.channels.message import ChannelMessage @@ -301,7 +303,7 @@ async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> N events: list[tuple[str, str, str]] = [] live_handle = object() channel._renderer = SimpleNamespace( - start_stream=lambda kind: events.append(("start", kind, "")) or live_handle, + start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle, update_stream=lambda live, *, kind, text: events.append(("update", kind, text)), finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)), error=lambda content: events.append(("error", "error", content)), @@ -319,8 +321,7 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]: yielded = [event async for event in channel.stream_events(message, source())] assert events == [ - ("start", "command", ""), - ("update", "command", "hel"), + ("start", "command", "hel"), ("update", "command", "hello"), ("finish", "command", "hello"), ] @@ -337,6 +338,40 @@ def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: assert result.suffix == ".history" +def test_cli_renderer_stream_uses_live_with_initial_text(monkeypatch: pytest.MonkeyPatch) -> None: + live_calls: list[tuple[str, object]] = [] + + class FakeLive: + def __init__(self, renderable, **kwargs) -> None: + live_calls.append(("init", renderable)) + live_calls.append(("transient", kwargs["transient"])) + self.renderable = renderable + + def start(self, *, refresh: bool = False) -> None: + live_calls.append(("start_refresh", refresh)) + + def update(self, renderable, *, refresh: bool = False) -> None: + live_calls.append(("update_refresh", refresh)) + self.renderable = renderable + + def stop(self) -> None: + live_calls.append(("stop", self.renderable)) + + printed: list[str] = [] + console = SimpleNamespace(print=printed.append) + monkeypatch.setattr(cli_renderer, "Live", FakeLive) + + renderer = CliRenderer(console) # type: ignore[arg-type] + live = renderer.start_stream("normal", "hel") + renderer.update_stream(live, kind="normal", text="hello") # type: ignore[arg-type] + renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type] + + assert ("transient", False) in live_calls + assert ("start_refresh", True) in live_calls + assert ("update_refresh", True) in live_calls + assert not printed + + def test_bub_message_filter_accepts_private_messages() -> None: message = SimpleNamespace(chat=SimpleNamespace(type="private"), text="hello")