Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging

from typing import Union, Optional, Any
from typing import Optional, Any

from nio import AsyncClient, Event, MatrixRoom

Expand All @@ -15,7 +15,12 @@
from .registry import Registry
from .help import HelpCommand, DefaultHelpCommand
from .scheduler import Scheduler
from .errors import AlreadyRegisteredError, CommandNotFoundError, CheckError
from .errors import (
AlreadyRegisteredError,
CommandNotFoundError,
CheckError,
RoomNotFoundError,
)


class Bot(Registry):
Expand All @@ -36,7 +41,9 @@ def __init__(

self._config: Config | None = None
self._client: AsyncClient | None = None
self._synced: asyncio.Event = asyncio.Event()
self._help: HelpCommand | None = help_

self.extensions: dict[str, Extension] = {}
self.scheduler: Scheduler = Scheduler()
self.log: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,10 +82,23 @@ def _auto_register_events(self) -> None:
except ValueError:
continue

def get_room(self, room_id: str) -> Room:
"""Retrieve a Room instance based on the room_id."""
matrix_room = self.client.rooms[room_id]
return Room(matrix_room=matrix_room, client=self.client)
def get_room(self, room_id: str) -> Room | None:
"""Retrieve a `Room` instance by its Matrix room ID.

Returns the `Room` object corresponding to `room_id` if it exists in
the client's known rooms. Returns `None` if the room cannot be found.

## Example

```python
room = bot.get_room("!abc123:matrix.org")
if room:
print(room.name)
```
"""
if matrix_room := self.client.rooms.get(room_id):
return Room(matrix_room=matrix_room, client=self.client)
return None

def load_extension(self, extension: Extension) -> None:
self.log.debug(f"Loading extension: '{extension.name}'")
Expand Down Expand Up @@ -255,17 +275,26 @@ async def run(self) -> None:
login_resp = await self.client.login(self.config.password)
self.log.info("logged in: %s", login_resp)

self.scheduler.start()
sync_task = asyncio.create_task(self.client.sync_forever(timeout=30_000))

await self._wait_until_synced()
await self._on_ready()
await self.client.sync_forever(timeout=30_000)

self.scheduler.start()
await sync_task

async def _wait_until_synced(self) -> None:
await self._synced.wait()

# MATRIX EVENTS

async def on_message(self, room: Room, event: Event) -> None:
await self._process_commands(room, event)

async def _on_matrix_event(self, matrix_room: MatrixRoom, event: Event) -> None:
if not self._synced.is_set():
self._synced.set()

# ignore bot events
if event.sender == self.client.user:
return
Expand All @@ -276,6 +305,10 @@ async def _on_matrix_event(self, matrix_room: MatrixRoom, event: Event) -> None:

try:
room = self.get_room(matrix_room.room_id)

if not room:
raise RoomNotFoundError(f"Room '{matrix_room.room_id}' not found.")

await self._dispatch_matrix_event(room, event)
except Exception as error:
await self._on_error(error)
Expand Down Expand Up @@ -306,6 +339,10 @@ async def _process_commands(self, room: Room, event: Event) -> None:

async def _build_context(self, matrix_room: Room, event: Event) -> Context:
room = self.get_room(matrix_room.room_id)

if not room:
raise RoomNotFoundError(f"Room '{matrix_room.room_id}' not found.")

ctx = Context(bot=self, room=room, event=event)
prefix = self.prefix or self.config.prefix

Expand Down
4 changes: 4 additions & 0 deletions matrix/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class MatrixError(Exception):
pass


class RoomNotFoundError(MatrixError):
pass


class RegistryError(MatrixError):
pass

Expand Down
15 changes: 14 additions & 1 deletion matrix/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,20 @@ def __init__(self, name: str, prefix: Optional[str] = None) -> None:
self._on_load: Optional[Callable] = None
self._on_unload: Optional[Callable] = None

def get_room(self, room_id: str) -> Room:
def get_room(self, room_id: str) -> Room | None:
"""Retrieve a `Room` instance by its Matrix room ID.

Returns the `Room` object corresponding to `room_id` if it exists in
the client's known rooms. Returns `None` if the room cannot be found.

## Example

```python
room = extension.get_room("!abc123:matrix.org")
if room:
print(room.name)
```
"""
if self.bot is None:
raise RuntimeError("Extension is not loaded")
return self.bot.get_room(room_id)
Expand Down
2 changes: 1 addition & 1 deletion matrix/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
class BotLike(Protocol):
prefix: str | None

def get_room(self, room_id: str) -> Room: ...
def get_room(self, room_id: str) -> Room | None: ...
114 changes: 110 additions & 4 deletions tests/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,30 @@ async def ping(ctx):
assert called, "Expected command handler to be called"


@pytest.mark.asyncio
async def test_command_not_processed_without_prefix(bot, room):
called = False

@bot.command()
async def greet(ctx):
nonlocal called
called = True

event = RoomMessageText.from_dict(
{
"content": {"body": "greet", "msgtype": "m.text"},
"event_id": "$id",
"origin_server_ts": 123456,
"sender": "@user:matrix.org",
"type": "m.room.message",
}
)

await bot._process_commands(room, event)

assert not called


@pytest.mark.asyncio
async def test_error_decorator_requires_coroutine(bot):
with pytest.raises(TypeError):
Expand Down Expand Up @@ -376,18 +400,37 @@ async def cmd2(ctx):
pass


import asyncio


async def start_and_stop(coro):
task = asyncio.create_task(coro)
await asyncio.sleep(0) # allow startup
task.cancel()
await asyncio.gather(task, return_exceptions=True)


@pytest.mark.asyncio
async def test_run_uses_token():
bot = Bot()
bot._load_config("tests/config_fixture_token.yaml")

bot._client.sync_forever = AsyncMock()
bot.on_ready = AsyncMock()
bot._on_ready = AsyncMock()

# unblock readiness
bot._synced.set()

task = asyncio.create_task(bot.run())

await asyncio.sleep(0)
await asyncio.sleep(0)

await bot.run()
task.cancel()
await asyncio.gather(task, return_exceptions=True)

assert bot._client.access_token == "abc123"
bot.on_ready.assert_awaited_once()
bot._on_ready.assert_awaited_once()
bot._client.sync_forever.assert_awaited_once()


Expand All @@ -397,7 +440,15 @@ async def test_run_with_username_and_password(bot):
bot._client.sync_forever = AsyncMock()
bot._on_ready = AsyncMock()

await bot.run()
bot._synced.set()

task = asyncio.create_task(bot.run())

await asyncio.sleep(0)
await asyncio.sleep(0)

task.cancel()
await asyncio.gather(task, return_exceptions=True)

bot._client.login.assert_awaited_once_with("grace1234")
bot._on_ready.assert_awaited_once()
Expand All @@ -418,6 +469,39 @@ def test_start_handles_keyboard_interrupt(caplog):
bot._client.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_ready_called_only_once(bot):
# Prepare
bot._synced.set()
bot._on_ready = AsyncMock()

# Simulate run
await bot._wait_until_synced()
await bot._on_ready()

bot._on_ready.assert_awaited_once()


@pytest.mark.asyncio
async def test_scheduler_starts_after_ready(bot):
bot._synced.set()

order = []

async def ready():
order.append("ready")

bot._on_ready = AsyncMock(side_effect=ready)
bot.scheduler.start = MagicMock(side_effect=lambda: order.append("scheduler"))

# Simulate run
await bot._wait_until_synced()
await bot._on_ready()
bot.scheduler.start()

assert order == ["ready", "scheduler"]


@pytest.mark.asyncio
async def test_scheduled_task_in_scheduler(bot):
@bot.schedule("* * * * *")
Expand Down Expand Up @@ -743,3 +827,25 @@ def test_unload_extension_logs_unloading(bot: Bot, loaded_extension: Extension):
bot.unload_extension(loaded_extension.name)

bot.log.debug.assert_any_call("unloaded extension '%s'", loaded_extension.name)


def test_unload_extension_removes_only_its_jobs(bot: Bot):
ext_a = Extension(name="a")
ext_b = Extension(name="b")

@ext_a.schedule("* * * * *")
async def task():
pass

@ext_b.schedule("* * * * *")
async def task():
pass

bot.load_extension(ext_a)
bot.load_extension(ext_b)

bot.unload_extension("a")

job_names = [j.name for j in bot.scheduler.jobs]

assert "task" in job_names
Loading