Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_extension(self, extension: Extension) -> None:
)

self.extensions[extension.name] = extension
extension.load()
extension.load(self)
self.log.debug("loaded extension '%s'", extension.name)

def unload_extension(self, ext_name: str) -> None:
Expand Down
19 changes: 16 additions & 3 deletions matrix/extension.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import logging
import inspect
import logging
from typing import Callable, Optional

from typing import Any, Callable, Coroutine, Optional
from matrix.protocols import BotLike
from matrix.registry import Registry
from matrix.room import Room

logger = logging.getLogger(__name__)


class Extension(Registry):
def __init__(self, name: str, prefix: Optional[str] = None) -> None:
super().__init__(name, prefix=prefix)

self.bot: Optional[BotLike] = None
self._on_load: Optional[Callable] = None
self._on_unload: Optional[Callable] = None

def load(self) -> None:
def get_room(self, room_id: str) -> Room:
if self.bot is None:
raise RuntimeError("Extension is not loaded")
return self.bot.get_room(room_id)

def load(self, bot: BotLike) -> None:
self.bot = bot

if self._on_load:
self._on_load()

Expand All @@ -35,6 +46,8 @@ def setup():
return func

def unload(self) -> None:
self.bot = None

if self._on_unload:
self._on_unload()

Expand Down
9 changes: 9 additions & 0 deletions matrix/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Protocol

from matrix.room import Room


class BotLike(Protocol):
prefix: str | None

def get_room(self, room_id: str) -> Room: ...
98 changes: 92 additions & 6 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
import pytest

from unittest.mock import MagicMock
from typing import Optional

from matrix.extension import Extension
from matrix.room import Room


class MockBot:
prefix: str = "!"

def __init__(self, room: Optional[Room] = None) -> None:
self.get_room = MagicMock(return_value=room or MagicMock(spec=Room))


@pytest.fixture
def extension() -> Extension:
return Extension(name="test_ext", prefix="!")


@pytest.fixture
def bot() -> MockBot:
return MockBot()


# INIT


def test_init_with_name_and_prefix__expect_attributes_set():
ext = Extension(name="math", prefix="!")

Expand All @@ -21,6 +40,10 @@ def test_init_with_name_only__expect_prefix_is_none():
assert ext.prefix is None


def test_init__expect_bot_is_none(extension: Extension):
assert extension.bot is None


def test_init__expect_on_load_is_none(extension: Extension):
assert extension._on_load is None

Expand All @@ -45,6 +68,9 @@ def test_init__expect_empty_checks(extension: Extension):
assert extension._checks == []


# ON LOAD


def test_on_load_with_sync_function__expect_handler_registered(extension: Extension):
@extension.on_load
def setup():
Expand Down Expand Up @@ -86,20 +112,34 @@ def second():
assert extension._on_load is second


def test_load_with_registered_handler__expect_handler_called(extension: Extension):
# LOAD


def test_load__expect_bot_set(extension: Extension, bot: MockBot):
extension.load(bot)

assert extension.bot is bot


def test_load_with_registered_handler__expect_handler_called(
extension: Extension, bot: MockBot
):
called = []

@extension.on_load
def setup():
called.append(True)

extension.load()
extension.load(bot)

assert called == [True]


def test_load_with_no_handler__expect_no_error(extension: Extension):
extension.load()
def test_load_with_no_handler__expect_no_error(extension: Extension, bot: MockBot):
extension.load(bot)


# ON UNLOAD


def test_on_unload_with_sync_function__expect_handler_registered(extension: Extension):
Expand Down Expand Up @@ -143,17 +183,63 @@ def second():
assert extension._on_unload is second


def test_unload_with_registered_handler__expect_handler_called(extension: Extension):
# UNLOAD


def test_unload__expect_bot_cleared(extension: Extension, bot: MockBot):
extension.load(bot)
extension.unload()

assert extension.bot is None


def test_unload_with_registered_handler__expect_handler_called(
extension: Extension, bot: MockBot
):
called = []

@extension.on_unload
def teardown():
called.append(True)

extension.load(bot)
extension.unload()

assert called == [True]


def test_unload_with_no_handler__expect_no_error(extension: Extension):
def test_unload_with_no_handler__expect_no_error(extension: Extension, bot: MockBot):
extension.load(bot)
extension.unload()


# GET ROOM


def test_get_room_before_load__expect_runtime_error(extension: Extension):
with pytest.raises(RuntimeError, match="Extension is not loaded"):
extension.get_room("!room:example.com")


def test_get_room_after_load__expect_delegates_to_bot(
extension: Extension, bot: MockBot
):
room_id = "!room:example.com"
expected_room = MagicMock(spec=Room)
bot.get_room.return_value = expected_room

extension.load(bot)
result = extension.get_room(room_id)

bot.get_room.assert_called_once_with(room_id)
assert result is expected_room


def test_get_room_after_unload__expect_runtime_error(
extension: Extension, bot: MockBot
):
extension.load(bot)
extension.unload()

with pytest.raises(RuntimeError, match="Extension is not loaded"):
extension.get_room("!room:example.com")
Loading