Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hashlib import md5
from pathlib import Path

from loguru import logger
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.formatted_text import FormattedText
Expand Down Expand Up @@ -43,10 +44,16 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None:
self._mode = "agent" # or "shell"
self._main_task: asyncio.Task | None = None
self._renderer = CliRenderer(get_console())
self._log_handler_id = self._install_log_sink()
self._last_tape_info: TapeInfo | None = None
self._workspace = self._agent.framework.workspace
self._prompt = self._build_prompt(self._workspace)

def _install_log_sink(self) -> int:
with contextlib.suppress(ValueError):
logger.remove(0)
return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}")

async def _refresh_tape_info(self) -> None:
tape = self._agent.tapes.session_tape(self._message_template["session_id"], self._workspace)
info = await self._agent.tapes.info(tape.name)
Expand All @@ -67,6 +74,8 @@ async def stop(self) -> None:
self._main_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._main_task
with contextlib.suppress(ValueError):
logger.remove(self._log_handler_id)

async def send(self, message: ChannelMessage) -> None:
if message.kind != "error":
Expand Down Expand Up @@ -140,10 +149,11 @@ async def stream_events(
content = str(event.data.get("delta", ""))
if not content.strip() and not text:
continue # skip leading whitespace-only events
if live is None:
live = self._renderer.start_stream(message.kind)
text += content
self._renderer.update_stream(live, kind=message.kind, text=text)
if live is None:
live = self._renderer.start_stream(message.kind, text)
else:
self._renderer.update_stream(live, kind=message.kind, text=text)
yield event
finally:
if live is not None:
Expand Down
12 changes: 8 additions & 4 deletions src/bub/channels/cli/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ def error(self, text: str) -> None:
return
self.console.print(self.panel("error", text))

def start_stream(self, kind: MessageKind) -> Live:
def log(self, message: object) -> None:
text = str(message).rstrip()
if text:
self.console.print(text)

def start_stream(self, kind: MessageKind, text: str) -> Live:
live = Live(
self.panel(kind, ""),
self.panel(kind, text),
console=self.console,
auto_refresh=False,
transient=False,
vertical_overflow="visible",
)
live.start()
live.refresh()
live.start(refresh=True)
return live

def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:
Expand Down
41 changes: 38 additions & 3 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from republic import StreamEvent

from bub.channels.cli import CliChannel
from bub.channels.cli import renderer as cli_renderer
from bub.channels.cli.renderer import CliRenderer
from bub.channels.handler import BufferedMessageHandler
from bub.channels.manager import ChannelManager
from bub.channels.message import ChannelMessage
Expand Down Expand Up @@ -301,7 +303,7 @@ async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> N
events: list[tuple[str, str, str]] = []
live_handle = object()
channel._renderer = SimpleNamespace(
start_stream=lambda kind: events.append(("start", kind, "")) or live_handle,
start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
error=lambda content: events.append(("error", "error", content)),
Expand All @@ -319,8 +321,7 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:
yielded = [event async for event in channel.stream_events(message, source())]

assert events == [
("start", "command", ""),
("update", "command", "hel"),
("start", "command", "hel"),
("update", "command", "hello"),
("finish", "command", "hello"),
]
Expand All @@ -337,6 +338,40 @@ def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
assert result.suffix == ".history"


def test_cli_renderer_stream_uses_live_with_initial_text(monkeypatch: pytest.MonkeyPatch) -> None:
live_calls: list[tuple[str, object]] = []

class FakeLive:
def __init__(self, renderable, **kwargs) -> None:
live_calls.append(("init", renderable))
live_calls.append(("transient", kwargs["transient"]))
self.renderable = renderable

def start(self, *, refresh: bool = False) -> None:
live_calls.append(("start_refresh", refresh))

def update(self, renderable, *, refresh: bool = False) -> None:
live_calls.append(("update_refresh", refresh))
self.renderable = renderable

def stop(self) -> None:
live_calls.append(("stop", self.renderable))

printed: list[str] = []
console = SimpleNamespace(print=printed.append)
monkeypatch.setattr(cli_renderer, "Live", FakeLive)

renderer = CliRenderer(console) # type: ignore[arg-type]
live = renderer.start_stream("normal", "hel")
renderer.update_stream(live, kind="normal", text="hello") # type: ignore[arg-type]
renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type]

assert ("transient", False) in live_calls
assert ("start_refresh", True) in live_calls
assert ("update_refresh", True) in live_calls
assert not printed


def test_bub_message_filter_accepts_private_messages() -> None:
message = SimpleNamespace(chat=SimpleNamespace(type="private"), text="hello")

Expand Down
Loading