From d4602d97056be4d9c13cc10cee527f06fc8f0942 Mon Sep 17 00:00:00 2001 From: penguinboi Date: Mon, 6 Apr 2026 00:41:08 -0400 Subject: [PATCH] Bot protocol to give get_room access to extension --- matrix/bot.py | 2 +- matrix/extension.py | 19 ++++++-- matrix/protocols.py | 9 ++++ tests/test_extension.py | 98 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 matrix/protocols.py diff --git a/matrix/bot.py b/matrix/bot.py index b863c79..163a847 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -105,7 +105,7 @@ def load_extension(self, extension: Extension) -> None: ) self.extensions[extension.name] = extension - extension.load() + extension.load(self) self.log.debug("loaded extension '%s'", extension.name) def unload_extension(self, ext_name: str) -> None: diff --git a/matrix/extension.py b/matrix/extension.py index 5fe9ec7..8ef807a 100644 --- a/matrix/extension.py +++ b/matrix/extension.py @@ -1,8 +1,10 @@ -import logging import inspect +import logging +from typing import Callable, Optional -from typing import Any, Callable, Coroutine, Optional +from matrix.protocols import BotLike from matrix.registry import Registry +from matrix.room import Room logger = logging.getLogger(__name__) @@ -10,10 +12,19 @@ class Extension(Registry): def __init__(self, name: str, prefix: Optional[str] = None) -> None: super().__init__(name, prefix=prefix) + + self.bot: Optional[BotLike] = None self._on_load: Optional[Callable] = None self._on_unload: Optional[Callable] = None - def load(self) -> None: + def get_room(self, room_id: str) -> Room: + if self.bot is None: + raise RuntimeError("Extension is not loaded") + return self.bot.get_room(room_id) + + def load(self, bot: BotLike) -> None: + self.bot = bot + if self._on_load: self._on_load() @@ -35,6 +46,8 @@ def setup(): return func def unload(self) -> None: + self.bot = None + if self._on_unload: self._on_unload() diff --git a/matrix/protocols.py b/matrix/protocols.py new file mode 100644 index 0000000..91d5c73 --- /dev/null +++ b/matrix/protocols.py @@ -0,0 +1,9 @@ +from typing import Protocol + +from matrix.room import Room + + +class BotLike(Protocol): + prefix: str | None + + def get_room(self, room_id: str) -> Room: ... diff --git a/tests/test_extension.py b/tests/test_extension.py index 8b96013..46806a0 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -1,6 +1,17 @@ import pytest +from unittest.mock import MagicMock +from typing import Optional + from matrix.extension import Extension +from matrix.room import Room + + +class MockBot: + prefix: str = "!" + + def __init__(self, room: Optional[Room] = None) -> None: + self.get_room = MagicMock(return_value=room or MagicMock(spec=Room)) @pytest.fixture @@ -8,6 +19,14 @@ def extension() -> Extension: return Extension(name="test_ext", prefix="!") +@pytest.fixture +def bot() -> MockBot: + return MockBot() + + +# INIT + + def test_init_with_name_and_prefix__expect_attributes_set(): ext = Extension(name="math", prefix="!") @@ -21,6 +40,10 @@ def test_init_with_name_only__expect_prefix_is_none(): assert ext.prefix is None +def test_init__expect_bot_is_none(extension: Extension): + assert extension.bot is None + + def test_init__expect_on_load_is_none(extension: Extension): assert extension._on_load is None @@ -45,6 +68,9 @@ def test_init__expect_empty_checks(extension: Extension): assert extension._checks == [] +# ON LOAD + + def test_on_load_with_sync_function__expect_handler_registered(extension: Extension): @extension.on_load def setup(): @@ -86,20 +112,34 @@ def second(): assert extension._on_load is second -def test_load_with_registered_handler__expect_handler_called(extension: Extension): +# LOAD + + +def test_load__expect_bot_set(extension: Extension, bot: MockBot): + extension.load(bot) + + assert extension.bot is bot + + +def test_load_with_registered_handler__expect_handler_called( + extension: Extension, bot: MockBot +): called = [] @extension.on_load def setup(): called.append(True) - extension.load() + extension.load(bot) assert called == [True] -def test_load_with_no_handler__expect_no_error(extension: Extension): - extension.load() +def test_load_with_no_handler__expect_no_error(extension: Extension, bot: MockBot): + extension.load(bot) + + +# ON UNLOAD def test_on_unload_with_sync_function__expect_handler_registered(extension: Extension): @@ -143,17 +183,63 @@ def second(): assert extension._on_unload is second -def test_unload_with_registered_handler__expect_handler_called(extension: Extension): +# UNLOAD + + +def test_unload__expect_bot_cleared(extension: Extension, bot: MockBot): + extension.load(bot) + extension.unload() + + assert extension.bot is None + + +def test_unload_with_registered_handler__expect_handler_called( + extension: Extension, bot: MockBot +): called = [] @extension.on_unload def teardown(): called.append(True) + extension.load(bot) extension.unload() assert called == [True] -def test_unload_with_no_handler__expect_no_error(extension: Extension): +def test_unload_with_no_handler__expect_no_error(extension: Extension, bot: MockBot): + extension.load(bot) extension.unload() + + +# GET ROOM + + +def test_get_room_before_load__expect_runtime_error(extension: Extension): + with pytest.raises(RuntimeError, match="Extension is not loaded"): + extension.get_room("!room:example.com") + + +def test_get_room_after_load__expect_delegates_to_bot( + extension: Extension, bot: MockBot +): + room_id = "!room:example.com" + expected_room = MagicMock(spec=Room) + bot.get_room.return_value = expected_room + + extension.load(bot) + result = extension.get_room(room_id) + + bot.get_room.assert_called_once_with(room_id) + assert result is expected_room + + +def test_get_room_after_unload__expect_runtime_error( + extension: Extension, bot: MockBot +): + extension.load(bot) + extension.unload() + + with pytest.raises(RuntimeError, match="Extension is not loaded"): + extension.get_room("!room:example.com")