Skip to content

Commit 289f585

Browse files
authored
feat: implement stream event handling in CliChannel and update related methods (#160)
Signed-off-by: Frost Ming <me@frostming.com>
1 parent 746238e commit 289f585

6 files changed

Lines changed: 47 additions & 55 deletions

File tree

src/bub/channels/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3+
from collections.abc import AsyncIterable
34
from typing import ClassVar
45

56
from republic import StreamEvent
@@ -35,7 +36,6 @@ async def send(self, message: ChannelMessage) -> None:
3536
# Do nothing by default
3637
return
3738

38-
async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None:
39-
"""Handle an event from the agent. Optional to implement."""
40-
# Do nothing by default
41-
return
39+
def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
40+
"""Optionally wrap the output stream for this channel."""
41+
return stream

src/bub/channels/cli/__init__.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import contextlib
3-
from collections.abc import AsyncGenerator
4-
from dataclasses import dataclass
3+
from collections.abc import AsyncGenerator, AsyncIterable
54
from datetime import datetime
65
from hashlib import md5
76
from pathlib import Path
@@ -20,19 +19,12 @@
2019
from bub.builtin.tape import TapeInfo
2120
from bub.channels.base import Channel
2221
from bub.channels.cli.renderer import CliRenderer
23-
from bub.channels.message import ChannelMessage, MessageKind
22+
from bub.channels.message import ChannelMessage
2423
from bub.envelope import field_of
2524
from bub.tools import REGISTRY
2625
from bub.types import MessageHandler
2726

2827

29-
@dataclass
30-
class _StreamRenderState:
31-
live: Live
32-
kind: MessageKind
33-
text: str = ""
34-
35-
3628
class CliChannel(Channel):
3729
"""A simple CLI channel for testing and debugging."""
3830

@@ -75,6 +67,11 @@ async def stop(self) -> None:
7567
with contextlib.suppress(asyncio.CancelledError):
7668
await self._main_task
7769

70+
async def send(self, message: ChannelMessage) -> None:
71+
if message.kind != "error":
72+
return
73+
self._renderer.error(message.content)
74+
7875
async def _main_loop(self) -> None:
7976
self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace))
8077
await self._refresh_tape_info()
@@ -131,21 +128,25 @@ def _prompt_message(self) -> FormattedText:
131128
symbol = ">" if self._mode == "agent" else ","
132129
return FormattedText([("bold", f"{cwd} {symbol} ")])
133130

134-
async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None:
135-
streams = self._stream_render_states()
136-
state = streams.get(message.session_id)
137-
if event.kind == "text":
138-
if state is None:
139-
state = _StreamRenderState(live=self._renderer.start_stream(message.kind), kind=message.kind)
140-
streams[message.session_id] = state
141-
content = str(event.data.get("delta", ""))
142-
state.text += content
143-
self._renderer.update_stream(state.live, kind=message.kind, text=state.text)
144-
elif event.kind == "final":
145-
if state is None:
146-
return
147-
self._renderer.finish_stream(state.live, kind=state.kind, text=state.text)
148-
streams.pop(message.session_id, None)
131+
async def stream_events(
132+
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
133+
) -> AsyncIterable[StreamEvent]:
134+
live: Live | None = None
135+
text = ""
136+
try:
137+
async for event in stream:
138+
if event.kind == "text":
139+
content = str(event.data.get("delta", ""))
140+
if not content.strip() and not text:
141+
continue # skip leading whitespace-only events
142+
if live is None:
143+
live = self._renderer.start_stream(message.kind)
144+
text += content
145+
self._renderer.update_stream(live, kind=message.kind, text=text)
146+
yield event
147+
finally:
148+
if live is not None:
149+
self._renderer.finish_stream(live, kind=message.kind, text=text)
149150

150151
def _build_prompt(self, workspace: Path) -> PromptSession[str]:
151152
kb = KeyBindings()
@@ -188,10 +189,3 @@ def _render_bottom_toolbar(self) -> FormattedText:
188189
def _history_file(home: Path, workspace: Path) -> Path:
189190
workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest()
190191
return home / "history" / f"{workspace_hash}.history"
191-
192-
def _stream_render_states(self) -> dict[str, _StreamRenderState]:
193-
states = getattr(self, "_active_stream_renders", None)
194-
if states is None:
195-
states = {}
196-
self._active_stream_renders = states
197-
return states

