From b81eb2c6fea7ddbea2cd3af7c7a9406005ecb0a0 Mon Sep 17 00:00:00 2001 From: penguinboi Date: Mon, 6 Apr 2026 23:57:07 -0400 Subject: [PATCH] Refactor Bot Lifecycle and Event Synchronization --- matrix/bot.py | 53 ++++++++++++++++---- matrix/errors.py | 4 ++ matrix/extension.py | 15 +++++- matrix/protocols.py | 2 +- tests/test_bot.py | 114 ++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 174 insertions(+), 14 deletions(-) diff --git a/matrix/bot.py b/matrix/bot.py index 26bd1c5..83f8e67 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -3,7 +3,7 @@ import asyncio import logging -from typing import Union, Optional, Any +from typing import Optional, Any from nio import AsyncClient, Event, MatrixRoom @@ -15,7 +15,12 @@ from .registry import Registry from .help import HelpCommand, DefaultHelpCommand from .scheduler import Scheduler -from .errors import AlreadyRegisteredError, CommandNotFoundError, CheckError +from .errors import ( + AlreadyRegisteredError, + CommandNotFoundError, + CheckError, + RoomNotFoundError, +) class Bot(Registry): @@ -36,7 +41,9 @@ def __init__( self._config: Config | None = None self._client: AsyncClient | None = None + self._synced: asyncio.Event = asyncio.Event() self._help: HelpCommand | None = help_ + self.extensions: dict[str, Extension] = {} self.scheduler: Scheduler = Scheduler() self.log: logging.Logger = logging.getLogger(__name__) @@ -75,10 +82,23 @@ def _auto_register_events(self) -> None: except ValueError: continue - def get_room(self, room_id: str) -> Room: - """Retrieve a Room instance based on the room_id.""" - matrix_room = self.client.rooms[room_id] - return Room(matrix_room=matrix_room, client=self.client) + def get_room(self, room_id: str) -> Room | None: + """Retrieve a `Room` instance by its Matrix room ID. + + Returns the `Room` object corresponding to `room_id` if it exists in + the client's known rooms. Returns `None` if the room cannot be found. + + ## Example + + ```python + room = bot.get_room("!abc123:matrix.org") + if room: + print(room.name) + ``` + """ + if matrix_room := self.client.rooms.get(room_id): + return Room(matrix_room=matrix_room, client=self.client) + return None def load_extension(self, extension: Extension) -> None: self.log.debug(f"Loading extension: '{extension.name}'") @@ -255,10 +275,16 @@ async def run(self) -> None: login_resp = await self.client.login(self.config.password) self.log.info("logged in: %s", login_resp) - self.scheduler.start() + sync_task = asyncio.create_task(self.client.sync_forever(timeout=30_000)) + await self._wait_until_synced() await self._on_ready() - await self.client.sync_forever(timeout=30_000) + + self.scheduler.start() + await sync_task + + async def _wait_until_synced(self) -> None: + await self._synced.wait() # MATRIX EVENTS @@ -266,6 +292,9 @@ async def on_message(self, room: Room, event: Event) -> None: await self._process_commands(room, event) async def _on_matrix_event(self, matrix_room: MatrixRoom, event: Event) -> None: + if not self._synced.is_set(): + self._synced.set() + # ignore bot events if event.sender == self.client.user: return @@ -276,6 +305,10 @@ async def _on_matrix_event(self, matrix_room: MatrixRoom, event: Event) -> None: try: room = self.get_room(matrix_room.room_id) + + if not room: + raise RoomNotFoundError(f"Room '{matrix_room.room_id}' not found.") + await self._dispatch_matrix_event(room, event) except Exception as error: await self._on_error(error) @@ -306,6 +339,10 @@ async def _process_commands(self, room: Room, event: Event) -> None: async def _build_context(self, matrix_room: Room, event: Event) -> Context: room = self.get_room(matrix_room.room_id) + + if not room: + raise RoomNotFoundError(f"Room '{matrix_room.room_id}' not found.") + ctx = Context(bot=self, room=room, event=event) prefix = self.prefix or self.config.prefix diff --git a/matrix/errors.py b/matrix/errors.py index 98e9c18..c9cca60 100644 --- a/matrix/errors.py +++ b/matrix/errors.py @@ -13,6 +13,10 @@ class MatrixError(Exception): pass +class RoomNotFoundError(MatrixError): + pass + + class RegistryError(MatrixError): pass diff --git a/matrix/extension.py b/matrix/extension.py index 8ef807a..a364358 100644 --- a/matrix/extension.py +++ b/matrix/extension.py @@ -17,7 +17,20 @@ def __init__(self, name: str, prefix: Optional[str] = None) -> None: self._on_load: Optional[Callable] = None self._on_unload: Optional[Callable] = None - def get_room(self, room_id: str) -> Room: + def get_room(self, room_id: str) -> Room | None: + """Retrieve a `Room` instance by its Matrix room ID. + + Returns the `Room` object corresponding to `room_id` if it exists in + the client's known rooms. Returns `None` if the room cannot be found. + + ## Example + + ```python + room = extension.get_room("!abc123:matrix.org") + if room: + print(room.name) + ``` + """ if self.bot is None: raise RuntimeError("Extension is not loaded") return self.bot.get_room(room_id) diff --git a/matrix/protocols.py b/matrix/protocols.py index 91d5c73..b72fcdf 100644 --- a/matrix/protocols.py +++ b/matrix/protocols.py @@ -6,4 +6,4 @@ class BotLike(Protocol): prefix: str | None - def get_room(self, room_id: str) -> Room: ... + def get_room(self, room_id: str) -> Room | None: ... diff --git a/tests/test_bot.py b/tests/test_bot.py index 00bf5ef..4d52a0e 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -313,6 +313,30 @@ async def ping(ctx): assert called, "Expected command handler to be called" +@pytest.mark.asyncio +async def test_command_not_processed_without_prefix(bot, room): + called = False + + @bot.command() + async def greet(ctx): + nonlocal called + called = True + + event = RoomMessageText.from_dict( + { + "content": {"body": "greet", "msgtype": "m.text"}, + "event_id": "$id", + "origin_server_ts": 123456, + "sender": "@user:matrix.org", + "type": "m.room.message", + } + ) + + await bot._process_commands(room, event) + + assert not called + + @pytest.mark.asyncio async def test_error_decorator_requires_coroutine(bot): with pytest.raises(TypeError): @@ -376,18 +400,37 @@ async def cmd2(ctx): pass +import asyncio + + +async def start_and_stop(coro): + task = asyncio.create_task(coro) + await asyncio.sleep(0) # allow startup + task.cancel() + await asyncio.gather(task, return_exceptions=True) + + @pytest.mark.asyncio async def test_run_uses_token(): bot = Bot() bot._load_config("tests/config_fixture_token.yaml") bot._client.sync_forever = AsyncMock() - bot.on_ready = AsyncMock() + bot._on_ready = AsyncMock() + + # unblock readiness + bot._synced.set() + + task = asyncio.create_task(bot.run()) + + await asyncio.sleep(0) + await asyncio.sleep(0) - await bot.run() + task.cancel() + await asyncio.gather(task, return_exceptions=True) assert bot._client.access_token == "abc123" - bot.on_ready.assert_awaited_once() + bot._on_ready.assert_awaited_once() bot._client.sync_forever.assert_awaited_once() @@ -397,7 +440,15 @@ async def test_run_with_username_and_password(bot): bot._client.sync_forever = AsyncMock() bot._on_ready = AsyncMock() - await bot.run() + bot._synced.set() + + task = asyncio.create_task(bot.run()) + + await asyncio.sleep(0) + await asyncio.sleep(0) + + task.cancel() + await asyncio.gather(task, return_exceptions=True) bot._client.login.assert_awaited_once_with("grace1234") bot._on_ready.assert_awaited_once() @@ -418,6 +469,39 @@ def test_start_handles_keyboard_interrupt(caplog): bot._client.close.assert_awaited_once() +@pytest.mark.asyncio +async def test_on_ready_called_only_once(bot): + # Prepare + bot._synced.set() + bot._on_ready = AsyncMock() + + # Simulate run + await bot._wait_until_synced() + await bot._on_ready() + + bot._on_ready.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_scheduler_starts_after_ready(bot): + bot._synced.set() + + order = [] + + async def ready(): + order.append("ready") + + bot._on_ready = AsyncMock(side_effect=ready) + bot.scheduler.start = MagicMock(side_effect=lambda: order.append("scheduler")) + + # Simulate run + await bot._wait_until_synced() + await bot._on_ready() + bot.scheduler.start() + + assert order == ["ready", "scheduler"] + + @pytest.mark.asyncio async def test_scheduled_task_in_scheduler(bot): @bot.schedule("* * * * *") @@ -743,3 +827,25 @@ def test_unload_extension_logs_unloading(bot: Bot, loaded_extension: Extension): bot.unload_extension(loaded_extension.name) bot.log.debug.assert_any_call("unloaded extension '%s'", loaded_extension.name) + + +def test_unload_extension_removes_only_its_jobs(bot: Bot): + ext_a = Extension(name="a") + ext_b = Extension(name="b") + + @ext_a.schedule("* * * * *") + async def task(): + pass + + @ext_b.schedule("* * * * *") + async def task(): + pass + + bot.load_extension(ext_a) + bot.load_extension(ext_b) + + bot.unload_extension("a") + + job_names = [j.name for j in bot.scheduler.jobs] + + assert "task" in job_names