|
1 | 1 | import asyncio |
2 | 2 | import contextlib |
3 | | -from collections.abc import AsyncGenerator |
4 | | -from dataclasses import dataclass |
| 3 | +from collections.abc import AsyncGenerator, AsyncIterable |
5 | 4 | from datetime import datetime |
6 | 5 | from hashlib import md5 |
7 | 6 | from pathlib import Path |
|
20 | 19 | from bub.builtin.tape import TapeInfo |
21 | 20 | from bub.channels.base import Channel |
22 | 21 | from bub.channels.cli.renderer import CliRenderer |
23 | | -from bub.channels.message import ChannelMessage, MessageKind |
| 22 | +from bub.channels.message import ChannelMessage |
24 | 23 | from bub.envelope import field_of |
25 | 24 | from bub.tools import REGISTRY |
26 | 25 | from bub.types import MessageHandler |
27 | 26 |
|
28 | 27 |
|
29 | | -@dataclass |
30 | | -class _StreamRenderState: |
31 | | - live: Live |
32 | | - kind: MessageKind |
33 | | - text: str = "" |
34 | | - |
35 | | - |
36 | 28 | class CliChannel(Channel): |
37 | 29 | """A simple CLI channel for testing and debugging.""" |
38 | 30 |
|
@@ -75,6 +67,11 @@ async def stop(self) -> None: |
75 | 67 | with contextlib.suppress(asyncio.CancelledError): |
76 | 68 | await self._main_task |
77 | 69 |
|
| 70 | + async def send(self, message: ChannelMessage) -> None: |
| 71 | + if message.kind != "error": |
| 72 | + return |
| 73 | + self._renderer.error(message.content) |
| 74 | + |
78 | 75 | async def _main_loop(self) -> None: |
79 | 76 | self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace)) |
80 | 77 | await self._refresh_tape_info() |
@@ -131,21 +128,25 @@ def _prompt_message(self) -> FormattedText: |
131 | 128 | symbol = ">" if self._mode == "agent" else "," |
132 | 129 | return FormattedText([("bold", f"{cwd} {symbol} ")]) |
133 | 130 |
|
134 | | - async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None: |
135 | | - streams = self._stream_render_states() |
136 | | - state = streams.get(message.session_id) |
137 | | - if event.kind == "text": |
138 | | - if state is None: |
139 | | - state = _StreamRenderState(live=self._renderer.start_stream(message.kind), kind=message.kind) |
140 | | - streams[message.session_id] = state |
141 | | - content = str(event.data.get("delta", "")) |
142 | | - state.text += content |
143 | | - self._renderer.update_stream(state.live, kind=message.kind, text=state.text) |
144 | | - elif event.kind == "final": |
145 | | - if state is None: |
146 | | - return |
147 | | - self._renderer.finish_stream(state.live, kind=state.kind, text=state.text) |
148 | | - streams.pop(message.session_id, None) |
| 131 | + async def stream_events( |
| 132 | + self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] |
| 133 | + ) -> AsyncIterable[StreamEvent]: |
| 134 | + live: Live | None = None |
| 135 | + text = "" |
| 136 | + try: |
| 137 | + async for event in stream: |
| 138 | + if event.kind == "text": |
| 139 | + content = str(event.data.get("delta", "")) |
| 140 | + if not content.strip() and not text: |
| 141 | + continue # skip leading whitespace-only events |
| 142 | + if live is None: |
| 143 | + live = self._renderer.start_stream(message.kind) |
| 144 | + text += content |
| 145 | + self._renderer.update_stream(live, kind=message.kind, text=text) |
| 146 | + yield event |
| 147 | + finally: |
| 148 | + if live is not None: |
| 149 | + self._renderer.finish_stream(live, kind=message.kind, text=text) |
149 | 150 |
|
150 | 151 | def _build_prompt(self, workspace: Path) -> PromptSession[str]: |
151 | 152 | kb = KeyBindings() |
@@ -188,10 +189,3 @@ def _render_bottom_toolbar(self) -> FormattedText: |
188 | 189 | def _history_file(home: Path, workspace: Path) -> Path: |
189 | 190 | workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest() |
190 | 191 | return home / "history" / f"{workspace_hash}.history" |
191 | | - |
192 | | - def _stream_render_states(self) -> dict[str, _StreamRenderState]: |
193 | | - states = getattr(self, "_active_stream_renders", None) |
194 | | - if states is None: |
195 | | - states = {} |
196 | | - self._active_stream_renders = states |
197 | | - return states |
0 commit comments