Skip to content

Commit 14f2974

Browse files
authored
fix: allow provide_tape_store hook to return a context manager (#199)
Fixes #194 Signed-off-by: Frost Ming <me@frostming.com>
1 parent acb76a5 commit 14f2974

8 files changed

Lines changed: 232 additions & 20 deletions

File tree

src/bub/builtin/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from bub.channels.message import ChannelMessage
2222
from bub.envelope import field_of
2323
from bub.framework import BubFramework
24+
from bub.types import TurnResult
2425

2526
ONBOARD_BANNER = r"""
2627
███████████ █████
@@ -53,7 +54,11 @@ def run(
5354
context={"sender_id": sender_id},
5455
)
5556

56-
result = asyncio.run(framework.process_inbound(inbound))
57+
async def _run() -> TurnResult:
58+
async with framework.running():
59+
return await framework.process_inbound(inbound)
60+
61+
result = asyncio.run(_run())
5762
for outbound in result.outbounds:
5863
rendered = str(field_of(outbound, "content", ""))
5964
target_channel = str(field_of(outbound, "channel", "stdout"))

src/bub/channels/manager.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,25 @@ def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
142142
async def listen_and_run(self) -> None:
143143
stop_event = asyncio.Event()
144144
self.framework.bind_outbound_router(self)
145-
for channel in self.enabled_channels():
146-
await channel.start(stop_event)
147-
logger.info("channel.manager started listening")
148-
try:
149-
while True:
150-
message = await wait_until_stopped(self._messages.get(), stop_event)
151-
task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output))
152-
task.add_done_callback(functools.partial(self._on_task_done, message.session_id))
153-
self._ongoing_tasks.setdefault(message.session_id, set()).add(task)
154-
except asyncio.CancelledError:
155-
logger.info("channel.manager received shutdown signal")
156-
except Exception:
157-
logger.exception("channel.manager error")
158-
raise
159-
finally:
160-
self.framework.bind_outbound_router(None)
161-
await self.shutdown()
162-
logger.info("channel.manager stopped")
145+
async with self.framework.running():
146+
for channel in self.enabled_channels():
147+
await channel.start(stop_event)
148+
logger.info("channel.manager started listening")
149+
try:
150+
while True:
151+
message = await wait_until_stopped(self._messages.get(), stop_event)
152+
task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output))
153+
task.add_done_callback(functools.partial(self._on_task_done, message.session_id))
154+
self._ongoing_tasks.setdefault(message.session_id, set()).add(task)
155+
except asyncio.CancelledError:
156+
logger.info("channel.manager received shutdown signal")
157+
except Exception:
158+
logger.exception("channel.manager error")
159+
raise
160+
finally:
161+
self.framework.bind_outbound_router(None)
162+
await self.shutdown()
163+
logger.info("channel.manager stopped")
163164

164165
async def shutdown(self) -> None:
165166
count = 0

src/bub/framework.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import contextlib
6+
from collections.abc import AsyncGenerator, AsyncIterator, Iterator
57
from dataclasses import dataclass
68
from pathlib import Path
79
from typing import TYPE_CHECKING, Any
@@ -46,6 +48,7 @@ def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None:
4648
self._hook_runtime = HookRuntime(self._plugin_manager)
4749
self._plugin_status: dict[str, PluginStatus] = {}
4850
self._outbound_router: OutboundChannelRouter | None = None
51+
self._tape_store: TapeStore | AsyncTapeStore | None = None
4952
configure.load(self.config_file)
5053

5154
def _load_builtin_hooks(self) -> None:
@@ -253,8 +256,24 @@ def get_channels(self, message_handler: MessageHandler) -> dict[str, Channel]:
253256
channels[channel.name] = channel
254257
return channels
255258

259+
@contextlib.asynccontextmanager
260+
async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]:
261+
async with contextlib.AsyncExitStack() as stack:
262+
tape_store = self._hook_runtime.call_first_sync("provide_tape_store")
263+
# Allow plugins to return either TapeStore/AsyncTapeStore instances or context managers for them
264+
# This benefits plugins that need to initialize and clean up resources with the tape store.
265+
if isinstance(tape_store, AsyncIterator):
266+
tape_store = await stack.enter_async_context(contextlib.asynccontextmanager(lambda: tape_store)())
267+
elif isinstance(tape_store, Iterator):
268+
tape_store = stack.enter_context(contextlib.contextmanager(lambda: tape_store)())
269+
self._tape_store = tape_store
270+
try:
271+
yield stack
272+
finally:
273+
self._tape_store = None
274+
256275
def get_tape_store(self) -> TapeStore | AsyncTapeStore | None:
257-
return self._hook_runtime.call_first_sync("provide_tape_store")
276+
return self._tape_store
258277

259278
def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) -> str:
260279
return "\n\n".join(

tests/test_builtin_cli.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,71 @@ def fake_secret(message: str) -> str:
239239
assert not config_file.exists()
240240

241241

242+
def test_run_command_processes_inbound_inside_framework_runtime(tmp_path: Path) -> None:
243+
config_file = tmp_path / "config.yml"
244+
framework = BubFramework(config_file=config_file)
245+
observed: dict[str, Any] = {}
246+
247+
class RecordingTapeStore:
248+
def __init__(self) -> None:
249+
self.enter_count = 0
250+
self.exit_count = 0
251+
252+
tape_store = RecordingTapeStore()
253+
254+
class RunPlugin:
255+
@hookimpl
256+
def register_cli_commands(self, app: typer.Typer) -> None:
257+
app.command("run")(cli.run)
258+
259+
@hookimpl
260+
def provide_tape_store(self):
261+
tape_store.enter_count += 1
262+
try:
263+
yield tape_store
264+
finally:
265+
tape_store.exit_count += 1
266+
267+
@hookimpl
268+
def build_prompt(self, message, session_id, state) -> str:
269+
observed["session_id"] = session_id
270+
observed["message_content"] = message.content
271+
observed["sender_id"] = message.context["sender_id"]
272+
return "prompt"
273+
274+
@hookimpl
275+
async def run_model(self, prompt, session_id, state) -> str:
276+
observed["tape_store"] = framework.get_tape_store()
277+
return "model output"
278+
279+
@hookimpl
280+
def render_outbound(self, message, session_id, state, model_output):
281+
return [{"channel": "stdout", "chat_id": "local", "content": model_output}]
282+
283+
@hookimpl
284+
async def dispatch_outbound(self, message) -> bool:
285+
return True
286+
287+
framework._plugin_manager.register(RunPlugin(), name="run-plugin")
288+
app = framework.create_cli_app()
289+
290+
result = CliRunner().invoke(
291+
app,
292+
["run", "hello", "--channel", "cli", "--chat-id", "room", "--sender-id", "frost"],
293+
)
294+
295+
assert result.exit_code == 0
296+
assert "[stdout:local]\nmodel output" in result.stdout
297+
assert observed == {
298+
"session_id": "cli:room",
299+
"message_content": "hello",
300+
"sender_id": "frost",
301+
"tape_store": tape_store,
302+
}
303+
assert tape_store.enter_count == 1
304+
assert tape_store.exit_count == 1
305+
306+
242307
def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: Path, monkeypatch) -> None:
243308
config_file = tmp_path / "config.yml"
244309

tests/test_channels.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,21 @@ def __init__(self, channels: dict[str, FakeChannel]) -> None:
6666
self._channels = channels
6767
self.router = None
6868
self.process_calls: list[tuple[ChannelMessage, bool]] = []
69+
self.running_entries = 0
70+
self.running_exits = 0
6971

7072
def get_channels(self, message_handler):
7173
self.message_handler = message_handler
7274
return self._channels
7375

76+
@contextlib.asynccontextmanager
77+
async def running(self):
78+
self.running_entries += 1
79+
try:
80+
yield
81+
finally:
82+
self.running_exits += 1
83+
7484
def bind_outbound_router(self, router) -> None:
7585
self.router = router
7686

@@ -262,6 +272,8 @@ async def shutdown() -> None:
262272
message, stream_output = framework.process_calls[0]
263273
assert message.content == "hello"
264274
assert stream_output is True
275+
assert framework.running_entries == 1
276+
assert framework.running_exits == 1
265277

266278

267279
@pytest.mark.asyncio

tests/test_framework.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,38 @@ def system_prompt(self, prompt: str, state: dict[str, str]) -> str | None:
116116
assert prompt == "low\n\nhigh"
117117

118118

119+
@pytest.mark.asyncio
120+
async def test_running_enters_tape_store_once_and_reuses_it() -> None:
121+
framework = BubFramework()
122+
123+
class RecordingTapeStore:
124+
def __init__(self) -> None:
125+
self.enter_count = 0
126+
self.exit_count = 0
127+
128+
tape_store = RecordingTapeStore()
129+
130+
class TapePlugin:
131+
@hookimpl
132+
def provide_tape_store(self):
133+
tape_store.enter_count += 1
134+
try:
135+
yield tape_store
136+
finally:
137+
tape_store.exit_count += 1
138+
139+
framework._plugin_manager.register(TapePlugin(), name="tape")
140+
141+
async with framework.running():
142+
assert framework.get_tape_store() is tape_store
143+
assert framework.get_tape_store() is tape_store
144+
assert tape_store.enter_count == 1
145+
assert tape_store.exit_count == 0
146+
147+
assert tape_store.enter_count == 1
148+
assert tape_store.exit_count == 1
149+
150+
119151
def test_builtin_cli_exposes_login_and_gateway_command(write_config) -> None:
120152
with patch.dict(os.environ, {}, clear=True):
121153
framework = BubFramework(config_file=write_config())

website/src/content/docs/docs/extending/hooks.mdx

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,45 @@ Other hook consumers:
5353
- `provide_tape_store`
5454
- `build_tape_context`
5555

56+
## Tape Store Lifecycle
57+
58+
`provide_tape_store` is resolved when `BubFramework.running()` starts, not on every `get_tape_store()` call.
59+
Builtin CLI entry points open this runtime scope for you:
60+
61+
- `bub run` opens it around one inbound turn.
62+
- `bub chat` and `bub gateway` keep it open until the listener exits.
63+
64+
Return a plain store when no lifecycle management is needed:
65+
66+
```python
67+
from republic.tape import InMemoryTapeStore
68+
69+
from bub import hookimpl
70+
71+
72+
class MemoryTapePlugin:
73+
@hookimpl
74+
def provide_tape_store(self):
75+
return InMemoryTapeStore()
76+
```
77+
78+
Return a sync or async iterator when the store needs process-level setup and cleanup.
79+
The yielded value becomes the active result of `framework.get_tape_store()` until the runtime scope exits.
80+
81+
```python
82+
from bub import hookimpl
83+
84+
85+
class DatabaseTapePlugin:
86+
@hookimpl
87+
def provide_tape_store(self):
88+
store = open_store()
89+
try:
90+
yield store
91+
finally:
92+
store.close()
93+
```
94+
5695
## Interactive Onboarding
5796

5897
`onboard_config(current_config)` lets a plugin participate in the interactive `bub onboard` flow.

website/src/content/docs/zh-cn/docs/extending/hooks.mdx

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,45 @@ description: Hook 执行语义、优先级、同步/异步规则、签名匹配
5353
- `provide_tape_store`
5454
- `build_tape_context`
5555

56+
## Tape Store 生命周期
57+
58+
`provide_tape_store` 会在 `BubFramework.running()` 启动时解析,而不是在每次调用 `get_tape_store()` 时重新解析。
59+
内置 CLI 入口会自动打开这个运行时作用域:
60+
61+
- `bub run` 会围绕一个入站 turn 打开作用域。
62+
- `bub chat``bub gateway` 会保持作用域打开,直到 listener 退出。
63+
64+
当不需要生命周期管理时,可以返回普通 store:
65+
66+
```python
67+
from republic.tape import InMemoryTapeStore
68+
69+
from bub import hookimpl
70+
71+
72+
class MemoryTapePlugin:
73+
@hookimpl
74+
def provide_tape_store(self):
75+
return InMemoryTapeStore()
76+
```
77+
78+
当 store 需要进程级初始化和清理时,可以返回同步或异步迭代器。
79+
yield 出来的值会成为 `framework.get_tape_store()` 的活跃结果,直到运行时作用域退出。
80+
81+
```python
82+
from bub import hookimpl
83+
84+
85+
class DatabaseTapePlugin:
86+
@hookimpl
87+
def provide_tape_store(self):
88+
store = open_store()
89+
try:
90+
yield store
91+
finally:
92+
store.close()
93+
```
94+
5695
## 交互式引导
5796

5897
`onboard_config(current_config)` 允许插件参与交互式 `bub onboard` 流程。

0 commit comments

Comments
 (0)