diff --git a/matrix/bot.py b/matrix/bot.py index db083a4..b863c79 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -3,7 +3,7 @@ import asyncio import logging -from typing import Union, Optional +from typing import Union, Optional, Any from nio import AsyncClient, Event, MatrixRoom @@ -49,9 +49,27 @@ def __init__( self.help: HelpCommand = help or DefaultHelpCommand(prefix=self.prefix) self.register_command(self.help) - self.client.add_event_callback(self._on_event, Event) + self.client.add_event_callback(self._on_matrix_event, Event) self._auto_register_events() + def _auto_register_events(self) -> None: + for attr in dir(self): + if not attr.startswith("on_"): + continue + + coro = getattr(self, attr, None) + if not inspect.iscoroutinefunction(coro): + continue + + try: + if attr in self.LIFECYCLE_EVENTS: + self.hook(coro) + + if attr in self.EVENT_MAP: + self.event(coro) + except ValueError: + continue + def get_room(self, room_id: str) -> Room: """Retrieve a Room instance based on the room_id.""" matrix_room = self.client.rooms[room_id] @@ -72,6 +90,9 @@ def load_extension(self, extension: Extension) -> None: for event_type, handlers in extension._event_handlers.items(): self._event_handlers[event_type].extend(handlers) + for hook_name, handlers in extension._hook_handlers.items(): + self._hook_handlers[hook_name].extend(handlers) + self._checks.extend(extension._checks) self._error_handlers.update(extension._error_handlers) self._command_error_handlers.update(extension._command_error_handlers) @@ -118,133 +139,74 @@ def unload_extension(self, ext_name: str) -> None: extension.unload() self.log.debug("unloaded extension '%s'", ext_name) - def _auto_register_events(self) -> None: - for attr in dir(self): - if not attr.startswith("on_"): - continue - coro = getattr(self, attr, None) - if inspect.iscoroutinefunction(coro): - try: - self.event(coro) - except ValueError: # ignore unknown name - continue - - async def _on_event(self, room: MatrixRoom, event: Event) -> None: - # ignore bot events - if event.sender == self.client.user: - return - - # ignore events that happened before the bot started - if self.start_at and self.start_at > (event.server_timestamp / 1000): - return - - try: - await self._dispatch(room, event) - except Exception as error: - await self.on_error(error) - - async def _dispatch(self, room: MatrixRoom, event: Event) -> None: - """Internal type-based fan-out plus optional command handling.""" - for event_type, funcs in self._event_handlers.items(): - if isinstance(event, event_type): - for func in funcs: - await func(room, event) - - async def _process_commands(self, room: MatrixRoom, event: Event) -> None: - """Parse and execute commands""" - ctx = await self._build_context(room, event) - - if ctx.command: - for check in self._checks: - if not await check(ctx): - raise CheckError(ctx.command, check) - - await ctx.command(ctx) - - async def _build_context(self, matrix_room: MatrixRoom, event: Event) -> Context: - room = self.get_room(matrix_room.room_id) - ctx = Context(bot=self, room=room, event=event) - prefix: str | None = None - - if self.prefix is not None and ctx.body.startswith(self.prefix): - prefix = self.prefix - else: - prefix = next( - ( - cmd.prefix - for cmd in self._commands.values() - if cmd.prefix is not None and ctx.body.startswith(cmd.prefix) - ), - self.config.prefix, - ) - - if prefix is None or not ctx.body.startswith(prefix): - return ctx - - if parts := ctx.body[len(prefix) :].split(): - cmd_name = parts[0] - cmd = self._commands.get(cmd_name) - - if cmd and cmd.prefix and not ctx.body.startswith(cmd.prefix): - return ctx - - if not cmd: - raise CommandNotFoundError(cmd_name) - - ctx.command = cmd - - return ctx - - async def on_message(self, room: MatrixRoom, event: Event) -> None: - """ - Invoked when a message event is received. - - This method is automatically called when a :class:`nio.RoomMessageText` - event is detected. It is primarily responsible for detecting and - processing commands that match the bot's defined prefix. - - :param ctx: The context object containing information about the Matrix - room and the message event. - :type ctx: Context - """ - await self._process_commands(room, event) + # LIFECYCLE async def on_ready(self) -> None: - """Invoked after a successful login, before sync starts.""" - self.log.info("bot is ready") + """Override this in a subclass.""" + pass + + async def _on_ready(self) -> None: + """Internal hook — always fires, calls public override then extension handlers.""" + await self.on_ready() + await self._dispatch("on_ready") async def on_error(self, error: Exception) -> None: - """ - Handle errors by invoking a registered error handler, - a generic error callback, or logging the exception. + """Override this in a subclass.""" + self.log.exception("Unhandled error: '%s'", error) - :param error: The exception instance that was raised. - :type error: Exceptipon - """ + async def _on_error(self, error: Exception) -> None: if handler := self._error_handlers.get(type(error)): await handler(error) return - if self._on_error: - await self._on_error(error) + if self._fallback_error_handler: + await self._fallback_error_handler(error) return + + await self._dispatch("on_error", error) + + async def on_command(self, _ctx: Context) -> None: + """Override this in a subclass.""" + pass + + async def _on_command(self, ctx: Context) -> None: + await self._dispatch("on_command", ctx) + + async def on_command_error(self, _ctx: Context, error: Exception) -> None: + """Override this in a subclass.""" self.log.exception("Unhandled error: '%s'", error) - async def on_command_error(self, ctx: "Context", error: Exception) -> None: + async def _on_command_error(self, ctx: Context, error: Exception) -> None: """ Handles errors raised during command invocation. This method is called automatically when a command error occurs. If a specific error handler is registered for the type of the exception, it will be invoked with the current context and error. - - :param ctx: The context in which the command was invoked. - :type ctx: Context - :param error: The exception that was raised during command execution. - :type error: Exception """ if handler := self._command_error_handlers.get(type(error)): await handler(ctx, error) + return + + await self._dispatch("on_command_error", ctx, error) + + # ENTRYPOINT + + def start(self) -> None: + """ + Synchronous entry point for running the bot. + + This is a convenience wrapper that allows running the bot like a + script using a blocking call. It internally calls :meth:`run` within + :func:`asyncio.run`, and ensures the client is closed gracefully + on interruption. + """ + try: + asyncio.run(self.run()) + except KeyboardInterrupt: + self.log.info("bot interrupted by user") + finally: + asyncio.run(self.client.close()) async def run(self) -> None: """ @@ -268,21 +230,67 @@ async def run(self) -> None: self.scheduler.start() - await self.on_ready() + await self._on_ready() await self.client.sync_forever(timeout=30_000) - def start(self) -> None: - """ - Synchronous entry point for running the bot. + # MATRIX EVENTS + + async def on_message(self, room: MatrixRoom, event: Event) -> None: + await self._process_commands(room, event) + + async def _on_matrix_event(self, room: MatrixRoom, event: Event) -> None: + # ignore bot events + if event.sender == self.client.user: + return + + # ignore events that happened before the bot started + if self.start_at and self.start_at > (event.server_timestamp / 1000): + return - This is a convenience wrapper that allows running the bot like a - script using a blocking call. It internally calls :meth:`run` within - :func:`asyncio.run`, and ensures the client is closed gracefully - on interruption. - """ try: - asyncio.run(self.run()) - except KeyboardInterrupt: - self.log.info("bot interrupted by user") - finally: - asyncio.run(self.client.close()) + await self._dispatch_matrix_event(room, event) + except Exception as error: + await self._on_error(error) + + async def _dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: + """Fire all listeners registered for a named lifecycle event.""" + for handler in self._hook_handlers.get(event_name, []): + await handler(*args, **kwargs) + + async def _dispatch_matrix_event(self, room: MatrixRoom, event: Event) -> None: + """Fire all listeners registered for a named matrix event.""" + for event_type, funcs in self._event_handlers.items(): + if isinstance(event, event_type): + for func in funcs: + await func(room, event) + + async def _process_commands(self, room: MatrixRoom, event: Event) -> None: + """Parse and execute commands""" + ctx = await self._build_context(room, event) + + if ctx.command: + for check in self._checks: + if not await check(ctx): + raise CheckError(ctx.command, check) + + await self._on_command(ctx) + await ctx.command(ctx) + + async def _build_context(self, matrix_room: MatrixRoom, event: Event) -> Context: + room = self.get_room(matrix_room.room_id) + ctx = Context(bot=self, room=room, event=event) + prefix = self.prefix or self.config.prefix + + if not ctx.body.startswith(prefix): + return ctx + + if parts := ctx.body[len(prefix) :].split(): + cmd_name = parts[0] + cmd = self._commands.get(cmd_name) + + if not cmd: + raise CommandNotFoundError(cmd_name) + + ctx.command = cmd + + return ctx diff --git a/matrix/registry.py b/matrix/registry.py index 2a9683b..f8a6a66 100644 --- a/matrix/registry.py +++ b/matrix/registry.py @@ -47,6 +47,13 @@ class Registry: "on_member_change": RoomMemberEvent, } + LIFECYCLE_EVENTS: set[str] = { + "on_ready", + "on_error", + "on_command", + "on_command_error", + } + def __init__(self, name: str, prefix: Optional[str] = None): self.name = name self.prefix = prefix @@ -57,7 +64,8 @@ def __init__(self, name: str, prefix: Optional[str] = None): self._scheduler: Scheduler = Scheduler() self._event_handlers: Dict[Type[Event], List[Callback]] = defaultdict(list) - self._on_error: Optional[ErrorCallback] = None + self._hook_handlers: Dict[str, List[Callback]] = defaultdict(list) + self._fallback_error_handler: Optional[ErrorCallback] = None self._error_handlers: Dict[type[Exception], ErrorCallback] = {} self._command_error_handlers: Dict[type[Exception], CommandErrorCallback] = {} @@ -208,17 +216,15 @@ def wrapper(f: Callback) -> Callback: if not inspect.iscoroutinefunction(f): raise TypeError("Event handlers must be coroutines") - if event_spec: - if isinstance(event_spec, str): - event_type = self.EVENT_MAP.get(event_spec) - if event_type is None: - raise ValueError(f"Unknown event string: {event_spec}") - else: - event_type = event_spec - else: - event_type = self.EVENT_MAP.get(f.__name__) - if event_type is None: - raise ValueError(f"Unknown event name: {f.__name__}") + key = event_spec if isinstance(event_spec, str) else f.__name__ + event_type: type[Event] | None = ( + event_spec + if event_spec and not isinstance(event_spec, str) + else self.EVENT_MAP.get(key) + ) + + if event_type is None: + raise ValueError(f"Unknown event: {key!r}") return self.register_event(event_type, f) @@ -238,6 +244,64 @@ def register_event(self, event_type: Type[Event], callback: Callback) -> Callbac ) return callback + def hook( + self, func: Optional[Callback] = None, *, event_name: Optional[str] = None + ) -> Union[Callback, Callable[[Callback], Callback]]: + """Decorator to register a coroutine as a lifecycle event hook. + + Lifecycle events include things like ``on_ready``, ``on_command``, + and ``on_error``. If the event name is not provided, it is inferred + from the function name. Multiple handlers for the same lifecycle + event are supported and called in registration order. + + ## Example + + ```python + @bot.hook + async def on_ready(): + print("Bot is ready!") + + @bot.hook(event_name="on_command") + async def log_command(ctx): + print(f"Command invoked: {ctx.command}") + ``` + """ + + def wrapper(f: Callback) -> Callback: + if not inspect.iscoroutinefunction(f): + raise TypeError("Lifecycle hooks must be coroutines") + + name = event_name or f.__name__ + if name not in self.LIFECYCLE_EVENTS: + raise ValueError(f"Unknown lifecycle event: {name}") + + return self.register_hook(name, f) + + if func is None: + return wrapper + return wrapper(func) + + def register_hook(self, event_name: str, callback: Callback) -> Callback: + """Register a lifecycle event hook directly for a given event name. + + Prefer the :meth:`hook` decorator for typical use. This method + is useful when loading lifecycle hooks from an extension. + """ + if not inspect.iscoroutinefunction(callback): + raise TypeError("Lifecycle hooks must be coroutines") + + if event_name not in self.LIFECYCLE_EVENTS: + raise ValueError(f"Unknown lifecycle event: {event_name}") + + self._hook_handlers[event_name].append(callback) + logger.debug( + "registered lifecycle hook '%s' for event '%s' on %s", + callback.__name__, + event_name, + type(self).__name__, + ) + return callback + def check(self, func: Callback) -> Callback: """Register a global check that must pass before any command is invoked. @@ -321,7 +385,7 @@ def wrapper(func: ErrorCallback) -> ErrorCallback: if exception: self._error_handlers[exception] = func else: - self._on_error = func + self._fallback_error_handler = func logger.debug( "registered error handler '%s' on %s", func.__name__, diff --git a/tests/test_bot.py b/tests/test_bot.py index 8ab4651..95e3d01 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -64,16 +64,15 @@ def test_bot_init_with_invalid_config_file(): def test_auto_register_events_registers_known_events(bot): - # Add a dummy coroutine named on_message_known to bot instance - async def on_message_known(room, event): + async def on_message(room, event): pass - setattr(bot, "on_message_known", on_message_known) + setattr(bot, "on_message", on_message) - with patch.object(bot, "event", wraps=bot.event) as event: + with patch.object(bot, "event", wraps=bot.event) as mock_event: bot._auto_register_events() - event.assert_any_call(on_message_known) + mock_event.assert_any_call(on_message) @pytest.mark.asyncio @@ -100,7 +99,7 @@ async def handler2(room, event): ) room = MatrixRoom("!roomid:matrix.org", "room_alias") - await bot._dispatch(room, event) + await bot._dispatch_matrix_event(room, event) assert "h1" in called assert "h2" in called @@ -114,26 +113,27 @@ async def test_on_event_ignores_self_events(bot): event.sender = "@grace:matrix.org" event.server_timestamp = 123456789 - with patch.object(bot, "_dispatch", new_callable=AsyncMock) as dispatch: - await bot._on_event(MatrixRoom("!room:matrix.org", "alias"), event) + with patch.object( + bot, "_dispatch_matrix_event", new_callable=AsyncMock + ) as dispatch: + await bot._on_matrix_event(MatrixRoom("!room:matrix.org", "alias"), event) dispatch.assert_not_called() @pytest.mark.asyncio async def test_on_event_ignores_old_events(bot, room, event): - # Set start_at after event time bot.client.user = "@somebot:matrix.org" bot.start_at = event.server_timestamp / 1000 + 10 - bot._dispatch = AsyncMock() - await bot._on_event(room, event) + bot._dispatch_matrix_event = AsyncMock() + await bot._on_matrix_event(room, event) - bot._dispatch.assert_not_called() + bot._dispatch_matrix_event.assert_not_called() @pytest.mark.asyncio async def test_on_event_calls_error_handler(bot): - bot._dispatch = AsyncMock(side_effect=Exception("boom")) + bot._dispatch_matrix_event = AsyncMock(side_effect=Exception("boom")) custom_error_handler = AsyncMock() bot.error()(custom_error_handler) @@ -144,7 +144,7 @@ async def test_on_event_calls_error_handler(bot): bot.start_at = 0 bot.client.user = "@grace:matrix.org" - await bot._on_event(MatrixRoom("!roomid", "alias"), event) + await bot._on_matrix_event(MatrixRoom("!roomid", "alias"), event) custom_error_handler.assert_awaited_once() @@ -156,13 +156,28 @@ async def test_on_message_calls_process_commands(bot, room, event): @pytest.mark.asyncio -async def test_on_ready(bot): - await bot.on_ready() - bot.log.info.assert_called_once_with("bot is ready") +async def test_on_ready_dispatches(bot): + with patch.object(bot, "_dispatch", new_callable=AsyncMock) as mock_dispatch: + await bot._on_ready() + mock_dispatch.assert_awaited_once_with("on_ready") @pytest.mark.asyncio -async def test_on_error_calls_custom_handler(bot): +async def test_on_error_calls_specific_handler(bot): + called = False + + @bot.error(ValueError) + async def custom_error_handler(e): + nonlocal called + called = True + + await bot._on_error(ValueError("test error")) + + assert called, "Specific error handler was not called" + + +@pytest.mark.asyncio +async def test_on_error_calls_fallback_handler(bot): called = False @bot.error() @@ -170,15 +185,14 @@ async def custom_error_handler(e): nonlocal called called = True - error = Exception("test error") - await bot.on_error(error) + await bot._fallback_error_handler(Exception("test error")) + await bot.on_error(Exception("test error")) - assert called, "Custom error handler was not called" + assert called, "Fallback error handler was not called" @pytest.mark.asyncio async def test_on_error_logs_when_no_handler(bot): - bot._on_error = None error = Exception("test") await bot.on_error(error) @@ -197,7 +211,6 @@ async def greet(ctx): event.body = "!greet" room = MatrixRoom("!roomid:matrix.org", "alias") - # Patch _build_context to return context with command assigned with patch.object( bot, "_build_context", new_callable=AsyncMock ) as mock_build_context: @@ -383,12 +396,12 @@ async def test_run_uses_token(): async def test_run_with_username_and_password(bot): bot.client.login = AsyncMock(return_value="login_resp") bot.client.sync_forever = AsyncMock() - bot.on_ready = AsyncMock() + bot._on_ready = AsyncMock() await bot.run() bot.client.login.assert_awaited_once_with("grace1234") - bot.on_ready.assert_awaited_once() + bot._on_ready.assert_awaited_once() bot.client.sync_forever.assert_awaited_once() diff --git a/tests/test_registry.py b/tests/test_registry.py index 56f326a..ebf3a60 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -297,12 +297,14 @@ async def on_value_error(error): assert registry._error_handlers[ValueError] is on_value_error -def test_register_generic_error_handler__expect_on_error_set(registry: Registry): +def test_register_generic_error_handler__expect_fallback_error_handler_set( + registry: Registry, +): @registry.error() async def on_any_error(error): pass - assert registry._on_error is on_any_error + assert registry._fallback_error_handler is on_any_error def test_register_error_handler_with_non_coroutine__expect_type_error(