From 04a478db3fd6105f9a362d29d46403846d736262 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 6 May 2026 15:51:44 +0800 Subject: [PATCH] fix: allow `provide_tape_store` hook to return a context manager Fixes #194 Signed-off-by: Frost Ming --- src/bub/builtin/cli.py | 7 +- src/bub/channels/manager.py | 37 ++++++----- src/bub/framework.py | 21 +++++- tests/test_builtin_cli.py | 65 +++++++++++++++++++ tests/test_channels.py | 12 ++++ tests/test_framework.py | 32 +++++++++ .../src/content/docs/docs/extending/hooks.mdx | 39 +++++++++++ .../docs/zh-cn/docs/extending/hooks.mdx | 39 +++++++++++ 8 files changed, 232 insertions(+), 20 deletions(-) diff --git a/src/bub/builtin/cli.py b/src/bub/builtin/cli.py index d7d65930..48725103 100644 --- a/src/bub/builtin/cli.py +++ b/src/bub/builtin/cli.py @@ -21,6 +21,7 @@ from bub.channels.message import ChannelMessage from bub.envelope import field_of from bub.framework import BubFramework +from bub.types import TurnResult ONBOARD_BANNER = r""" ███████████ █████ @@ -53,7 +54,11 @@ def run( context={"sender_id": sender_id}, ) - result = asyncio.run(framework.process_inbound(inbound)) + async def _run() -> TurnResult: + async with framework.running(): + return await framework.process_inbound(inbound) + + result = asyncio.run(_run()) for outbound in result.outbounds: rendered = str(field_of(outbound, "content", "")) target_channel = str(field_of(outbound, "channel", "stdout")) diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 83e4a602..60e32a75 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -142,24 +142,25 @@ def _on_task_done(self, session_id: str, task: asyncio.Task) -> None: async def listen_and_run(self) -> None: stop_event = asyncio.Event() self.framework.bind_outbound_router(self) - for channel in self.enabled_channels(): - await channel.start(stop_event) - logger.info("channel.manager started listening") - try: - while True: - message = await wait_until_stopped(self._messages.get(), stop_event) - task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output)) - task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) - self._ongoing_tasks.setdefault(message.session_id, set()).add(task) - except asyncio.CancelledError: - logger.info("channel.manager received shutdown signal") - except Exception: - logger.exception("channel.manager error") - raise - finally: - self.framework.bind_outbound_router(None) - await self.shutdown() - logger.info("channel.manager stopped") + async with self.framework.running(): + for channel in self.enabled_channels(): + await channel.start(stop_event) + logger.info("channel.manager started listening") + try: + while True: + message = await wait_until_stopped(self._messages.get(), stop_event) + task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output)) + task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) + self._ongoing_tasks.setdefault(message.session_id, set()).add(task) + except asyncio.CancelledError: + logger.info("channel.manager received shutdown signal") + except Exception: + logger.exception("channel.manager error") + raise + finally: + self.framework.bind_outbound_router(None) + await self.shutdown() + logger.info("channel.manager stopped") async def shutdown(self) -> None: count = 0 diff --git a/src/bub/framework.py b/src/bub/framework.py index 8573fea3..cd259c7b 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -2,6 +2,8 @@ from __future__ import annotations +import contextlib +from collections.abc import AsyncGenerator, AsyncIterator, Iterator from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any @@ -46,6 +48,7 @@ def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None: self._hook_runtime = HookRuntime(self._plugin_manager) self._plugin_status: dict[str, PluginStatus] = {} self._outbound_router: OutboundChannelRouter | None = None + self._tape_store: TapeStore | AsyncTapeStore | None = None configure.load(self.config_file) def _load_builtin_hooks(self) -> None: @@ -253,8 +256,24 @@ def get_channels(self, message_handler: MessageHandler) -> dict[str, Channel]: channels[channel.name] = channel return channels + @contextlib.asynccontextmanager + async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]: + async with contextlib.AsyncExitStack() as stack: + tape_store = self._hook_runtime.call_first_sync("provide_tape_store") + # Allow plugins to return either TapeStore/AsyncTapeStore instances or context managers for them + # This benefits plugins that need to initialize and clean up resources with the tape store. + if isinstance(tape_store, AsyncIterator): + tape_store = await stack.enter_async_context(contextlib.asynccontextmanager(lambda: tape_store)()) + elif isinstance(tape_store, Iterator): + tape_store = stack.enter_context(contextlib.contextmanager(lambda: tape_store)()) + self._tape_store = tape_store + try: + yield stack + finally: + self._tape_store = None + def get_tape_store(self) -> TapeStore | AsyncTapeStore | None: - return self._hook_runtime.call_first_sync("provide_tape_store") + return self._tape_store def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) -> str: return "\n\n".join( diff --git a/tests/test_builtin_cli.py b/tests/test_builtin_cli.py index 9d69686e..43b8010a 100644 --- a/tests/test_builtin_cli.py +++ b/tests/test_builtin_cli.py @@ -239,6 +239,71 @@ def fake_secret(message: str) -> str: assert not config_file.exists() +def test_run_command_processes_inbound_inside_framework_runtime(tmp_path: Path) -> None: + config_file = tmp_path / "config.yml" + framework = BubFramework(config_file=config_file) + observed: dict[str, Any] = {} + + class RecordingTapeStore: + def __init__(self) -> None: + self.enter_count = 0 + self.exit_count = 0 + + tape_store = RecordingTapeStore() + + class RunPlugin: + @hookimpl + def register_cli_commands(self, app: typer.Typer) -> None: + app.command("run")(cli.run) + + @hookimpl + def provide_tape_store(self): + tape_store.enter_count += 1 + try: + yield tape_store + finally: + tape_store.exit_count += 1 + + @hookimpl + def build_prompt(self, message, session_id, state) -> str: + observed["session_id"] = session_id + observed["message_content"] = message.content + observed["sender_id"] = message.context["sender_id"] + return "prompt" + + @hookimpl + async def run_model(self, prompt, session_id, state) -> str: + observed["tape_store"] = framework.get_tape_store() + return "model output" + + @hookimpl + def render_outbound(self, message, session_id, state, model_output): + return [{"channel": "stdout", "chat_id": "local", "content": model_output}] + + @hookimpl + async def dispatch_outbound(self, message) -> bool: + return True + + framework._plugin_manager.register(RunPlugin(), name="run-plugin") + app = framework.create_cli_app() + + result = CliRunner().invoke( + app, + ["run", "hello", "--channel", "cli", "--chat-id", "room", "--sender-id", "frost"], + ) + + assert result.exit_code == 0 + assert "[stdout:local]\nmodel output" in result.stdout + assert observed == { + "session_id": "cli:room", + "message_content": "hello", + "sender_id": "frost", + "tape_store": tape_store, + } + assert tape_store.enter_count == 1 + assert tape_store.exit_count == 1 + + def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: Path, monkeypatch) -> None: config_file = tmp_path / "config.yml" diff --git a/tests/test_channels.py b/tests/test_channels.py index 54f56158..14d34905 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -66,11 +66,21 @@ def __init__(self, channels: dict[str, FakeChannel]) -> None: self._channels = channels self.router = None self.process_calls: list[tuple[ChannelMessage, bool]] = [] + self.running_entries = 0 + self.running_exits = 0 def get_channels(self, message_handler): self.message_handler = message_handler return self._channels + @contextlib.asynccontextmanager + async def running(self): + self.running_entries += 1 + try: + yield + finally: + self.running_exits += 1 + def bind_outbound_router(self, router) -> None: self.router = router @@ -262,6 +272,8 @@ async def shutdown() -> None: message, stream_output = framework.process_calls[0] assert message.content == "hello" assert stream_output is True + assert framework.running_entries == 1 + assert framework.running_exits == 1 @pytest.mark.asyncio diff --git a/tests/test_framework.py b/tests/test_framework.py index ed925498..63756cfc 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -116,6 +116,38 @@ def system_prompt(self, prompt: str, state: dict[str, str]) -> str | None: assert prompt == "low\n\nhigh" +@pytest.mark.asyncio +async def test_running_enters_tape_store_once_and_reuses_it() -> None: + framework = BubFramework() + + class RecordingTapeStore: + def __init__(self) -> None: + self.enter_count = 0 + self.exit_count = 0 + + tape_store = RecordingTapeStore() + + class TapePlugin: + @hookimpl + def provide_tape_store(self): + tape_store.enter_count += 1 + try: + yield tape_store + finally: + tape_store.exit_count += 1 + + framework._plugin_manager.register(TapePlugin(), name="tape") + + async with framework.running(): + assert framework.get_tape_store() is tape_store + assert framework.get_tape_store() is tape_store + assert tape_store.enter_count == 1 + assert tape_store.exit_count == 0 + + assert tape_store.enter_count == 1 + assert tape_store.exit_count == 1 + + def test_builtin_cli_exposes_login_and_gateway_command(write_config) -> None: with patch.dict(os.environ, {}, clear=True): framework = BubFramework(config_file=write_config()) diff --git a/website/src/content/docs/docs/extending/hooks.mdx b/website/src/content/docs/docs/extending/hooks.mdx index b8bf006b..e7ad13b9 100644 --- a/website/src/content/docs/docs/extending/hooks.mdx +++ b/website/src/content/docs/docs/extending/hooks.mdx @@ -53,6 +53,45 @@ Other hook consumers: - `provide_tape_store` - `build_tape_context` +## Tape Store Lifecycle + +`provide_tape_store` is resolved when `BubFramework.running()` starts, not on every `get_tape_store()` call. +Builtin CLI entry points open this runtime scope for you: + +- `bub run` opens it around one inbound turn. +- `bub chat` and `bub gateway` keep it open until the listener exits. + +Return a plain store when no lifecycle management is needed: + +```python +from republic.tape import InMemoryTapeStore + +from bub import hookimpl + + +class MemoryTapePlugin: + @hookimpl + def provide_tape_store(self): + return InMemoryTapeStore() +``` + +Return a sync or async iterator when the store needs process-level setup and cleanup. +The yielded value becomes the active result of `framework.get_tape_store()` until the runtime scope exits. + +```python +from bub import hookimpl + + +class DatabaseTapePlugin: + @hookimpl + def provide_tape_store(self): + store = open_store() + try: + yield store + finally: + store.close() +``` + ## Interactive Onboarding `onboard_config(current_config)` lets a plugin participate in the interactive `bub onboard` flow. diff --git a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx index 82560783..022d2dfc 100644 --- a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx +++ b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx @@ -53,6 +53,45 @@ description: Hook 执行语义、优先级、同步/异步规则、签名匹配 - `provide_tape_store` - `build_tape_context` +## Tape Store 生命周期 + +`provide_tape_store` 会在 `BubFramework.running()` 启动时解析,而不是在每次调用 `get_tape_store()` 时重新解析。 +内置 CLI 入口会自动打开这个运行时作用域: + +- `bub run` 会围绕一个入站 turn 打开作用域。 +- `bub chat` 和 `bub gateway` 会保持作用域打开,直到 listener 退出。 + +当不需要生命周期管理时,可以返回普通 store: + +```python +from republic.tape import InMemoryTapeStore + +from bub import hookimpl + + +class MemoryTapePlugin: + @hookimpl + def provide_tape_store(self): + return InMemoryTapeStore() +``` + +当 store 需要进程级初始化和清理时,可以返回同步或异步迭代器。 +yield 出来的值会成为 `framework.get_tape_store()` 的活跃结果,直到运行时作用域退出。 + +```python +from bub import hookimpl + + +class DatabaseTapePlugin: + @hookimpl + def provide_tape_store(self): + store = open_store() + try: + yield store + finally: + store.close() +``` + ## 交互式引导 `onboard_config(current_config)` 允许插件参与交互式 `bub onboard` 流程。