Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/bub/channels/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,65 @@

import asyncio
from dataclasses import dataclass
from typing import Any, ClassVar

from loguru import logger
from telegram import Update
from telegram import Message, Update
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters

from bub.channels.base import BaseChannel
from bub.channels.bus import MessageBus
from bub.channels.events import InboundMessage, OutboundMessage


class BubMessageFilter(filters.MessageFilter):
GROUP_CHAT_TYPES: ClassVar[set[str]] = {"group", "supergroup"}

def filter(self, message: Message) -> bool | dict[str, list[Any]] | None:
# Only text messages are allowed
text = message.text
if not text:
return False

# Private chat: accept all messages except for commands (starting with /)
if message.chat.type == "private":
return not filters.COMMAND.filter(message)

# Group chat: only allow `/bot`, mention bot, or reply to bot messages.
if message.chat.type in self.GROUP_CHAT_TYPES:
bot = message.get_bot()
bot_id = bot.id
bot_username = (bot.username or "").lower()
if text.startswith("/bot "):
return True

if self._mentions_bot(message, text, bot_id, bot_username):
return True

if self._is_reply_to_bot(message, bot_id):
return True

return False

def _mentions_bot(self, message: Message, text: str, bot_id: int, bot_username: str) -> bool:
for entity in message.entities or ():
if entity.type == "mention" and bot_username:
mention_text = text[entity.offset : entity.offset + entity.length]
if mention_text.lower() == f"@{bot_username}":
return True
continue
if entity.type == "text_mention" and entity.user and entity.user.id == bot_id:
return True
return False

@staticmethod
def _is_reply_to_bot(message: Message, bot_id: int) -> bool:
reply_to_message = message.reply_to_message
if reply_to_message is None or reply_to_message.from_user is None:
return False
return reply_to_message.from_user.id == bot_id


@dataclass(frozen=True)
class TelegramConfig:
"""Telegram adapter config."""
Expand Down Expand Up @@ -41,7 +90,7 @@ async def start(self) -> None:
self._app = Application.builder().token(self._config.token).build()
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("help", self._on_help))
self._app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
self._app.add_handler(MessageHandler(BubMessageFilter(), self._on_text))
await self._app.initialize()
await self._app.start()
updater = self._app.updater
Expand Down
75 changes: 75 additions & 0 deletions tests/test_telegram_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from dataclasses import dataclass
from types import SimpleNamespace

from bub.channels.telegram import BubMessageFilter


@dataclass
class DummyUser:
id: int


@dataclass
class DummyEntity:
type: str
offset: int = 0
length: int = 0
user: DummyUser | None = None


class DummyMessage:
def __init__(
self,
*,
text: str,
chat_type: str,
bot_id: int = 1000,
bot_username: str = "BubBot",
entities: list[DummyEntity] | None = None,
reply_to_message: object | None = None,
) -> None:
self.text = text
self.chat = SimpleNamespace(type=chat_type)
self.entities = entities or []
self.reply_to_message = reply_to_message
self._bot_id = bot_id
self._bot_username = bot_username

def get_bot(self) -> object:
return SimpleNamespace(id=self._bot_id, username=self._bot_username)


def test_group_allows_bot_prefix() -> None:
message = DummyMessage(text="/bot hello", chat_type="group")
assert BubMessageFilter().filter(message) is True


def test_group_allows_at_mention_by_username_entity() -> None:
message = DummyMessage(
text="@BubBot ping",
chat_type="supergroup",
entities=[DummyEntity(type="mention", offset=0, length=7)],
)
assert BubMessageFilter().filter(message) is True


def test_group_allows_at_mention_by_text_mention_entity() -> None:
message = DummyMessage(
text="ping bot",
chat_type="group",
entities=[DummyEntity(type="text_mention", user=DummyUser(id=1000))],
)
assert BubMessageFilter().filter(message) is True


def test_group_allows_reply_to_bot_message() -> None:
reply_to_message = SimpleNamespace(from_user=SimpleNamespace(id=1000))
message = DummyMessage(text="reply", chat_type="group", reply_to_message=reply_to_message)
assert BubMessageFilter().filter(message) is True


def test_group_rejects_unrelated_text() -> None:
message = DummyMessage(text="hello world", chat_type="group")
assert BubMessageFilter().filter(message) is False
Loading