diff --git a/src/bub/channels/telegram.py b/src/bub/channels/telegram.py index 45267d94..d5731a36 100644 --- a/src/bub/channels/telegram.py +++ b/src/bub/channels/telegram.py @@ -4,9 +4,10 @@ import asyncio from dataclasses import dataclass +from typing import Any, ClassVar from loguru import logger -from telegram import Update +from telegram import Message, Update from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from bub.channels.base import BaseChannel @@ -14,6 +15,54 @@ from bub.channels.events import InboundMessage, OutboundMessage +class BubMessageFilter(filters.MessageFilter): + GROUP_CHAT_TYPES: ClassVar[set[str]] = {"group", "supergroup"} + + def filter(self, message: Message) -> bool | dict[str, list[Any]] | None: + # Only text messages are allowed + text = message.text + if not text: + return False + + # Private chat: accept all messages except for commands (starting with /) + if message.chat.type == "private": + return not filters.COMMAND.filter(message) + + # Group chat: only allow `/bot`, mention bot, or reply to bot messages. + if message.chat.type in self.GROUP_CHAT_TYPES: + bot = message.get_bot() + bot_id = bot.id + bot_username = (bot.username or "").lower() + if text.startswith("/bot "): + return True + + if self._mentions_bot(message, text, bot_id, bot_username): + return True + + if self._is_reply_to_bot(message, bot_id): + return True + + return False + + def _mentions_bot(self, message: Message, text: str, bot_id: int, bot_username: str) -> bool: + for entity in message.entities or (): + if entity.type == "mention" and bot_username: + mention_text = text[entity.offset : entity.offset + entity.length] + if mention_text.lower() == f"@{bot_username}": + return True + continue + if entity.type == "text_mention" and entity.user and entity.user.id == bot_id: + return True + return False + + @staticmethod + def _is_reply_to_bot(message: Message, bot_id: int) -> bool: + reply_to_message = message.reply_to_message + if reply_to_message is None or reply_to_message.from_user is None: + return False + return reply_to_message.from_user.id == bot_id + + @dataclass(frozen=True) class TelegramConfig: """Telegram adapter config.""" @@ -41,7 +90,7 @@ async def start(self) -> None: self._app = Application.builder().token(self._config.token).build() self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("help", self._on_help)) - self._app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text)) + self._app.add_handler(MessageHandler(BubMessageFilter(), self._on_text)) await self._app.initialize() await self._app.start() updater = self._app.updater diff --git a/tests/test_telegram_filter.py b/tests/test_telegram_filter.py new file mode 100644 index 00000000..9ad72f94 --- /dev/null +++ b/tests/test_telegram_filter.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +from bub.channels.telegram import BubMessageFilter + + +@dataclass +class DummyUser: + id: int + + +@dataclass +class DummyEntity: + type: str + offset: int = 0 + length: int = 0 + user: DummyUser | None = None + + +class DummyMessage: + def __init__( + self, + *, + text: str, + chat_type: str, + bot_id: int = 1000, + bot_username: str = "BubBot", + entities: list[DummyEntity] | None = None, + reply_to_message: object | None = None, + ) -> None: + self.text = text + self.chat = SimpleNamespace(type=chat_type) + self.entities = entities or [] + self.reply_to_message = reply_to_message + self._bot_id = bot_id + self._bot_username = bot_username + + def get_bot(self) -> object: + return SimpleNamespace(id=self._bot_id, username=self._bot_username) + + +def test_group_allows_bot_prefix() -> None: + message = DummyMessage(text="/bot hello", chat_type="group") + assert BubMessageFilter().filter(message) is True + + +def test_group_allows_at_mention_by_username_entity() -> None: + message = DummyMessage( + text="@BubBot ping", + chat_type="supergroup", + entities=[DummyEntity(type="mention", offset=0, length=7)], + ) + assert BubMessageFilter().filter(message) is True + + +def test_group_allows_at_mention_by_text_mention_entity() -> None: + message = DummyMessage( + text="ping bot", + chat_type="group", + entities=[DummyEntity(type="text_mention", user=DummyUser(id=1000))], + ) + assert BubMessageFilter().filter(message) is True + + +def test_group_allows_reply_to_bot_message() -> None: + reply_to_message = SimpleNamespace(from_user=SimpleNamespace(id=1000)) + message = DummyMessage(text="reply", chat_type="group", reply_to_message=reply_to_message) + assert BubMessageFilter().filter(message) is True + + +def test_group_rejects_unrelated_text() -> None: + message = DummyMessage(text="hello world", chat_type="group") + assert BubMessageFilter().filter(message) is False