From 0a033ca0557efcc8e9d0054027a0db00e1aae6db Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Thu, 16 Apr 2026 10:18:13 +0800 Subject: [PATCH] feat: Switchable streaming and non-streaming model output - Updated `run_model` and `run_model_stream` hooks to allow for both synchronous and asynchronous execution. - Introduced a `stream_output` flag in `ChannelManager` and `BubFramework` to control streaming behavior. - Modified `process_inbound` to handle streaming output based on the new flag. - Enhanced `Agent` class to support streaming output through `run_stream` method. - Updated documentation to reflect changes in streaming capabilities and usage. - Added tests to verify the correct behavior of streaming and non-streaming executions. Signed-off-by: Frost Ming --- docs/architecture.md | 11 +- docs/channels/cli.md | 6 + docs/channels/index.md | 8 +- docs/extension-guide.md | 14 +- docs/features.md | 5 +- docs/index.md | 8 +- src/bub/builtin/agent.py | 267 +++++++++++++++++++++++++++++--- src/bub/builtin/hook_impl.py | 6 +- src/bub/builtin/tools.py | 2 +- src/bub/channels/manager.py | 3 +- src/bub/framework.py | 21 ++- src/bub/hook_runtime.py | 17 +- tests/test_builtin_agent.py | 10 +- tests/test_builtin_hook_impl.py | 26 +++- tests/test_channels.py | 73 +++++++++ tests/test_framework.py | 112 ++++++++++++++ tests/test_hook_runtime.py | 35 +++++ tests/test_subagent_tool.py | 24 +-- 18 files changed, 581 insertions(+), 67 deletions(-) diff --git a/docs/architecture.md b/docs/architecture.md index 9b948061..e5cb685f 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -16,14 +16,14 @@ 2. Initialize state with `_runtime_workspace` from `BubFramework.workspace`. 3. Merge all `load_state(message, session_id)` dicts. 4. Build prompt via `build_prompt(message, session_id, state)` (fallback to inbound `content` if empty). -5. Execute `run_model_stream(prompt, session_id, state)`. -6. For each stream event, call `OutboundChannelRouter.dispatch_event(...)`, which forwards to `channel.on_event(event, message)` when the target channel exists. +5. Execute `run_model(prompt, session_id, state)` by default, or `run_model_stream(prompt, session_id, state)` when the caller opts into streaming. +6. In streaming mode, forward each stream event through the outbound router before collecting final text. 7. Always execute `save_state(...)` in a `finally` block. 8. Render outbound batches via `render_outbound(...)`, then flatten them. 9. If no outbound exists, emit one fallback outbound. 10. Dispatch each outbound via `dispatch_outbound(message)`. -If no plugin implements `run_model_stream`, `HookRuntime` falls back to `run_model(prompt, session_id, state)` and adapts the returned text into a stream with a single text chunk. +`HookRuntime` keeps both directions compatible: `run_model()` can consume a streaming plugin by concatenating text chunks, and `run_model_stream()` can consume a plain `run_model()` plugin by adapting its text into a single-chunk stream. ## Hook Priority Semantics @@ -50,7 +50,8 @@ If no plugin implements `run_model_stream`, `HookRuntime` falls back to `run_mod Builtin `BuiltinImpl` behavior includes: - `build_prompt`: supports comma command mode; non-command text may include `context_str`. -- `run_model_stream`: delegates to `Agent.run()`. +- `run_model`: delegates to `Agent.run()`. +- `run_model_stream`: delegates to `Agent.run_stream()`. - `system_prompt`: combines a default prompt with workspace `AGENTS.md`. - `register_cli_commands`: installs `run`, `gateway`, `chat`, plus hidden diagnostic commands. - `provide_channels`: returns `telegram` and `cli` channel adapters. @@ -67,6 +68,8 @@ Channels have two different outbound surfaces: If a channel does not implement any special event behavior, it can ignore `on_event` and rely entirely on `send()`. +Channel streaming is opt-in through `BUB_STREAM_OUTPUT=true` (used by `ChannelManager`). When disabled, channels only receive the final rendered outbound message. + ## Boundaries - `Envelope` stays intentionally weakly typed (`Any` + accessor helpers). diff --git a/docs/channels/cli.md b/docs/channels/cli.md index ab88c56d..27f947a5 100644 --- a/docs/channels/cli.md +++ b/docs/channels/cli.md @@ -56,6 +56,12 @@ Enable only selected channels: uv run bub gateway --enable-channel telegram ``` +Forward streaming model events to channel adapters instead of waiting for the final rendered message: + +```bash +BUB_STREAM_OUTPUT=true uv run bub gateway --enable-channel telegram +``` + ## `bub chat` Start an interactive REPL session via the `cli` channel. diff --git a/docs/channels/index.md b/docs/channels/index.md index b1d76756..678a6243 100644 --- a/docs/channels/index.md +++ b/docs/channels/index.md @@ -27,6 +27,12 @@ Enable only Telegram: uv run bub gateway --enable-channel telegram ``` +Enable streaming event delivery for channel listeners: + +```bash +BUB_STREAM_OUTPUT=true uv run bub gateway --enable-channel telegram +``` + ## Session Semantics - `run` command default session id: `:` @@ -43,7 +49,7 @@ Channel adapters can receive outbound data in two forms: Use `on_event` for incremental UX such as live text updates, typing indicators, progress bars, or chunk-level logging. Use `send` for the final durable outbound payload. -`on_event` is optional. A channel that does not need streaming behavior can ignore it and only implement `send`. +`on_event` is optional. A channel that does not need streaming behavior can ignore it and only implement `send`. `ChannelManager` only forwards stream events when `BUB_STREAM_OUTPUT=true`; otherwise channels receive final outbounds only. ## Debounce Behavior diff --git a/docs/extension-guide.md b/docs/extension-guide.md index a13b67d2..5182e712 100644 --- a/docs/extension-guide.md +++ b/docs/extension-guide.md @@ -100,16 +100,16 @@ Current `process_inbound()` hook usage: 1. `resolve_session` (`call_first`) 2. `load_state` (`call_many`, then merged by framework) 3. `build_prompt` (`call_first`) -4. `run_model_stream` (`call_first`) +4. `run_model` / `run_model_stream` (`call_first`) 5. `save_state` (`call_many`, always executed in `finally`) 6. `render_outbound` (`call_many`) 7. `dispatch_outbound` (`call_many`, per outbound) Compatibility note: -- `run_model_stream` is the primary model hook. -- If no plugin implements `run_model_stream`, Bub falls back to `run_model`. -- The `run_model` return value is wrapped into a stream with exactly one text chunk. +- Bub can execute either `run_model` or `run_model_stream`, depending on whether the caller requests streaming. +- `HookRuntime.run_model()` can consume a streaming plugin by concatenating text chunks. +- `HookRuntime.run_model_stream()` can consume a plain `run_model()` implementation by wrapping it into a one-chunk stream. - A plugin should implement one of these hooks, not both. Other hook consumers: @@ -182,8 +182,8 @@ uv run bub hooks uv run bub run "hello" ``` -Check that your plugin is listed for `build_prompt` / `run_model_stream`, and output reflects your override. -If you intentionally use the legacy compatibility hook, check for `run_model`. +Check that your plugin is listed for `build_prompt` plus whichever model hook you implement, and output reflects your override. +If you intentionally use the non-streaming path, check for `run_model`; if you need incremental output, check for `run_model_stream`. ## 10) Listen To Parent Stream @@ -229,7 +229,7 @@ class StreamTapPlugin: Use this when you need to log chunks, redact text, inject extra events, or measure stream timing without reimplementing the underlying model call. -If you also need to support parents that only implement legacy `run_model`, add your own fallback path and wrap that text result into a one-chunk stream. +If you also need to support parents that only implement `run_model`, add your own fallback path and wrap that text result into a one-chunk stream. ## 11) Common Pitfalls diff --git a/docs/features.md b/docs/features.md index 2c1f4ab0..f8b165e5 100644 --- a/docs/features.md +++ b/docs/features.md @@ -5,8 +5,8 @@ Every turn stage is a [pluggy](https://pluggy.readthedocs.io/) hook. Builtins are ordinary plugins — override any stage by registering your own. Both first-result hooks (override) and broadcast hooks (observer) are supported. -`run_model_stream` is the primary model hook. -Legacy `run_model` hooks still work and are adapted into a single text chunk stream. +`run_model` is the default model hook for turn execution. +`run_model_stream` remains available for incremental channel output, and either hook shape can be adapted to the other. Safe fallback to prompt text when no model hook returns a value (with `on_error` notification). Automatic fallback outbound when `render_outbound` produces nothing. @@ -21,6 +21,7 @@ Context is reconstructed from tape records, not accumulated in session state. - **Model runtime**: agent loop with tool use, backed by [Republic](https://github.com/bubbuild/republic). - **Comma commands**: `,help`, `,skill`, `,fs.read`, etc. Unknown commands fall back to shell. - **Channels**: `cli` and `telegram` ship as defaults. +- **Streaming toggle**: channel event streaming is controlled by `BUB_STREAM_OUTPUT` and is off by default. All of these are hook implementations. Replace what you need. diff --git a/docs/index.md b/docs/index.md index 4b9b9245..9dcbe0b2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,12 +32,12 @@ uv run bub gateway # channel listener mode Every inbound message goes through one turn pipeline. Each stage is a hook. ```text -resolve_session → load_state → build_prompt → run_model_stream - ↓ - dispatch_outbound ← render_outbound ← save_state +resolve_session → load_state → build_prompt → run_model / run_model_stream + ↓ + dispatch_outbound ← render_outbound ← save_state ``` -`run_model` remains supported as a compatibility hook and is adapted into a single-chunk stream when `run_model_stream` is absent. +By default Bub executes `run_model` and expects plain text. Streaming remains available through `run_model_stream`, and `HookRuntime` adapts either hook shape to the other for compatibility. Builtins are plugins registered first. Later plugins override earlier ones. No special cases. diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 4b1180f4..46367b17 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -13,7 +13,7 @@ from datetime import UTC, datetime from functools import cached_property from pathlib import Path -from typing import Any +from typing import Any, Literal, overload from loguru import logger from republic import ( @@ -24,6 +24,7 @@ StreamEvent, StreamState, TapeContext, + ToolAutoResult, ToolContext, ) from republic.tape import InMemoryTapeStore, Tape @@ -90,6 +91,29 @@ async def run( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, + ) -> str: + if not prompt: + return "error: empty prompt" + tape = self.tapes.session_tape(session_id, workspace_from_state(state)) + tape.context = replace(tape.context, state=state) + merge_back = not session_id.startswith("temp/") + async with self.tapes.fork_tape(tape.name, merge_back=merge_back): + await self.tapes.ensure_bootstrap_anchor(tape.name) + if isinstance(prompt, str) and prompt.strip().startswith(","): + return await self._run_command(tape=tape, line=prompt.strip()) + return await self._agent_loop( + tape=tape, prompt=prompt, model=model, allowed_skills=allowed_skills, allowed_tools=allowed_tools + ) + + async def run_stream( + self, + *, + session_id: str, + prompt: str | list[dict], + state: State, + model: str | None = None, + allowed_skills: Collection[str] | None = None, + allowed_tools: Collection[str] | None = None, ) -> AsyncStreamEvents: if not prompt: events = [ @@ -114,7 +138,12 @@ async def run( ]) else: events = await self._agent_loop( - tape=tape, prompt=prompt, model=model, allowed_skills=allowed_skills, allowed_tools=allowed_tools + tape=tape, + prompt=prompt, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + stream_output=True, ) return self._events_with_callback(events, callback=stack.aclose) @@ -158,6 +187,30 @@ async def _run_command(self, tape: Tape, *, line: str) -> str: } await self.tapes.append_event(tape.name, "command", event_payload) + @overload + async def _agent_loop( + self, + *, + tape: Tape, + prompt: str | list[dict], + model: str | None = ..., + allowed_skills: Collection[str] | None = ..., + allowed_tools: Collection[str] | None = ..., + stream_output: Literal[False] = ..., + ) -> str: ... + + @overload + async def _agent_loop( + self, + *, + tape: Tape, + prompt: str | list[dict], + model: str | None = ..., + allowed_skills: Collection[str] | None = ..., + allowed_tools: Collection[str] | None = ..., + stream_output: Literal[True] = ..., + ) -> AsyncStreamEvents: ... + async def _agent_loop( self, *, @@ -166,7 +219,8 @@ async def _agent_loop( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncStreamEvents: + stream_output: bool = False, + ) -> AsyncStreamEvents | str: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model await self.tapes.append_event( @@ -179,16 +233,137 @@ async def _agent_loop( "allowed_tools": list(allowed_tools) if allowed_tools else None, }, ) - state = StreamState() - iterator = self._stream_events_with_auto_handoff( - tape=tape, - prompt=next_prompt, - state=state, - model=model, - allowed_skills=allowed_skills, - allowed_tools=allowed_tools, - ) - return AsyncStreamEvents(iterator, state=state) + if stream_output: + state = StreamState() + iterator = self._stream_events_with_auto_handoff( + tape=tape, + prompt=next_prompt, + state=state, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + ) + return AsyncStreamEvents(iterator, state=state) + else: + return await self._run_tools_with_auto_handoff( + tape=tape, + prompt=next_prompt, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + ) + + async def _run_tools_with_auto_handoff( + self, + tape: Tape, + prompt: str | list[dict], + model: str | None = None, + allowed_skills: Collection[str] | None = None, + allowed_tools: Collection[str] | None = None, + ) -> str: + auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES + display_model = model or self.settings.model + next_prompt = prompt + for step in range(1, self.settings.max_steps + 1): + start = time.monotonic() + logger.info("loop.step step={} tape={} model={}", step, tape.name, display_model) + await self.tapes.append_event(tape.name, "loop.step.start", {"step": step, "prompt": next_prompt}) + try: + output = await self._run_once( + tape=tape, + prompt=next_prompt, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + ) + except Exception as exc: + elapsed_ms = int((time.monotonic() - start) * 1000) + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "error", + "error": f"{exc!s}", + "date": datetime.now(UTC).isoformat(), + }, + ) + raise + + outcome = _resolve_tool_auto_result(output) + elapsed_ms = int((time.monotonic() - start) * 1000) + if outcome.kind == "text": + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "ok", + "date": datetime.now(UTC).isoformat(), + }, + ) + return outcome.text + if outcome.kind == "continue": + if "context" in tape.context.state: + next_prompt = f"{CONTINUE_PROMPT} [context: {tape.context.state['context']}]" + else: + next_prompt = CONTINUE_PROMPT + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "continue", + "date": datetime.now(UTC).isoformat(), + }, + ) + continue + + # Check if this is a context-length error that can be recovered via auto-handoff + if auto_handoff_remaining > 0 and _is_context_length_error(outcome.error): + auto_handoff_remaining -= 1 + logger.warning( + "auto_handoff: context length exceeded, performing automatic handoff. tape={} step={}", + tape.name, + step, + ) + await self.tapes.handoff( + tape.name, + name="auto_handoff/context_overflow", + state={"reason": "context_length_exceeded", "error": outcome.error}, + ) + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "auto_handoff", + "error": outcome.error, + "date": datetime.now(UTC).isoformat(), + }, + ) + # Retry with original prompt — the handoff anchor will truncate history + next_prompt = prompt + continue + + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "error", + "error": outcome.error, + "date": datetime.now(UTC).isoformat(), + }, + ) + raise RuntimeError(outcome.error) + + raise RuntimeError(f"max_steps_reached={self.settings.max_steps}") async def _stream_events_with_auto_handoff( self, @@ -213,6 +388,7 @@ async def _stream_events_with_auto_handoff( model=model, allowed_skills=allowed_skills, allowed_tools=allowed_tools, + stream_output=True, ) async for event in output: yield event @@ -316,6 +492,30 @@ def _load_skills_prompt(self, prompt: str, workspace: Path, allowed_skills: set[ expanded_skills = set(HINT_RE.findall(prompt)) & set(skill_index.keys()) return render_skills_prompt(list(skill_index.values()), expanded_skills=expanded_skills) + @overload + async def _run_once( + self, + *, + tape: Tape, + prompt: str | list[dict], + model: str | None = ..., + allowed_skills: Collection[str] | None = ..., + allowed_tools: Collection[str] | None = ..., + stream_output: Literal[False] = ..., + ) -> ToolAutoResult: ... + + @overload + async def _run_once( + self, + *, + tape: Tape, + prompt: str | list[dict], + model: str | None = ..., + allowed_skills: Collection[str] | None = ..., + allowed_tools: Collection[str] | None = ..., + stream_output: Literal[True] = ..., + ) -> AsyncStreamEvents: ... + async def _run_once( self, *, @@ -324,7 +524,8 @@ async def _run_once( model: str | None = None, allowed_tools: Collection[str] | None = None, allowed_skills: Collection[str] | None = None, - ) -> AsyncStreamEvents: + stream_output: bool = False, + ) -> AsyncStreamEvents | ToolAutoResult: prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt) if allowed_tools is not None: allowed_tools = {name.casefold() for name in allowed_tools} @@ -336,13 +537,26 @@ async def _run_once( else: tools = list(REGISTRY.values()) async with asyncio.timeout(self.settings.model_timeout_seconds): - return await tape.stream_events_async( - prompt=prompt, - system_prompt=self._system_prompt(prompt_text, state=tape.context.state, allowed_skills=allowed_skills), - max_tokens=self.settings.max_tokens, - tools=model_tools(tools), - model=model, - ) + if stream_output: + return await tape.stream_events_async( + prompt=prompt, + system_prompt=self._system_prompt( + prompt_text, state=tape.context.state, allowed_skills=allowed_skills + ), + max_tokens=self.settings.max_tokens, + tools=model_tools(tools), + model=model, + ) + else: + return await tape.run_tools_async( + prompt=prompt, + system_prompt=self._system_prompt( + prompt_text, state=tape.context.state, allowed_skills=allowed_skills + ), + max_tokens=self.settings.max_tokens, + tools=model_tools(tools), + model=model, + ) def _system_prompt(self, prompt: str, state: State, allowed_skills: set[str] | None = None) -> str: blocks: list[str] = [] @@ -373,6 +587,17 @@ def _resolve_final_data(final_data: dict[str, Any], error: RepublicError | None) return _ToolAutoOutcome(kind="error", error=error_message or "unknown error") +def _resolve_tool_auto_result(output: ToolAutoResult) -> _ToolAutoOutcome: + if output.kind == "text": + return _ToolAutoOutcome(kind="text", text=output.text or "") + if output.kind == "tools" or output.tool_calls or output.tool_results: + return _ToolAutoOutcome(kind="continue") + if output.error is None: + return _ToolAutoOutcome(kind="error", error="tool_auto_error: unknown") + error_kind = getattr(output.error.kind, "value", str(output.error.kind)) + return _ToolAutoOutcome(kind="error", error=f"{error_kind}: {output.error.message}") + + def _build_llm(settings: AgentSettings, tape_store: AsyncTapeStore, tape_context: TapeContext) -> LLM: from republic.auth.openai_codex import openai_codex_oauth_resolver diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index a80acf6e..591ecab6 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -106,9 +106,13 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St return text @hookimpl - async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: + async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: return await self.agent.run(session_id=session_id, prompt=prompt, state=state) + @hookimpl + async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: + return await self.agent.run_stream(session_id=session_id, prompt=prompt, state=state) + @hookimpl def register_cli_commands(self, app: typer.Typer) -> None: from bub.builtin import cli diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index fe2b034a..92644f3c 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -267,7 +267,7 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str: state = {**context.state, "session_id": subagent_session} allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"}) output = "" - async for event in await agent.run( + async for event in await agent.run_stream( session_id=subagent_session, prompt=param.prompt, state=state, diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 589d5d08..ab9c1e24 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -35,6 +35,7 @@ class ChannelSettings(BaseSettings): default=60.0, description="Time window in seconds to consider a channel active for processing messages.", ) + stream_output: bool = Field(default=False, description="Whether to stream model output to channels in real-time.") class ChannelManager: @@ -138,7 +139,7 @@ async def listen_and_run(self) -> None: try: while True: message = await wait_until_stopped(self._messages.get(), stop_event) - task = asyncio.create_task(self.framework.process_inbound(message)) + task = asyncio.create_task(self.framework.process_inbound(message, self._settings.stream_output)) task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) self._ongoing_tasks.setdefault(message.session_id, set()).add(task) except asyncio.CancelledError: diff --git a/src/bub/framework.py b/src/bub/framework.py index 005f2221..603eccf2 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -87,7 +87,7 @@ def _main( self._hook_runtime.call_many_sync("register_cli_commands", app=app) return app - async def process_inbound(self, inbound: Envelope) -> TurnResult: + async def process_inbound(self, inbound: Envelope, stream_output: bool = False) -> TurnResult: """Run one inbound message through hooks and return turn result.""" try: @@ -109,7 +109,7 @@ async def process_inbound(self, inbound: Envelope) -> TurnResult: prompt = content_of(inbound) model_output = "" try: - model_output = await self._run_model(inbound, prompt, session_id, state) + model_output = await self._run_model(inbound, prompt, session_id, state, stream_output) finally: await self._hook_runtime.call_many( "save_state", @@ -129,8 +129,23 @@ async def process_inbound(self, inbound: Envelope) -> TurnResult: raise async def _run_model( - self, inbound: Envelope, prompt: str | list[dict], session_id: str, state: dict[str, Any] + self, + inbound: Envelope, + prompt: str | list[dict], + session_id: str, + state: dict[str, Any], + stream_output: bool, ) -> str: + if not stream_output: + output = await self._hook_runtime.run_model(prompt=prompt, session_id=session_id, state=state) + if output is None: + await self._hook_runtime.notify_error( + stage="run_model", + error=RuntimeError("no model skill returned output"), + message=inbound, + ) + return prompt if isinstance(prompt, str) else content_of(inbound) + return output stream = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) if stream is None: await self._hook_runtime.notify_error( diff --git a/src/bub/hook_runtime.py b/src/bub/hook_runtime.py index c69cd1db..3d58712c 100644 --- a/src/bub/hook_runtime.py +++ b/src/bub/hook_runtime.py @@ -4,7 +4,7 @@ import inspect from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast import pluggy from loguru import logger @@ -160,6 +160,21 @@ def _iter_hookimpls(self, hook_name: str) -> list[Any]: def _kwargs_for_impl(impl: Any, kwargs: dict[str, Any]) -> dict[str, Any]: return {name: kwargs[name] for name in impl.argnames if name in kwargs} + async def run_model(self, prompt: str | list[dict], session_id: str, state: dict[str, Any]) -> str | None: + """Run the first `run_model` hook found and return its result.""" + for _, plugin in reversed(self._plugin_manager.list_name_plugin()): + if hasattr(plugin, "run_model"): + output = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) + return cast(str, output) + elif hasattr(plugin, "run_model_stream"): + stream = await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) + text = "" + async for event in stream: + if event.kind == "text": + text += str(event.data.get("delta", "")) + return text + return None + async def run_model_stream( self, prompt: str | list[dict], session_id: str, state: dict[str, Any] ) -> AsyncStreamEvents | None: diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index f55169ea..be359b70 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -120,7 +120,7 @@ async def test_agent_run_regular_session_merges_back() -> None: fork_capture = _ForkCapture() agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] - result = await agent.run(session_id="user/session1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run_stream(session_id="user/session1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 [event async for event in result] assert fork_capture.merge_back_values == [True] @@ -133,7 +133,7 @@ async def test_agent_run_temp_session_does_not_merge_back() -> None: fork_capture = _ForkCapture() agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] - result = await agent.run(session_id="temp/abc123", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run_stream(session_id="temp/abc123", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 [event async for event in result] assert fork_capture.merge_back_values == [False] @@ -147,7 +147,7 @@ async def test_agent_run_passes_model_to_llm() -> None: fake_tapes = _FakeTapeService(fork_capture) agent.tapes = fake_tapes # type: ignore[assignment] - result = await agent.run( + result = await agent.run_stream( session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}, # noqa: S108 @@ -163,7 +163,7 @@ async def test_agent_run_empty_prompt_returns_error() -> None: agent = _make_agent() agent.tapes = MagicMock() # type: ignore[assignment] - result = await agent.run(session_id="user/s1", prompt="", state={}) + result = await agent.run_stream(session_id="user/s1", prompt="", state={}) events = [event async for event in result] assert [(event.kind, event.data) for event in events] == [ @@ -180,7 +180,7 @@ async def test_agent_run_model_defaults_to_none() -> None: fake_tapes = _FakeTapeService(fork_capture) agent.tapes = fake_tapes # type: ignore[assignment] - result = await agent.run(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run_stream(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 [event async for event in result] assert fake_tapes.run_tools_model is None diff --git a/tests/test_builtin_hook_impl.py b/tests/test_builtin_hook_impl.py index 717004c3..2c2397ff 100644 --- a/tests/test_builtin_hook_impl.py +++ b/tests/test_builtin_hook_impl.py @@ -27,10 +27,15 @@ async def __aexit__(self, exc_type, exc, traceback) -> None: class FakeAgent: def __init__(self, home: Path) -> None: self.settings = SimpleNamespace(home=home) - self.calls: list[tuple[str, str, dict[str, object]]] = [] + self.run_calls: list[tuple[str, str, dict[str, object]]] = [] + self.run_stream_calls: list[tuple[str, str, dict[str, object]]] = [] - async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents: - self.calls.append((session_id, prompt, state)) + async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) -> str: + self.run_calls.append((session_id, prompt, state)) + return "agent-output" + + async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents: + self.run_stream_calls.append((session_id, prompt, state)) async def iterator(): yield StreamEvent("text", {"delta": "agent-output"}) @@ -117,6 +122,18 @@ async def test_build_prompt_marks_commands_and_prefixes_context(tmp_path: Path) assert prompt_lines[2] == "hello" +@pytest.mark.asyncio +async def test_run_model_delegates_to_agent(tmp_path: Path) -> None: + _, impl, agent = _build_impl(tmp_path) + state = {"context": "ctx"} + + result = await impl.run_model(prompt="prompt", session_id="session", state=state) + + assert result == "agent-output" + assert agent.run_calls == [("session", "prompt", state)] + assert agent.run_stream_calls == [] + + @pytest.mark.asyncio async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: _, impl, agent = _build_impl(tmp_path) @@ -126,7 +143,8 @@ async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: events = [event async for event in stream] assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})] - assert agent.calls == [("session", "prompt", state)] + assert agent.run_stream_calls == [("session", "prompt", state)] + assert agent.run_calls == [] def test_system_prompt_appends_workspace_agents_file(tmp_path: Path) -> None: diff --git a/tests/test_channels.py b/tests/test_channels.py index 7e5bee79..90b73035 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -47,6 +47,7 @@ class FakeFramework: def __init__(self, channels: dict[str, FakeChannel]) -> None: self._channels = channels self.router = None + self.process_calls: list[tuple[ChannelMessage, bool]] = [] def get_channels(self, message_handler): self.message_handler = message_handler @@ -55,6 +56,13 @@ def get_channels(self, message_handler): def bind_outbound_router(self, router) -> None: self.router = router + async def process_inbound(self, message: ChannelMessage, stream_output: bool = False): + self.process_calls.append((message, stream_output)) + stop_event = getattr(self, "_stop_event", None) + if stop_event is not None: + stop_event.set() + return None + def _message( content: str, @@ -174,6 +182,71 @@ async def never_finish() -> None: assert cli.stopped is False +@pytest.mark.asyncio +async def test_channel_manager_listen_and_run_passes_stream_output_setting( + monkeypatch: pytest.MonkeyPatch, +) -> None: + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + + class StubChannelSettings: + enabled_channels = "telegram" + debounce_seconds = 1.0 + max_wait_seconds = 10.0 + active_time_window = 60.0 + stream_output = True + + import bub.channels.manager as manager_module + + monkeypatch.setattr(manager_module, "ChannelSettings", StubChannelSettings) + manager = ChannelManager(framework) + calls = 0 + spawned_coroutines = [] + original_create_task = manager_module.asyncio.create_task + + class DummyTask: + def add_done_callback(self, callback) -> None: + return None + + def cancel(self) -> None: + return None + + def exception(self): + return None + + def create_task(coro): + spawned_coroutines.append(coro) + return DummyTask() + + async def wait_until_stopped(awaitable, current_stop_event): + nonlocal calls + calls += 1 + if calls == 1: + return await awaitable + close = getattr(awaitable, "close", None) + if callable(close): + close() + raise asyncio.CancelledError + + async def shutdown() -> None: + return None + + manager.shutdown = shutdown # type: ignore[method-assign] + monkeypatch.setattr(manager_module.asyncio, "create_task", create_task) + monkeypatch.setattr(manager_module, "wait_until_stopped", wait_until_stopped) + + listen_task = original_create_task(manager.listen_and_run()) + await asyncio.sleep(0) + await manager.on_receive(_message("hello", channel="telegram")) + await listen_task + assert len(spawned_coroutines) == 1 + await spawned_coroutines[0] + + assert len(framework.process_calls) == 1 + message, stream_output = framework.process_calls[0] + assert message.content == "hello" + assert stream_output is True + + @pytest.mark.asyncio async def test_channel_manager_quit_cancels_only_matching_session_tasks() -> None: manager = ChannelManager(FakeFramework({"telegram": FakeChannel("telegram")}), enabled_channels=["telegram"]) diff --git a/tests/test_framework.py b/tests/test_framework.py index 699f61ac..5749d0f0 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -1,11 +1,15 @@ from __future__ import annotations from pathlib import Path +from types import SimpleNamespace +import pytest import typer +from republic import AsyncStreamEvents, StreamEvent from typer.testing import CliRunner from bub.channels.base import Channel +from bub.channels.message import ChannelMessage from bub.framework import BubFramework from bub.hookspecs import hookimpl @@ -110,3 +114,111 @@ def test_builtin_cli_exposes_login_and_gateway_command() -> None: assert gateway_result.exit_code == 0 assert "bub gateway" in gateway_result.stdout assert "Start message listeners" in gateway_result.stdout + + +@pytest.mark.asyncio +async def test_process_inbound_defaults_to_non_streaming_run_model() -> None: + framework = BubFramework() + saved_outputs: list[str] = [] + + class NonStreamingPlugin: + @hookimpl + def resolve_session(self, message) -> str: + return "session" + + @hookimpl + def load_state(self, message, session_id) -> dict[str, str]: + return {} + + @hookimpl + def build_prompt(self, message, session_id, state) -> str: + return "prompt" + + @hookimpl + async def run_model(self, prompt, session_id, state) -> str: + return "plain-text" + + @hookimpl + async def save_state(self, session_id, state, message, model_output) -> None: + saved_outputs.append(model_output) + + @hookimpl + def render_outbound(self, message, session_id, state, model_output): + return [{"content": model_output, "channel": "cli", "chat_id": "room"}] + + @hookimpl + async def dispatch_outbound(self, message) -> bool: + return True + + framework._plugin_manager.register(NonStreamingPlugin(), name="non-streaming") + + result = await framework.process_inbound( + ChannelMessage(session_id="s", channel="cli", chat_id="room", content="hi") + ) + + assert result.model_output == "plain-text" + assert saved_outputs == ["plain-text"] + + +@pytest.mark.asyncio +async def test_process_inbound_streams_when_requested() -> None: # noqa: C901 + framework = BubFramework() + stream_calls: list[str] = [] + wrapped_events: list[str] = [] + + class StreamingPlugin: + @hookimpl + def resolve_session(self, message) -> str: + return "session" + + @hookimpl + def load_state(self, message, session_id) -> dict[str, str]: + return {} + + @hookimpl + def build_prompt(self, message, session_id, state) -> str: + return "prompt" + + @hookimpl + async def run_model_stream(self, prompt, session_id, state): + stream_calls.append(prompt) + + async def iterator(): + yield StreamEvent("text", {"delta": "stream"}) + yield StreamEvent("text", {"delta": "ed"}) + yield StreamEvent("final", {"text": "streamed", "ok": True}) + + return AsyncStreamEvents(iterator(), state=SimpleNamespace(error=None, usage=None)) + + @hookimpl + async def save_state(self, session_id, state, message, model_output) -> None: + return None + + @hookimpl + def render_outbound(self, message, session_id, state, model_output): + return [{"content": model_output, "channel": "cli", "chat_id": "room"}] + + @hookimpl + async def dispatch_outbound(self, message) -> bool: + return True + + class RecordingRouter: + def wrap_stream(self, message, stream): + async def iterator(): + async for event in stream: + wrapped_events.append(event.kind) + yield event + + return iterator() + + framework._plugin_manager.register(StreamingPlugin(), name="streaming") + framework.bind_outbound_router(RecordingRouter()) + + result = await framework.process_inbound( + ChannelMessage(session_id="s", channel="cli", chat_id="room", content="hi"), + stream_output=True, + ) + + assert stream_calls == ["prompt"] + assert wrapped_events == ["text", "text", "final"] + assert result.model_output == "streamed" diff --git a/tests/test_hook_runtime.py b/tests/test_hook_runtime.py index 337a2bda..d2afadc9 100644 --- a/tests/test_hook_runtime.py +++ b/tests/test_hook_runtime.py @@ -1,5 +1,6 @@ import pluggy import pytest +from republic import AsyncStreamEvents, StreamEvent from bub.hook_runtime import HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs, hookimpl @@ -104,3 +105,37 @@ def resolve_session(self, message): assert "resolve_session" in report assert report["resolve_session"] == ["session"] + + +@pytest.mark.asyncio +async def test_run_model_uses_streaming_hook_when_plain_hook_absent() -> None: + class StreamPlugin: + @hookimpl + async def run_model_stream(self, prompt, session_id, state): + async def iterator(): + yield StreamEvent("text", {"delta": "stream"}) + yield StreamEvent("text", {"delta": "ed"}) + + return AsyncStreamEvents(iterator()) + + runtime = _runtime_with_plugins(("stream", StreamPlugin())) + + result = await runtime.run_model(prompt="hello", session_id="s", state={}) + + assert result == "streamed" + + +@pytest.mark.asyncio +async def test_run_model_stream_falls_back_to_plain_hook() -> None: + class PlainPlugin: + @hookimpl + async def run_model(self, prompt, session_id, state): + return "plain" + + runtime = _runtime_with_plugins(("plain", PlainPlugin())) + + stream = await runtime.run_model_stream(prompt="hello", session_id="s", state={}) + + assert stream is not None + events = [event async for event in stream] + assert [(event.kind, event.data) for event in events] == [("text", {"delta": "plain"})] diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index 371592e9..14ec3472 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -20,9 +20,9 @@ def __init__(self, state: dict[str, Any]) -> None: class FakeAgent: def __init__(self) -> None: - self.run = AsyncMock(side_effect=self._run) + self.run_stream = AsyncMock(side_effect=self._run_stream) - async def _run(self, **kwargs: Any) -> AsyncStreamEvents: + async def _run_stream(self, **kwargs: Any) -> AsyncStreamEvents: async def iterator(): yield StreamEvent("text", {"delta": "agent result"}) @@ -37,8 +37,8 @@ async def test_subagent_inherit_session() -> None: result = await run_subagent.run(prompt="do something", session="inherit", context=ctx) assert result == "agent result" - agent.run.assert_called_once() - call_kwargs = agent.run.call_args.kwargs + agent.run_stream.assert_called_once() + call_kwargs = agent.run_stream.call_args.kwargs assert call_kwargs["session_id"] == "user/abc" assert call_kwargs["prompt"] == "do something" assert call_kwargs["model"] is None @@ -51,7 +51,7 @@ async def test_subagent_temp_session() -> None: await run_subagent.run(prompt="task", session="temp", context=ctx) - call_kwargs = agent.run.call_args.kwargs + call_kwargs = agent.run_stream.call_args.kwargs assert call_kwargs["session_id"].startswith("temp/") assert call_kwargs["session_id"] != "user/abc" @@ -63,7 +63,7 @@ async def test_subagent_custom_session() -> None: await run_subagent.run(prompt="task", session="custom/session-1", context=ctx) - call_kwargs = agent.run.call_args.kwargs + call_kwargs = agent.run_stream.call_args.kwargs assert call_kwargs["session_id"] == "custom/session-1" @@ -74,7 +74,7 @@ async def test_subagent_passes_model() -> None: await run_subagent.run(prompt="task", model="openai:gpt-4o", context=ctx) - call_kwargs = agent.run.call_args.kwargs + call_kwargs = agent.run_stream.call_args.kwargs assert call_kwargs["model"] == "openai:gpt-4o" @@ -85,7 +85,7 @@ async def test_subagent_state_includes_session_id() -> None: await run_subagent.run(prompt="task", session="temp", context=ctx) - call_kwargs = agent.run.call_args.kwargs + call_kwargs = agent.run_stream.call_args.kwargs state = call_kwargs["state"] # State should contain the subagent session_id, not the original assert state["session_id"] == call_kwargs["session_id"] @@ -100,7 +100,7 @@ async def test_subagent_default_session_when_missing() -> None: await run_subagent.run(prompt="task", session="inherit", context=ctx) - call_kwargs = agent.run.call_args.kwargs + call_kwargs = agent.run_stream.call_args.kwargs assert call_kwargs["session_id"] == "temp/unknown" @@ -118,7 +118,7 @@ def allowed_tool_default() -> str: await run_subagent.run(prompt="task", allowed_tools=[], context=ctx) - allowed_tools = agent.run.call_args.kwargs["allowed_tools"] + allowed_tools = agent.run_stream.call_args.kwargs["allowed_tools"] assert tool_name in allowed_tools assert "subagent" not in allowed_tools @@ -137,7 +137,7 @@ def resolve_subagent() -> str: await run_subagent.run(prompt="task", allowed_tools=[" tests_resolve_subagent "], context=ctx) - assert agent.run.call_args.kwargs["allowed_tools"] == {tool_name} + assert agent.run_stream.call_args.kwargs["allowed_tools"] == {tool_name} @pytest.mark.asyncio @@ -148,4 +148,4 @@ async def test_subagent_rejects_unknown_allowed_tools() -> None: with pytest.raises(ValueError, match="tests_missing_tool"): await run_subagent.run(prompt="task", allowed_tools=[" tests_missing_tool "], context=ctx) - agent.run.assert_not_called() + agent.run_stream.assert_not_called()