From d928968f00bd10fe5d79768d3acd8e41689e723c Mon Sep 17 00:00:00 2001 From: Simon Roy Date: Tue, 17 Mar 2026 02:50:31 -0400 Subject: [PATCH 1/2] WIP hooks/lifecycle events --- matrix/bot.py | 38 ++++++++++++-------- matrix/registry.py | 89 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 101 insertions(+), 26 deletions(-) diff --git a/matrix/bot.py b/matrix/bot.py index db083a4..44ebf0b 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -3,7 +3,7 @@ import asyncio import logging -from typing import Union, Optional +from typing import Union, Optional, Any from nio import AsyncClient, Event, MatrixRoom @@ -72,6 +72,9 @@ def load_extension(self, extension: Extension) -> None: for event_type, handlers in extension._event_handlers.items(): self._event_handlers[event_type].extend(handlers) + for hook_name, handlers in extension._hook_handlers.items(): + self._hook_handlers[hook_name].extend(handlers) + self._checks.extend(extension._checks) self._error_handlers.update(extension._error_handlers) self._command_error_handlers.update(extension._command_error_handlers) @@ -122,12 +125,19 @@ def _auto_register_events(self) -> None: for attr in dir(self): if not attr.startswith("on_"): continue + coro = getattr(self, attr, None) - if inspect.iscoroutinefunction(coro): - try: + if not inspect.iscoroutinefunction(coro): + continue + + try: + if attr in self.LIFECYCLE_EVENTS: + self.hook(coro) + + if attr in self.EVENT_MAP: self.event(coro) - except ValueError: # ignore unknown name - continue + except ValueError: + continue async def _on_event(self, room: MatrixRoom, event: Event) -> None: # ignore bot events @@ -139,11 +149,16 @@ async def _on_event(self, room: MatrixRoom, event: Event) -> None: return try: - await self._dispatch(room, event) + await self._dispatch_matrix_event(room, event) except Exception as error: await self.on_error(error) - async def _dispatch(self, room: MatrixRoom, event: Event) -> None: + async def _dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: + """Fire all listeners registered for a named lifecycle event.""" + for handler in self._hook_handlers.get(event_name, []): + await handler(*args, **kwargs) + + async def _dispatch_matrix_event(self, room: MatrixRoom, event: Event) -> None: """Internal type-based fan-out plus optional command handling.""" for event_type, funcs in self._event_handlers.items(): if isinstance(event, event_type): @@ -202,10 +217,6 @@ async def on_message(self, room: MatrixRoom, event: Event) -> None: This method is automatically called when a :class:`nio.RoomMessageText` event is detected. It is primarily responsible for detecting and processing commands that match the bot's defined prefix. - - :param ctx: The context object containing information about the Matrix - room and the message event. - :type ctx: Context """ await self._process_commands(room, event) @@ -217,9 +228,6 @@ async def on_error(self, error: Exception) -> None: """ Handle errors by invoking a registered error handler, a generic error callback, or logging the exception. - - :param error: The exception instance that was raised. - :type error: Exceptipon """ if handler := self._error_handlers.get(type(error)): await handler(error) @@ -268,7 +276,7 @@ async def run(self) -> None: self.scheduler.start() - await self.on_ready() + await self._dispatch("on_ready") await self.client.sync_forever(timeout=30_000) def start(self) -> None: diff --git a/matrix/registry.py b/matrix/registry.py index 2a9683b..8c5b678 100644 --- a/matrix/registry.py +++ b/matrix/registry.py @@ -47,6 +47,16 @@ class Registry: "on_member_change": RoomMemberEvent, } + LIFECYCLE_EVENTS: set[str] = { + "on_ready", + "on_error", + "on_command", + "on_command_error", + "on_command_invoke", + "on_load", + "on_unload", + } + def __init__(self, name: str, prefix: Optional[str] = None): self.name = name self.prefix = prefix @@ -57,6 +67,7 @@ def __init__(self, name: str, prefix: Optional[str] = None): self._scheduler: Scheduler = Scheduler() self._event_handlers: Dict[Type[Event], List[Callback]] = defaultdict(list) + self._hook_handlers: Dict[str, List[Callback]] = defaultdict(list) self._on_error: Optional[ErrorCallback] = None self._error_handlers: Dict[type[Exception], ErrorCallback] = {} self._command_error_handlers: Dict[type[Exception], CommandErrorCallback] = {} @@ -208,17 +219,15 @@ def wrapper(f: Callback) -> Callback: if not inspect.iscoroutinefunction(f): raise TypeError("Event handlers must be coroutines") - if event_spec: - if isinstance(event_spec, str): - event_type = self.EVENT_MAP.get(event_spec) - if event_type is None: - raise ValueError(f"Unknown event string: {event_spec}") - else: - event_type = event_spec - else: - event_type = self.EVENT_MAP.get(f.__name__) - if event_type is None: - raise ValueError(f"Unknown event name: {f.__name__}") + key = event_spec if isinstance(event_spec, str) else f.__name__ + event_type: type[Event] | None = ( + event_spec + if event_spec and not isinstance(event_spec, str) + else self.EVENT_MAP.get(key) + ) + + if event_type is None: + raise ValueError(f"Unknown event: {key!r}") return self.register_event(event_type, f) @@ -238,6 +247,64 @@ def register_event(self, event_type: Type[Event], callback: Callback) -> Callbac ) return callback + def hook( + self, func: Optional[Callback], *, event_name: Optional[str] = None + ) -> Union[Callback, Callable[[Callback], Callback]]: + """Decorator to register a coroutine as a lifecycle event hook. + + Lifecycle events include things like ``on_ready``, ``on_command``, + and ``on_error``. If the event name is not provided, it is inferred + from the function name. Multiple handlers for the same lifecycle + event are supported and called in registration order. + + ## Example + + ```python + @bot.hook + async def on_ready(): + print("Bot is ready!") + + @bot.hook(event_name="on_command") + async def log_command(ctx): + print(f"Command invoked: {ctx.command}") + ``` + """ + + def wrapper(f: Callback) -> Callback: + if not inspect.iscoroutinefunction(f): + raise TypeError("Lifecycle hooks must be coroutines") + + name = event_name or f.__name__ + if name not in self.LIFECYCLE_EVENTS: + raise ValueError(f"Unknown lifecycle event: {name}") + + return self.register_hook(name, f) + + if func is None: + return wrapper + return wrapper(func) + + def register_hook(self, event_name: str, callback: Callback) -> Callback: + """Register a lifecycle event hook directly for a given event name. + + Prefer the :meth:`hook` decorator for typical use. This method + is useful when loading lifecycle hooks from an extension. + """ + if not inspect.iscoroutinefunction(callback): + raise TypeError("Lifecycle hooks must be coroutines") + + if event_name not in self.LIFECYCLE_EVENTS: + raise ValueError(f"Unknown lifecycle event: {event_name}") + + self._hook_handlers[event_name].append(callback) + logger.debug( + "registered lifecycle hook '%s' for event '%s' on %s", + callback.__name__, + event_name, + type(self).__name__, + ) + return callback + def check(self, func: Callback) -> Callback: """Register a global check that must pass before any command is invoked. From 92036970bf64bc1549b25119f643162a13a6d022 Mon Sep 17 00:00:00 2001 From: penguinboi Date: Sun, 5 Apr 2026 03:34:29 -0400 Subject: [PATCH 2/2] dispatch other lifefcycle events and proper dispatch --- matrix/bot.py | 240 ++++++++++++++++++++--------------------- matrix/registry.py | 9 +- tests/test_bot.py | 63 ++++++----- tests/test_registry.py | 6 +- 4 files changed, 165 insertions(+), 153 deletions(-) diff --git a/matrix/bot.py b/matrix/bot.py index 44ebf0b..b863c79 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -49,9 +49,27 @@ def __init__( self.help: HelpCommand = help or DefaultHelpCommand(prefix=self.prefix) self.register_command(self.help) - self.client.add_event_callback(self._on_event, Event) + self.client.add_event_callback(self._on_matrix_event, Event) self._auto_register_events() + def _auto_register_events(self) -> None: + for attr in dir(self): + if not attr.startswith("on_"): + continue + + coro = getattr(self, attr, None) + if not inspect.iscoroutinefunction(coro): + continue + + try: + if attr in self.LIFECYCLE_EVENTS: + self.hook(coro) + + if attr in self.EVENT_MAP: + self.event(coro) + except ValueError: + continue + def get_room(self, room_id: str) -> Room: """Retrieve a Room instance based on the room_id.""" matrix_room = self.client.rooms[room_id] @@ -121,25 +139,106 @@ def unload_extension(self, ext_name: str) -> None: extension.unload() self.log.debug("unloaded extension '%s'", ext_name) - def _auto_register_events(self) -> None: - for attr in dir(self): - if not attr.startswith("on_"): - continue + # LIFECYCLE - coro = getattr(self, attr, None) - if not inspect.iscoroutinefunction(coro): - continue + async def on_ready(self) -> None: + """Override this in a subclass.""" + pass - try: - if attr in self.LIFECYCLE_EVENTS: - self.hook(coro) + async def _on_ready(self) -> None: + """Internal hook — always fires, calls public override then extension handlers.""" + await self.on_ready() + await self._dispatch("on_ready") - if attr in self.EVENT_MAP: - self.event(coro) - except ValueError: - continue + async def on_error(self, error: Exception) -> None: + """Override this in a subclass.""" + self.log.exception("Unhandled error: '%s'", error) + + async def _on_error(self, error: Exception) -> None: + if handler := self._error_handlers.get(type(error)): + await handler(error) + return + + if self._fallback_error_handler: + await self._fallback_error_handler(error) + return + + await self._dispatch("on_error", error) + + async def on_command(self, _ctx: Context) -> None: + """Override this in a subclass.""" + pass + + async def _on_command(self, ctx: Context) -> None: + await self._dispatch("on_command", ctx) + + async def on_command_error(self, _ctx: Context, error: Exception) -> None: + """Override this in a subclass.""" + self.log.exception("Unhandled error: '%s'", error) + + async def _on_command_error(self, ctx: Context, error: Exception) -> None: + """ + Handles errors raised during command invocation. + + This method is called automatically when a command error occurs. + If a specific error handler is registered for the type of the + exception, it will be invoked with the current context and error. + """ + if handler := self._command_error_handlers.get(type(error)): + await handler(ctx, error) + return + + await self._dispatch("on_command_error", ctx, error) + + # ENTRYPOINT + + def start(self) -> None: + """ + Synchronous entry point for running the bot. + + This is a convenience wrapper that allows running the bot like a + script using a blocking call. It internally calls :meth:`run` within + :func:`asyncio.run`, and ensures the client is closed gracefully + on interruption. + """ + try: + asyncio.run(self.run()) + except KeyboardInterrupt: + self.log.info("bot interrupted by user") + finally: + asyncio.run(self.client.close()) + + async def run(self) -> None: + """ + Log in to the Matrix homeserver and begin syncing events. + + This method should be used within an asynchronous context, + typically via :func:`asyncio.run`. It handles authentication, + calls the :meth:`on_ready` hook, and starts the long-running + sync loop for receiving events. + """ + self.client.user = self.config.user_id + + self.start_at = time.time() + self.log.info("starting – timestamp=%s", self.start_at) - async def _on_event(self, room: MatrixRoom, event: Event) -> None: + if self.config.token: + self.client.access_token = self.config.token + else: + login_resp = await self.client.login(self.config.password) + self.log.info("logged in: %s", login_resp) + + self.scheduler.start() + + await self._on_ready() + await self.client.sync_forever(timeout=30_000) + + # MATRIX EVENTS + + async def on_message(self, room: MatrixRoom, event: Event) -> None: + await self._process_commands(room, event) + + async def _on_matrix_event(self, room: MatrixRoom, event: Event) -> None: # ignore bot events if event.sender == self.client.user: return @@ -151,7 +250,7 @@ async def _on_event(self, room: MatrixRoom, event: Event) -> None: try: await self._dispatch_matrix_event(room, event) except Exception as error: - await self.on_error(error) + await self._on_error(error) async def _dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: """Fire all listeners registered for a named lifecycle event.""" @@ -159,7 +258,7 @@ async def _dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: await handler(*args, **kwargs) async def _dispatch_matrix_event(self, room: MatrixRoom, event: Event) -> None: - """Internal type-based fan-out plus optional command handling.""" + """Fire all listeners registered for a named matrix event.""" for event_type, funcs in self._event_handlers.items(): if isinstance(event, event_type): for func in funcs: @@ -174,123 +273,24 @@ async def _process_commands(self, room: MatrixRoom, event: Event) -> None: if not await check(ctx): raise CheckError(ctx.command, check) + await self._on_command(ctx) await ctx.command(ctx) async def _build_context(self, matrix_room: MatrixRoom, event: Event) -> Context: room = self.get_room(matrix_room.room_id) ctx = Context(bot=self, room=room, event=event) - prefix: str | None = None - - if self.prefix is not None and ctx.body.startswith(self.prefix): - prefix = self.prefix - else: - prefix = next( - ( - cmd.prefix - for cmd in self._commands.values() - if cmd.prefix is not None and ctx.body.startswith(cmd.prefix) - ), - self.config.prefix, - ) + prefix = self.prefix or self.config.prefix - if prefix is None or not ctx.body.startswith(prefix): + if not ctx.body.startswith(prefix): return ctx if parts := ctx.body[len(prefix) :].split(): cmd_name = parts[0] cmd = self._commands.get(cmd_name) - if cmd and cmd.prefix and not ctx.body.startswith(cmd.prefix): - return ctx - if not cmd: raise CommandNotFoundError(cmd_name) ctx.command = cmd return ctx - - async def on_message(self, room: MatrixRoom, event: Event) -> None: - """ - Invoked when a message event is received. - - This method is automatically called when a :class:`nio.RoomMessageText` - event is detected. It is primarily responsible for detecting and - processing commands that match the bot's defined prefix. - """ - await self._process_commands(room, event) - - async def on_ready(self) -> None: - """Invoked after a successful login, before sync starts.""" - self.log.info("bot is ready") - - async def on_error(self, error: Exception) -> None: - """ - Handle errors by invoking a registered error handler, - a generic error callback, or logging the exception. - """ - if handler := self._error_handlers.get(type(error)): - await handler(error) - return - - if self._on_error: - await self._on_error(error) - return - self.log.exception("Unhandled error: '%s'", error) - - async def on_command_error(self, ctx: "Context", error: Exception) -> None: - """ - Handles errors raised during command invocation. - - This method is called automatically when a command error occurs. - If a specific error handler is registered for the type of the - exception, it will be invoked with the current context and error. - - :param ctx: The context in which the command was invoked. - :type ctx: Context - :param error: The exception that was raised during command execution. - :type error: Exception - """ - if handler := self._command_error_handlers.get(type(error)): - await handler(ctx, error) - - async def run(self) -> None: - """ - Log in to the Matrix homeserver and begin syncing events. - - This method should be used within an asynchronous context, - typically via :func:`asyncio.run`. It handles authentication, - calls the :meth:`on_ready` hook, and starts the long-running - sync loop for receiving events. - """ - self.client.user = self.config.user_id - - self.start_at = time.time() - self.log.info("starting – timestamp=%s", self.start_at) - - if self.config.token: - self.client.access_token = self.config.token - else: - login_resp = await self.client.login(self.config.password) - self.log.info("logged in: %s", login_resp) - - self.scheduler.start() - - await self._dispatch("on_ready") - await self.client.sync_forever(timeout=30_000) - - def start(self) -> None: - """ - Synchronous entry point for running the bot. - - This is a convenience wrapper that allows running the bot like a - script using a blocking call. It internally calls :meth:`run` within - :func:`asyncio.run`, and ensures the client is closed gracefully - on interruption. - """ - try: - asyncio.run(self.run()) - except KeyboardInterrupt: - self.log.info("bot interrupted by user") - finally: - asyncio.run(self.client.close()) diff --git a/matrix/registry.py b/matrix/registry.py index 8c5b678..f8a6a66 100644 --- a/matrix/registry.py +++ b/matrix/registry.py @@ -52,9 +52,6 @@ class Registry: "on_error", "on_command", "on_command_error", - "on_command_invoke", - "on_load", - "on_unload", } def __init__(self, name: str, prefix: Optional[str] = None): @@ -68,7 +65,7 @@ def __init__(self, name: str, prefix: Optional[str] = None): self._event_handlers: Dict[Type[Event], List[Callback]] = defaultdict(list) self._hook_handlers: Dict[str, List[Callback]] = defaultdict(list) - self._on_error: Optional[ErrorCallback] = None + self._fallback_error_handler: Optional[ErrorCallback] = None self._error_handlers: Dict[type[Exception], ErrorCallback] = {} self._command_error_handlers: Dict[type[Exception], CommandErrorCallback] = {} @@ -248,7 +245,7 @@ def register_event(self, event_type: Type[Event], callback: Callback) -> Callbac return callback def hook( - self, func: Optional[Callback], *, event_name: Optional[str] = None + self, func: Optional[Callback] = None, *, event_name: Optional[str] = None ) -> Union[Callback, Callable[[Callback], Callback]]: """Decorator to register a coroutine as a lifecycle event hook. @@ -388,7 +385,7 @@ def wrapper(func: ErrorCallback) -> ErrorCallback: if exception: self._error_handlers[exception] = func else: - self._on_error = func + self._fallback_error_handler = func logger.debug( "registered error handler '%s' on %s", func.__name__, diff --git a/tests/test_bot.py b/tests/test_bot.py index 8ab4651..95e3d01 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -64,16 +64,15 @@ def test_bot_init_with_invalid_config_file(): def test_auto_register_events_registers_known_events(bot): - # Add a dummy coroutine named on_message_known to bot instance - async def on_message_known(room, event): + async def on_message(room, event): pass - setattr(bot, "on_message_known", on_message_known) + setattr(bot, "on_message", on_message) - with patch.object(bot, "event", wraps=bot.event) as event: + with patch.object(bot, "event", wraps=bot.event) as mock_event: bot._auto_register_events() - event.assert_any_call(on_message_known) + mock_event.assert_any_call(on_message) @pytest.mark.asyncio @@ -100,7 +99,7 @@ async def handler2(room, event): ) room = MatrixRoom("!roomid:matrix.org", "room_alias") - await bot._dispatch(room, event) + await bot._dispatch_matrix_event(room, event) assert "h1" in called assert "h2" in called @@ -114,26 +113,27 @@ async def test_on_event_ignores_self_events(bot): event.sender = "@grace:matrix.org" event.server_timestamp = 123456789 - with patch.object(bot, "_dispatch", new_callable=AsyncMock) as dispatch: - await bot._on_event(MatrixRoom("!room:matrix.org", "alias"), event) + with patch.object( + bot, "_dispatch_matrix_event", new_callable=AsyncMock + ) as dispatch: + await bot._on_matrix_event(MatrixRoom("!room:matrix.org", "alias"), event) dispatch.assert_not_called() @pytest.mark.asyncio async def test_on_event_ignores_old_events(bot, room, event): - # Set start_at after event time bot.client.user = "@somebot:matrix.org" bot.start_at = event.server_timestamp / 1000 + 10 - bot._dispatch = AsyncMock() - await bot._on_event(room, event) + bot._dispatch_matrix_event = AsyncMock() + await bot._on_matrix_event(room, event) - bot._dispatch.assert_not_called() + bot._dispatch_matrix_event.assert_not_called() @pytest.mark.asyncio async def test_on_event_calls_error_handler(bot): - bot._dispatch = AsyncMock(side_effect=Exception("boom")) + bot._dispatch_matrix_event = AsyncMock(side_effect=Exception("boom")) custom_error_handler = AsyncMock() bot.error()(custom_error_handler) @@ -144,7 +144,7 @@ async def test_on_event_calls_error_handler(bot): bot.start_at = 0 bot.client.user = "@grace:matrix.org" - await bot._on_event(MatrixRoom("!roomid", "alias"), event) + await bot._on_matrix_event(MatrixRoom("!roomid", "alias"), event) custom_error_handler.assert_awaited_once() @@ -156,13 +156,28 @@ async def test_on_message_calls_process_commands(bot, room, event): @pytest.mark.asyncio -async def test_on_ready(bot): - await bot.on_ready() - bot.log.info.assert_called_once_with("bot is ready") +async def test_on_ready_dispatches(bot): + with patch.object(bot, "_dispatch", new_callable=AsyncMock) as mock_dispatch: + await bot._on_ready() + mock_dispatch.assert_awaited_once_with("on_ready") @pytest.mark.asyncio -async def test_on_error_calls_custom_handler(bot): +async def test_on_error_calls_specific_handler(bot): + called = False + + @bot.error(ValueError) + async def custom_error_handler(e): + nonlocal called + called = True + + await bot._on_error(ValueError("test error")) + + assert called, "Specific error handler was not called" + + +@pytest.mark.asyncio +async def test_on_error_calls_fallback_handler(bot): called = False @bot.error() @@ -170,15 +185,14 @@ async def custom_error_handler(e): nonlocal called called = True - error = Exception("test error") - await bot.on_error(error) + await bot._fallback_error_handler(Exception("test error")) + await bot.on_error(Exception("test error")) - assert called, "Custom error handler was not called" + assert called, "Fallback error handler was not called" @pytest.mark.asyncio async def test_on_error_logs_when_no_handler(bot): - bot._on_error = None error = Exception("test") await bot.on_error(error) @@ -197,7 +211,6 @@ async def greet(ctx): event.body = "!greet" room = MatrixRoom("!roomid:matrix.org", "alias") - # Patch _build_context to return context with command assigned with patch.object( bot, "_build_context", new_callable=AsyncMock ) as mock_build_context: @@ -383,12 +396,12 @@ async def test_run_uses_token(): async def test_run_with_username_and_password(bot): bot.client.login = AsyncMock(return_value="login_resp") bot.client.sync_forever = AsyncMock() - bot.on_ready = AsyncMock() + bot._on_ready = AsyncMock() await bot.run() bot.client.login.assert_awaited_once_with("grace1234") - bot.on_ready.assert_awaited_once() + bot._on_ready.assert_awaited_once() bot.client.sync_forever.assert_awaited_once() diff --git a/tests/test_registry.py b/tests/test_registry.py index 56f326a..ebf3a60 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -297,12 +297,14 @@ async def on_value_error(error): assert registry._error_handlers[ValueError] is on_value_error -def test_register_generic_error_handler__expect_on_error_set(registry: Registry): +def test_register_generic_error_handler__expect_fallback_error_handler_set( + registry: Registry, +): @registry.error() async def on_any_error(error): pass - assert registry._on_error is on_any_error + assert registry._fallback_error_handler is on_any_error def test_register_error_handler_with_non_coroutine__expect_type_error(