src/bub/channels/manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import contextlib
33
import functools
4-
from collections.abc import Collection
4+
from collections.abc import AsyncIterable, Collection
55

66
from loguru import logger
77
from pydantic import Field
@@ -94,17 +94,17 @@ async def dispatch_output(self, message: Envelope) -> bool:
9494
await channel.send(outbound)
9595
return True
9696

97-
async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None:
97+
def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]:
9898
channel_name = field_of(message, "output_channel", field_of(message, "channel"))
9999
if channel_name is None:
100-
return
100+
return stream
101101

102102
channel_key = str(channel_name)
103103
channel = self.get_channel(channel_key)
104104
if channel is None:
105-
return
105+
return stream
106106

107-
await channel.on_event(event, message)
107+
return channel.stream_events(message, stream)
108108

109109
async def quit(self, session_id: str) -> None:
110110
tasks = self._ongoing_tasks.pop(session_id, set())

src/bub/framework.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ async def _run_model(
140140
return prompt if isinstance(prompt, str) else content_of(inbound)
141141
else:
142142
parts: list[str] = []
143+
if self._outbound_router is not None:
144+
stream = self._outbound_router.wrap_stream(inbound, stream)
143145
async for event in stream:
144-
await self.dispatch_event_via_router(event, inbound)
145146
if event.kind == "text":
146147
parts.append(str(event.data.get("delta", "")))
147148
elif event.kind == "error":
@@ -163,12 +164,6 @@ async def dispatch_via_router(self, message: Envelope) -> bool:
163164
return False
164165
return await self._outbound_router.dispatch_output(message)
165166

166-
async def dispatch_event_via_router(self, event: Any, message: Envelope) -> bool:
167-
if self._outbound_router is not None:
168-
await self._outbound_router.dispatch_event(event, message)
169-
return True
170-
return False
171-
172167
async def quit_via_router(self, session_id: str) -> None:
173168
if self._outbound_router is not None:
174169
await self._outbound_router.quit(session_id)

src/bub/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Callable, Coroutine
5+
from collections.abc import AsyncIterable, Callable, Coroutine
66
from dataclasses import dataclass, field
77
from typing import Any, Protocol
88

@@ -16,7 +16,7 @@
1616

1717
class OutboundChannelRouter(Protocol):
1818
async def dispatch_output(self, message: Envelope) -> bool: ...
19-
async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: ...
19+
def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ...
2020
async def quit(self, session_id: str) -> None: ...
2121

2222

tests/test_channels.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_cli_channel_normalize_input_prefixes_shell_commands() -> None:
207207

208208

209209
@pytest.mark.asyncio
210-
async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send() -> None:
210+
async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> None:
211211
channel = CliChannel.__new__(CliChannel)
212212
events: list[tuple[str, str, str]] = []
213213
live_handle = object()
@@ -222,17 +222,20 @@ async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send(
222222

223223
message = _message("ignored", channel="cli", kind="command", session_id="cli:1")
224224

225-
await channel.on_event(StreamEvent("text", {"delta": "hel"}), message)
226-
await channel.on_event(StreamEvent("text", {"delta": "lo"}), message)
227-
await channel.on_event(StreamEvent("final", {}), message)
228-
await channel.send(_message("hello", channel="cli", kind="command", session_id="cli:1"))
225+
async def source() -> asyncio.AsyncIterator[StreamEvent]:
226+
yield StreamEvent("text", {"delta": "hel"})
227+
yield StreamEvent("text", {"delta": "lo"})
228+
yield StreamEvent("final", {})
229+
230+
yielded = [event async for event in channel.stream_events(message, source())]
229231

230232
assert events == [
231233
("start", "command", ""),
232234
("update", "command", "hel"),
233235
("update", "command", "hello"),
234236
("finish", "command", "hello"),
235237
]
238+
assert [event.kind for event in yielded] == ["text", "text", "final"]
236239

237240

238241
def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:

0 commit comments

Comments
 (0)