Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions matrix/extension.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import inspect
import logging
from typing import Callable, Optional
from typing import Callable

from matrix.protocols import BotLike
from matrix.registry import Registry
from matrix.config import Config
from matrix.room import Room

logger = logging.getLogger(__name__)


class Extension(Registry):
def __init__(self, name: str, prefix: Optional[str] = None) -> None:
def __init__(self, name: str, prefix: str | None = None) -> None:
super().__init__(name, prefix=prefix)

self.bot: Optional[BotLike] = None
self._on_load: Optional[Callable] = None
self._on_unload: Optional[Callable] = None
self._bot: BotLike | None = None
self._on_load: Callable | None = None
self._on_unload: Callable | None = None

@property
def bot(self) -> BotLike:
assert self._bot, "Extension is not loaded"
return self._bot

@property
def config(self) -> Config:
return self.bot.config

def get_room(self, room_id: str) -> Room | None:
"""Retrieve a `Room` instance by its Matrix room ID.
Expand All @@ -31,12 +41,10 @@ def get_room(self, room_id: str) -> Room | None:
print(room.name)
```
"""
if self.bot is None:
raise RuntimeError("Extension is not loaded")
return self.bot.get_room(room_id)

def load(self, bot: BotLike) -> None:
self.bot = bot
self._bot = bot

if self._on_load:
self._on_load()
Expand All @@ -59,7 +67,7 @@ def setup():
return func

def unload(self) -> None:
self.bot = None
self._bot = None

if self._on_unload:
self._on_unload()
Expand Down
4 changes: 4 additions & 0 deletions matrix/protocols.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Protocol

from matrix.config import Config
from matrix.room import Room


class BotLike(Protocol):
prefix: str | None

@property
def config(self) -> Config: ...

def get_room(self, room_id: str) -> Room | None: ...
15 changes: 11 additions & 4 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from unittest.mock import MagicMock
from typing import Optional

from matrix.config import Config
from matrix.extension import Extension
from matrix.room import Room


class MockBot:
prefix: str = "!"

@property
def config(self) -> Config:
return MagicMock(spec=Config)

def __init__(self, room: Optional[Room] = None) -> None:
self.get_room = MagicMock(return_value=room or MagicMock(spec=Room))

Expand Down Expand Up @@ -41,7 +46,8 @@ def test_init_with_name_only__expect_prefix_is_none():


def test_init__expect_bot_is_none(extension: Extension):
assert extension.bot is None
with pytest.raises(AssertionError):
_ = extension.bot


def test_init__expect_on_load_is_none(extension: Extension):
Expand Down Expand Up @@ -190,7 +196,8 @@ def test_unload__expect_bot_cleared(extension: Extension, bot: MockBot):
extension.load(bot)
extension.unload()

assert extension.bot is None
with pytest.raises(AssertionError):
_ = extension.bot


def test_unload_with_registered_handler__expect_handler_called(
Expand All @@ -217,7 +224,7 @@ def test_unload_with_no_handler__expect_no_error(extension: Extension, bot: Mock


def test_get_room_before_load__expect_runtime_error(extension: Extension):
with pytest.raises(RuntimeError, match="Extension is not loaded"):
with pytest.raises(AssertionError):
extension.get_room("!room:example.com")


Expand All @@ -241,5 +248,5 @@ def test_get_room_after_unload__expect_runtime_error(
extension.load(bot)
extension.unload()

with pytest.raises(RuntimeError, match="Extension is not loaded"):
with pytest.raises(AssertionError):
extension.get_room("!room:example.com")
Loading