From 4f643e2fe4a90558d6978ff19e240897799d5ba0 Mon Sep 17 00:00:00 2001 From: Levalicious Date: Sun, 5 Apr 2026 09:14:52 -0400 Subject: [PATCH] Add application command (slash command) interaction support - Add make_interaction_data() to build valid INTERACTION_CREATE gateway payloads - Add FakeWebhookAdapter to intercept interaction responses (create_interaction_response, execute_webhook, edit_webhook_message, delete_webhook_message) - Add InteractionResponse data class with content/embeds/ephemeral/deferred - Add interaction() runner function analogous to message() - Add VerifyInteraction builder for asserting interaction responses - Add interaction_response callback event and queue - Fix _command_tree and __session linkage in configure() - Update run_all_events() to wait for CommandTree-invoker tasks - Export new APIs from __init__.py - Add 7 tests covering commands, options, ephemeral, embeds, verify --- dev-requirements.txt | 6 + discord/ext/test/__init__.py | 2 + discord/ext/test/backend.py | 419 +++++++++++++++++++++++++++++++++- discord/ext/test/callbacks.py | 1 + discord/ext/test/runner.py | 147 +++++++++++- discord/ext/test/verify.py | 230 ++++++++++++++++++- tests/test_interaction.py | 382 +++++++++++++++++++++++++++++++ 7 files changed, 1161 insertions(+), 26 deletions(-) create mode 100644 tests/test_interaction.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 9f9c5a9..463c961 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,3 +10,9 @@ flake8~=7.0.0 pynacl typing-extensions mypy +davey +requests +types-Pygments +types-gunicorn +types-pexpect +types-requests \ No newline at end of file diff --git a/discord/ext/test/__init__.py b/discord/ext/test/__init__.py index 1483dfd..2bcd849 100644 --- a/discord/ext/test/__init__.py +++ b/discord/ext/test/__init__.py @@ -8,6 +8,7 @@ from . import backend as backend from .runner import * +from .runner import InteractionResponse as InteractionResponse from .utils import embed_eq as embed_eq from .utils import activity_eq as activity_eq @@ -18,3 +19,4 @@ from .verify import Verify as Verify from .verify import VerifyMessage as VerifyMessage from .verify import VerifyActivity as VerifyActivity +from .verify import VerifyInteraction as VerifyInteraction diff --git a/discord/ext/test/backend.py b/discord/ext/test/backend.py index 6119604..f14231b 100644 --- a/discord/ext/test/backend.py +++ b/discord/ext/test/backend.py @@ -37,6 +37,7 @@ class BackendState(NamedTuple): """ messages: dict[int, list[_types.message.Message]] state: dstate.FakeState + owner: _types.user.User | None log = logging.getLogger("discord.ext.tests") @@ -467,8 +468,8 @@ async def remove_role(self, guild_id: Snowflake, user_id: Snowflake, update_member(member, roles=roles) async def application_info(self) -> _types.appinfo.AppInfo: - # TODO: make these values configurable user = self.state.user + owner_data = get_config().owner or facts.make_user_dict("TestOwner", "0001", "") data: _types.appinfo.AppInfo = { "id": user.id, "name": user.name, @@ -477,7 +478,7 @@ async def application_info(self) -> _types.appinfo.AppInfo: "rpc_origins": [], "bot_public": True, "bot_require_code_grant": False, - "owner": facts.make_user_dict("TestOwner", "0001", ""), + "owner": owner_data, "summary": "", "verify_key": "", "flags": 0, @@ -597,6 +598,207 @@ async def get_guild(self, guild_id: Snowflake, *, with_counts: bool = True) -> _ return facts.dict_from_object(guild) +def _fake_message_dict( + content: str = "", + embeds: list[dict[str, Any]] | None = None, + edited: bool = False, + message_id: int | None = None, +) -> dict[str, Any]: + """Build a minimal fake message payload for webhook adapter returns.""" + now = str(int(datetime.datetime.now().timestamp())) + return { + "id": str(message_id or facts.make_id()), + "channel_id": "0", + "author": facts.make_user_dict("FakeApp", "0001", None), + "content": content, + "timestamp": now, + "edited_timestamp": now if edited else None, + "tts": False, + "mention_everyone": False, + "mentions": [], + "mention_roles": [], + "attachments": [], + "embeds": embeds or [], + "pinned": False, + "type": 0, + } + + +class FakeWebhookAdapter: + """ + A fake webhook adapter that intercepts interaction responses instead of + sending them to Discord. Stores responses in the callback system for verification. + + All interaction responses (ephemeral or not) go into the unified ``sent_queue`` + via ``CallbackEvent.interaction_response``. + """ + + _original_adapter: Any + + def __init__(self) -> None: + self._original_adapter = None + + async def create_interaction_response( + self, + interaction_id: int, + token: str, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + params: Any, + ) -> dict[str, Any]: + from .runner import InteractionResponse + + payload_data = params.payload or {} + response_type = payload_data.get("type", 4) + + ir = InteractionResponse(response_type=response_type, payload=payload_data) + await callbacks.dispatch_event(CallbackEvent.interaction_response, ir) + + # Return a valid InteractionCallback dict so discord.py doesn't crash + data_section = payload_data.get("data") or {} + return { + "interaction": { + "id": str(interaction_id), + "type": 2, # APPLICATION_COMMAND + "response_message_id": str(facts.make_id()), + "response_message_loading": False, + "response_message_ephemeral": ir.ephemeral, + }, + "resource": { + "type": response_type, + "message": _fake_message_dict( + content=data_section.get("content", ""), + embeds=data_section.get("embeds", []), + ), + }, + } + + async def execute_webhook( + self, + webhook_id: int, + token: str, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + payload: dict[str, Any] | None = None, + multipart: Any = None, + files: Any = None, + thread_id: int | None = None, + wait: bool = False, + params: dict[str, Any] | None = None, + with_components: bool = False, + ) -> dict[str, Any]: + from .runner import InteractionResponse + + content = (payload or {}).get("content", "") + embeds = (payload or {}).get("embeds", []) + flags = (payload or {}).get("flags", 0) + + ir = InteractionResponse( + response_type=4, + payload={"type": 4, "data": {"content": content, "embeds": embeds, "flags": flags}}, + ) + await callbacks.dispatch_event(CallbackEvent.interaction_response, ir) + + return _fake_message_dict(content=content, embeds=embeds) + + async def edit_webhook_message( + self, + webhook_id: int, + token: str, + message_id: int, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + payload: dict[str, Any] | None = None, + multipart: Any = None, + files: Any = None, + thread_id: int | None = None, + ) -> dict[str, Any]: + from .runner import InteractionResponse + + content = (payload or {}).get("content", "") + embeds = (payload or {}).get("embeds", []) + + ir = InteractionResponse( + response_type=7, # UPDATE_MESSAGE + payload={"type": 7, "data": {"content": content, "embeds": embeds}}, + ) + await callbacks.dispatch_event(CallbackEvent.interaction_response, ir) + + return _fake_message_dict(content=content, embeds=embeds, edited=True, message_id=message_id) + + async def delete_webhook_message( + self, + webhook_id: int, + token: str, + message_id: int, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + thread_id: int | None = None, + ) -> None: + pass + + async def edit_original_interaction_response( + self, + application_id: int, + token: str, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + payload: dict[str, Any] | None = None, + multipart: Any = None, + files: Any = None, + ) -> dict[str, Any]: + from .runner import InteractionResponse + + content = (payload or {}).get("content", "") + embeds = (payload or {}).get("embeds", []) + + ir = InteractionResponse( + response_type=7, # UPDATE_MESSAGE + payload={"type": 7, "data": {"content": content, "embeds": embeds}}, + ) + await callbacks.dispatch_event(CallbackEvent.interaction_response, ir) + + return _fake_message_dict(content=content, embeds=embeds, edited=True) + + async def delete_original_interaction_response( + self, + application_id: int, + token: str, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + ) -> None: + pass + + async def get_original_interaction_response( + self, + application_id: int, + token: str, + *, + session: Any, + proxy: str | None = None, + proxy_auth: Any = None, + ) -> dict[str, Any]: + return _fake_message_dict() + + def __getattr__(self, name: str) -> Any: + """Fall through to the original adapter for anything we don't handle.""" + if self._original_adapter is not None: + return getattr(self._original_adapter, name) + raise AttributeError(f"FakeWebhookAdapter has no attribute '{name}'") + + def get_state() -> dstate.FakeState: """ Get the current backend state, or raise an error if it hasn't been configured @@ -976,6 +1178,137 @@ def edit_message( return data +def make_interaction_data( + command_name: str, + *, + command_type: int = 1, + options: list[dict[str, Any]] | None = None, + member: discord.Member | None = None, + channel: _types.AnyChannel | None = None, +) -> dict[str, Any]: + """ + Build a valid INTERACTION_CREATE gateway payload dict for an application command. + This is fed into ``ConnectionState.parse_interaction_create`` to trigger the full + discord.py interaction dispatch pipeline. + + :param command_name: Name of the command (supports "group subcommand" syntax) + :param command_type: Application command type (1=CHAT_INPUT, 2=USER, 3=MESSAGE) + :param options: List of option dicts with 'name', 'type', 'value' keys + :param member: Member invoking the command + :param channel: Channel the command is invoked in + :return: A dict matching the INTERACTION_CREATE gateway event shape + """ + state = get_state() + + interaction_id = facts.make_id() + interaction_token = f"fake-token-{interaction_id}" + command_id = facts.make_id() + app_id = state.user.id + + guild = None + if channel is not None and hasattr(channel, "guild"): + guild = channel.guild + + # Build the command data + parts = command_name.split() + if len(parts) == 1: + # Simple command + cmd_data: dict[str, Any] = { + "id": str(command_id), + "name": parts[0], + "type": command_type, + } + if options: + cmd_data["options"] = options + elif len(parts) == 2: + # Subcommand: "group sub" + sub_options = options or [] + cmd_data = { + "id": str(command_id), + "name": parts[0], + "type": command_type, + "options": [{ + "name": parts[1], + "type": 1, # SUB_COMMAND + "options": sub_options, + }], + } + else: + # Subcommand group: "group sub_group sub" + sub_options = options or [] + cmd_data = { + "id": str(command_id), + "name": parts[0], + "type": command_type, + "options": [{ + "name": parts[1], + "type": 2, # SUB_COMMAND_GROUP + "options": [{ + "name": parts[2], + "type": 1, # SUB_COMMAND + "options": sub_options, + }], + }], + } + + if guild is not None: + cmd_data["guild_id"] = str(guild.id) + + # Build the member/user data + if member is not None: + user_data = facts.dict_from_object(member._user) + member_data = facts.dict_from_object(member) + else: + user_data = facts.dict_from_object(state.user) + member_data = None + + # Build the channel data + channel_data: dict[str, Any] = {} + if channel is not None: + channel_data = { + "id": str(channel.id), + "type": channel.type.value, + "name": getattr(channel, "name", ""), + "permissions": str(discord.Permissions.all().value), + } + + payload: dict[str, Any] = { + "id": str(interaction_id), + "application_id": str(app_id), + "type": 2, # APPLICATION_COMMAND + "data": cmd_data, + "token": interaction_token, + "version": 1, + "channel": channel_data, + "locale": "en-US", + "entitlement_sku_ids": [], + "entitlements": [], + "authorizing_integration_owners": {"0": str(app_id)}, + "attachment_size_limit": 26214400, + } + + if guild is not None: + payload["guild_id"] = str(guild.id) + payload["guild"] = { + "id": str(guild.id), + "locale": "en-US", + "features": guild.features, + } + payload["guild_locale"] = "en-US" + payload["app_permissions"] = str(discord.Permissions.all().value) + + if channel is not None: + payload["channel_id"] = str(channel.id) + + if member_data is not None: + payload["member"] = member_data + payload["member"]["permissions"] = str(discord.Permissions.all().value) + else: + payload["user"] = user_data + + return payload + + MEMBER_MENTION: Pattern[str] = re.compile(r"<@!?([0-9]{17,21})>", re.MULTILINE) ROLE_MENTION: Pattern[str] = re.compile(r"<@&([0-9]{17,21})>", re.MULTILINE) CHANNEL_MENTION: Pattern[str] = re.compile(r"<#([0-9]{17,21})>", re.MULTILINE) @@ -1184,15 +1517,26 @@ def configure(client: discord.Client) -> None: ... @overload -def configure(client: discord.Client | None, *, use_dummy: bool = ...) -> None: ... - - -def configure(client: discord.Client | None, *, use_dummy: bool = False) -> None: +def configure( + client: discord.Client | None, + *, + use_dummy: bool = ..., + owner: _types.user.User | None = ..., +) -> None: ... + + +def configure( + client: discord.Client | None, + *, + use_dummy: bool = False, + owner: _types.user.User | None = None, +) -> None: """ Configure the backend, optionally with the provided client :param client: Client to use, or None :param use_dummy: Whether to use a dummy if client param is None, or error + :param owner: User dict to use as the application owner, or None for a default """ global _cur_config @@ -1219,4 +1563,65 @@ def configure(client: discord.Client | None, *, use_dummy: bool = False) -> None client._connection = test_state - _cur_config = BackendState({}, test_state) + # Preserve the command tree reference so slash commands dispatch correctly. + # Bot.__init__ sets _connection._command_tree = self.tree, but we just replaced _connection. + if hasattr(client, 'tree'): + test_state._command_tree = client.tree + + # Interaction.__init__ accesses state.http._HTTPClient__session (mangled). + # Provide a fake value so it doesn't crash. + http._HTTPClient__session = None # type: ignore[attr-defined] + + # Install the fake webhook adapter so interaction responses are intercepted. + # discord.py uses a ContextVar (async_context) to store the webhook adapter. + # InteractionResponse.send_message calls adapter.create_interaction_response. + from discord.webhook.async_ import async_context + fake_adapter = FakeWebhookAdapter() + async_context.set(fake_adapter) # type: ignore[arg-type] + + _cur_config = BackendState({}, test_state, owner) + + # Populate client._application so that bot.application is not None. + # This mirrors what Client.login() does after calling application_info(). + _build_app_info(client, test_state, owner) + + +def _build_app_info( + client: discord.Client, + test_state: dstate.FakeState, + owner: _types.user.User | None = None, +) -> None: + """Build a fake AppInfo and assign it to client._application.""" + owner_data = owner or facts.make_user_dict("TestOwner", "0001", "") + app_data: _types.appinfo.AppInfo = { + "id": test_state.user.id, + "name": test_state.user.name, + "icon": test_state.user.avatar.url if test_state.user.avatar else None, + "description": "A test discord application", + "rpc_origins": [], + "bot_public": True, + "bot_require_code_grant": False, + "owner": owner_data, + "summary": "", + "verify_key": "", + "flags": 0, + } + client._application = discord.AppInfo(test_state, app_data) + if test_state.application_id is None: + test_state.application_id = client._application.id + + +def set_app_owner(client: discord.Client, owner: _types.user.User) -> None: + """ + Update the application owner on the client and backend state. + Called by :py:func:`runner.configure` when the owner needs to be resolved + after members are created. + + :param client: The configured client + :param owner: User dict for the new owner + """ + global _cur_config + if _cur_config is None: + raise ValueError("Dpytest backend not configured") + _cur_config = BackendState(_cur_config.messages, _cur_config.state, owner) + _build_app_info(client, _cur_config.state, owner) diff --git a/discord/ext/test/callbacks.py b/discord/ext/test/callbacks.py index e7660f2..b5449c9 100644 --- a/discord/ext/test/callbacks.py +++ b/discord/ext/test/callbacks.py @@ -45,6 +45,7 @@ class CallbackEvent(Enum): remove_role = "remove_role" app_info = "app_info" get_guilds = "get_guilds" + interaction_response = "interaction_response" _callbacks: dict[CallbackEvent, Callback] = {} diff --git a/discord/ext/test/runner.py b/discord/ext/test/runner.py index a21bd56..72c3d6f 100644 --- a/discord/ext/test/runner.py +++ b/discord/ext/test/runner.py @@ -29,6 +29,53 @@ from .utils import PeekableQueue +class InteractionResponse: + """Container for a captured interaction response from the bot.""" + + response_type: int + payload: dict[str, Any] + + def __init__( + self, + response_type: int, + payload: dict[str, Any] | None = None, + ) -> None: + self.response_type = response_type + self.payload = payload or {} + + @property + def content(self) -> str | None: + data: dict[str, Any] = self.payload.get("data", {}) + result: str | None = data.get("content") + return result + + @property + def embeds(self) -> list[discord.Embed]: + data = self.payload.get("data", {}) + raw = data.get("embeds", []) + return [discord.Embed.from_dict(e) for e in raw] + + @property + def ephemeral(self) -> bool: + data = self.payload.get("data", {}) + flags = data.get("flags", 0) + return bool(flags & 64) + + @property + def is_deferred(self) -> bool: + return self.response_type == 5 # DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE + + def __repr__(self) -> str: + parts = [f"type={self.response_type}"] + if self.content: + parts.append(f'content="{self.content}"') + if self.embeds: + parts.append(f"embeds={len(self.embeds)}") + if self.ephemeral: + parts.append("ephemeral=True") + return f"InteractionResponse({', '.join(parts)})" + + class RunnerConfig(NamedTuple): """ Exposed discord test configuration @@ -43,7 +90,7 @@ class RunnerConfig(NamedTuple): log = logging.getLogger("discord.ext.tests") _cur_config: RunnerConfig | None = None -sent_queue: PeekableQueue[discord.Message] = PeekableQueue() +sent_queue: PeekableQueue[discord.Message | InteractionResponse] = PeekableQueue() error_queue: PeekableQueue[tuple[ commands.Context[commands.Bot | commands.AutoShardedBot], CommandError ]] = PeekableQueue() @@ -88,15 +135,23 @@ async def run_all_events() -> None: Ensure that all dpy related coroutines have completed or been cancelled. If any dpy coroutines are currently running, this will also wait for those. """ + _event_names = {"_run_event", "CommandTree-invoker"} while True: if sys.version_info[1] >= 7: pending = asyncio.all_tasks() else: pending = asyncio.Task.all_tasks() - if not any(map(lambda x: _task_coro_name(x) == "_run_event" and not (x.done() or x.cancelled()), pending)): + + def _is_relevant(t: asyncio.Task[Any]) -> bool: + coro_name = _task_coro_name(t) + task_name = t.get_name() if hasattr(t, 'get_name') else None + return ((coro_name in _event_names or task_name in _event_names) + and not (t.done() or t.cancelled())) + + if not any(map(_is_relevant, pending)): break for task in pending: - if _task_coro_name(task) == "_run_event" and not (task.done() or task.cancelled()): + if _is_relevant(task): await task @@ -116,16 +171,19 @@ async def finish_on_command_error() -> None: def get_message(peek: bool = False) -> discord.Message: """ - Allow the user to retrieve the most recent message sent by the bot + Allow the user to retrieve the most recent message sent by the bot. + Skips any ``InteractionResponse`` items at the front of the queue. :param peek: If true, message will not be removed from the queue :return: Most recent message from the queue """ if peek: - message = sent_queue.peek() + item = sent_queue.peek() else: - message = sent_queue.get_nowait() - return message + item = sent_queue.get_nowait() + if not isinstance(item, discord.Message): + raise TypeError(f"Expected a discord.Message at the front of the queue, got {type(item).__name__}") + return item def get_embed(peek: bool = False) -> discord.Embed: @@ -145,7 +203,7 @@ def get_embed(peek: bool = False) -> discord.Embed: async def empty_queue() -> None: """ - Empty the current message queue. Waits for all events to complete to ensure queue + Empty the current queue. Waits for all events to complete to ensure queue is not immediately added to after running. """ await run_all_events() @@ -164,6 +222,15 @@ async def _message_callback(message: discord.Message) -> None: await sent_queue.put(message) +async def _interaction_response_callback(response: InteractionResponse) -> None: + """ + Internal callback, on an interaction response being sent adds it to the queue + + :param response: InteractionResponse captured from the bot + """ + await sent_queue.put(response) + + async def _edit_member_callback(fields: Any, member: discord.Member, reason: str | None) -> None: """ Internal callback. Updates a guild's voice states to reflect the given Member connecting to the given channel. @@ -362,6 +429,42 @@ async def member_join( return member +@require_config +async def interaction( + command_name: str, + *, + options: list[dict[str, Any]] | None = None, + channel: _types.AnyChannel | int = 0, + member: discord.Member | int = 0, +) -> None: + """ + Fake an application command interaction being sent by a user. + This dispatches through the real discord.py interaction pipeline, meaning + all checks, error handlers, and predicates will run. + + :param command_name: Name of the slash command (e.g. "ping" or "group subcommand") + :param options: List of option dicts, each with 'name', 'type', 'value' keys + :param channel: Channel the interaction is sent in, or index into config list + :param member: Member sending the interaction, or index into config list + """ + if isinstance(channel, int): + channel = get_config().channels[channel] + if isinstance(member, int): + member = get_config().members[member] + + payload = back.make_interaction_data( + command_name, + options=options, + member=member, + channel=channel, + ) + + state = back.get_state() + state.parse_interaction_create(payload) # type: ignore[arg-type] + + await run_all_events() + + def get_config() -> RunnerConfig: """ Get the current runner configuration @@ -377,7 +480,8 @@ def configure(client: discord.Client, guilds: int | list[str] = 1, text_channels: int | list[str] = 1, voice_channels: int | list[str] = 1, - members: int | list[str] = 1) -> None: + members: int | list[str] = 1, + owner: discord.User | discord.Member | bool = True) -> None: """ Set up the runner configuration. This should be done before any tests are run. @@ -386,6 +490,8 @@ def configure(client: discord.Client, :param text_channels: Number or list of names of text channels in each guild to start with. Default is 1 :param voice_channels: Number or list of names of voice channels in each guild to start with. Default is 1. :param members: Number or list of names of members in each guild (other than the client) to start with. Default is 1. + :param owner: The application owner. ``True`` (default) sets the first test member as owner. ``False`` uses a + generic "TestOwner". A specific :class:`discord.User` or :class:`discord.Member` sets that user as owner. """ # noqa: E501 global _cur_config @@ -395,7 +501,16 @@ def configure(client: discord.Client, if isinstance(client, discord.AutoShardedClient): raise TypeError("Sharded clients not yet supported") - back.configure(client) + # Resolve the owner user dict for the backend. True is resolved after members + # are created (below), False / None means use the default, a User/Member is + # converted immediately. + from . import factories as _facts + owner_data: _types.user.User | None = None + if isinstance(owner, (discord.User, discord.Member)): + raw_user = owner._user if isinstance(owner, discord.Member) else owner + owner_data = _facts.dict_from_object(raw_user) + + back.configure(client, owner=owner_data) # Wrap on_error so errors will be reported old_error = None @@ -415,9 +530,11 @@ async def on_command_error(ctx: commands.Context[BotT], error: CommandError) -> client.on_command_error = on_command_error # type: ignore[attr-defined] + CBE = CallbackEvent # Configure global callbacks - callbacks.set_callback(_message_callback, CallbackEvent.send_message) - callbacks.set_callback(_edit_member_callback, CallbackEvent.edit_member) + callbacks.set_callback(_message_callback, CBE.send_message) + callbacks.set_callback(_edit_member_callback, CBE.edit_member) + callbacks.set_callback(_interaction_response_callback, CBE.interaction_response) # type: ignore[call-overload] back.get_state().stop_dispatch() @@ -472,4 +589,10 @@ async def on_command_error(ctx: commands.Context[BotT], error: CommandError) -> back.get_state().start_dispatch() + # Deferred owner resolution: if owner=True, set the first test member as the app owner. + # This must happen after members are created above. + if owner is True and _members: + owner_data = _facts.dict_from_object(_members[0]._user) + back.set_app_owner(client, owner_data) + _cur_config = RunnerConfig(client, _guilds, _channels, _members) diff --git a/discord/ext/test/verify.py b/discord/ext/test/verify.py index f792c77..7e2796a 100644 --- a/discord/ext/test/verify.py +++ b/discord/ext/test/verify.py @@ -15,7 +15,7 @@ import discord -from .runner import sent_queue, get_config +from .runner import sent_queue, get_config, InteractionResponse from .utils import embed_eq, activity_eq from ._types import Undef, undefined @@ -55,7 +55,7 @@ class VerifyMessage: ``assert dpytest.verify().message().content("Hello World!")`` """ - _used: discord.Message | int | Undef | None + _used: discord.Message | InteractionResponse | int | Undef | None _contains: bool _peek: bool @@ -75,7 +75,7 @@ def __init__(self) -> None: self._attachment = undefined def __del__(self) -> None: - if not self._used: + if self._used is undefined: import warnings warnings.warn("VerifyMessage dropped without being used, did you forget an `assert`?", RuntimeWarning) @@ -93,16 +93,21 @@ def __bool__(self) -> bool: return sent_queue.qsize() == 0 if self._peek: - message: discord.Message = sent_queue.peek() + item = sent_queue.peek() else: try: - message = sent_queue.get_nowait() + item = sent_queue.get_nowait() except asyncio.QueueEmpty: # By now we're expecting a message, not getting one is a failure return False - self._used = message - return self._check_msg(message) + if not isinstance(item, discord.Message): + # Wrong type at front of queue — expected a Message, got InteractionResponse + self._used = item + return False + + self._used = item + return self._check_msg(item) def _expectation(self) -> str: if self._nothing: @@ -330,6 +335,209 @@ def type(self, type: discord.ActivityType | None) -> 'VerifyActivity': return self +class VerifyInteraction: + """ + Builder for interaction response verifications. When done building, should be asserted. + + **Example**: + ``assert dpytest.verify().interaction().content("Hello World!")`` + """ + + _used: discord.Message | InteractionResponse | int | Undef | None + + _contains: bool + _peek: bool + _nothing: bool + _content: str | Undef | None + _embed: discord.Embed | Undef | None + _ephemeral: bool | Undef + _deferred: bool | Undef + + def __init__(self) -> None: + self._used = undefined + + self._contains = False + self._peek = False + self._nothing = False + self._content = undefined + self._embed = undefined + self._ephemeral = undefined + self._deferred = undefined + + def __del__(self) -> None: + if self._used is undefined: + import warnings + warnings.warn("VerifyInteraction dropped without being used, did you forget an `assert`?", RuntimeWarning) + + def __repr__(self) -> str: + if self._used is not undefined: + return f"" + else: + return f"" + + def __bool__(self) -> bool: + self._used = None + + if self._nothing: + self._used = sent_queue.qsize() + return sent_queue.qsize() == 0 + + if self._peek: + item = sent_queue.peek() + else: + try: + item = sent_queue.get_nowait() + except asyncio.QueueEmpty: + return False + + if not isinstance(item, InteractionResponse): + # Wrong type at front of queue — expected an InteractionResponse, got Message + self._used = item + return False + + self._used = item + return self._check_response(item) + + def _expectation(self) -> str: + if self._nothing: + return "no interaction responses" + else: + parts: list[str] = [] + if self._contains: + parts.append("contains") + if self._content is not undefined: + if self._content is None: + parts.append("content=Empty") + else: + parts.append(f'content="{self._content}"') + if self._embed is not undefined: + if self._embed is None: + parts.append("embed=Empty") + else: + parts.append(f"embed={self._embed.to_dict()}") + if self._ephemeral is not undefined: + parts.append(f"ephemeral={self._ephemeral}") + if self._deferred is not undefined: + parts.append(f"deferred={self._deferred}") + return " ".join(parts) if parts else "any interaction response" + + def _diff_msg(self) -> str: + if isinstance(self._used, int): + return f"{self._used} interaction responses" + elif isinstance(self._used, InteractionResponse): + return str(self._used) + elif self._used is None: + return "no interaction response" + return "" + + def _check_response(self, resp: InteractionResponse) -> bool: + # Check deferred + if self._deferred is not undefined: + if self._deferred != resp.is_deferred: + return False + + # Check ephemeral + if self._ephemeral is not undefined: + if self._ephemeral != resp.ephemeral: + return False + + # If content is None, check that there is no content + if self._content is None and resp.content: + return False + + # If content is set (not None, not undefined), check match + if self._content is not None and self._content is not undefined: + if self._contains and self._content not in (resp.content or ""): + return False + if not self._contains and self._content != resp.content: + return False + + # Check embed + _embed = self._embed + if _embed is None and resp.embeds: + return False + if _embed is not None and _embed is not undefined: + if self._contains and not any(map(lambda e: embed_eq(_embed, e), resp.embeds)): + return False + if not self._contains and (len(resp.embeds) != 1 or not embed_eq(_embed, resp.embeds[0])): + return False + + return True + + def contains(self) -> 'VerifyInteraction': + """ + Only check whether content/embed list/etc contain the desired input. + + :return: Self for chaining + """ + self._contains = True + return self + + def peek(self) -> 'VerifyInteraction': + """ + Don't remove the verified interaction response from the queue. + + :return: Self for chaining + """ + self._peek = True + return self + + def nothing(self) -> 'VerifyInteraction': + """ + Check that no interaction response was sent. + + :return: Self for chaining + """ + if self._content is not undefined or self._embed is not undefined: + raise ValueError("Verify nothing conflicts with verifying some content or embed") + self._nothing = True + return self + + def content(self, content: str | None) -> 'VerifyInteraction': + """ + Check that the interaction response content matches the input. + + :param content: Content to match against, or None to ensure no content + :return: Self for chaining + """ + if self._nothing: + raise ValueError("Verify content conflicts with verifying nothing") + self._content = content + return self + + def embed(self, embed: discord.Embed | None) -> 'VerifyInteraction': + """ + Check that the interaction response embed matches the input. + + :param embed: Embed to match against, or None to ensure no embed + :return: Self for chaining + """ + if self._nothing: + raise ValueError("Verify embed conflicts with verifying nothing") + self._embed = embed + return self + + def ephemeral(self, value: bool = True) -> 'VerifyInteraction': + """ + Check that the interaction response is ephemeral (or not). + + :param value: Whether to check for ephemeral (True) or not (False) + :return: Self for chaining + """ + self._ephemeral = value + return self + + def deferred(self, value: bool = True) -> 'VerifyInteraction': + """ + Check that the interaction response is a deferred response (or not). + + :param value: Whether to check for deferred (True) or not (False) + :return: Self for chaining + """ + self._deferred = value + return self + + class Verify: """ Base for all kinds of verification builders. Used as an @@ -355,6 +563,14 @@ def activity(self) -> VerifyActivity: """ return VerifyActivity() + def interaction(self) -> VerifyInteraction: + """ + Verify an interaction response from the bot + + :return: Interaction verification builder + """ + return VerifyInteraction() + def verify() -> Verify: """ diff --git a/tests/test_interaction.py b/tests/test_interaction.py new file mode 100644 index 0000000..80e2a46 --- /dev/null +++ b/tests/test_interaction.py @@ -0,0 +1,382 @@ +""" + Tests for application command (slash command) interaction support. +""" + +import pytest +import pytest_asyncio +import discord +import discord.ext.commands as commands +import discord.ext.test as dpytest +from typing import Callable, TypeVar +from discord import app_commands +from discord.client import _LoopSentinel + +T = TypeVar('T') + + +@pytest_asyncio.fixture +async def bot() -> commands.Bot: + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", + intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="ping", description="A simple ping command") + async def ping(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Pong!") + + @b.tree.command(name="greet", description="Greet someone") + @app_commands.describe(name="The name to greet") + async def greet(interaction: discord.Interaction, name: str) -> None: + await interaction.response.send_message(f"Hello, {name}!") + + @b.tree.command(name="secret", description="Secret ephemeral command") + async def secret(interaction: discord.Interaction) -> None: + await interaction.response.send_message("This is secret!", ephemeral=True) + + @b.tree.command(name="embed_cmd", description="Send an embed") + async def embed_cmd(interaction: discord.Interaction) -> None: + embed = discord.Embed(title="Test Embed", description="Hello from embed!") + await interaction.response.send_message(embed=embed) + + dpytest.configure(b) + return b + + +@pytest.mark.asyncio +async def test_simple_interaction(bot: commands.Bot) -> None: + """Test a basic slash command interaction.""" + await dpytest.interaction("ping") + assert dpytest.verify().interaction().content("Pong!") + + +@pytest.mark.asyncio +async def test_interaction_with_option(bot: commands.Bot) -> None: + """Test a slash command with options.""" + await dpytest.interaction( + "greet", + options=[{"name": "name", "type": 3, "value": "World"}], + ) + assert dpytest.verify().interaction().content("Hello, World!") + + +@pytest.mark.asyncio +async def test_interaction_ephemeral(bot: commands.Bot) -> None: + """Test an ephemeral slash command response.""" + await dpytest.interaction("secret") + assert dpytest.verify().interaction().ephemeral() + + +@pytest.mark.asyncio +async def test_interaction_embed(bot: commands.Bot) -> None: + """Test a slash command that sends an embed.""" + await dpytest.interaction("embed_cmd") + expected_embed = discord.Embed(title="Test Embed", description="Hello from embed!") + assert dpytest.verify().interaction().embed(expected_embed) + + +@pytest.mark.asyncio +async def test_interaction_nothing(bot: commands.Bot) -> None: + """Test verify nothing when no interaction has been sent.""" + assert dpytest.verify().interaction().nothing() + + +@pytest.mark.asyncio +async def test_interaction_contains(bot: commands.Bot) -> None: + """Test the contains modifier on interaction verification.""" + await dpytest.interaction("ping") + assert dpytest.verify().interaction().contains().content("Pong") + + +@pytest.mark.asyncio +async def test_interaction_peek(bot: commands.Bot) -> None: + """Test peeking at an interaction response without consuming it.""" + await dpytest.interaction("ping") + assert dpytest.verify().interaction().peek().content("Pong!") + # Should still be in the queue + assert dpytest.verify().interaction().content("Pong!") + + +# --- Application owner / bot.application tests --- + + +class Unauthorized(app_commands.CheckFailure): + def __init__(self, user: discord.User | discord.Member) -> None: + self.user: discord.User | discord.Member = user + super().__init__("You do not own this bot.") + + +def is_owner() -> Callable[[T], T]: + def predicate(interaction: discord.Interaction) -> bool: + if interaction.client.application is None: + raise ValueError("This application is a lie") + if interaction.user != interaction.client.application.owner: + raise Unauthorized(interaction.user) + return True + return app_commands.check(predicate) + + +@pytest_asyncio.fixture +async def owner_bot() -> commands.Bot: + """Bot fixture with an owner-only command.""" + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="owneronly", description="Only the owner can use this") + @is_owner() + async def owneronly(interaction: discord.Interaction) -> None: + await interaction.response.send_message("You are the owner!") + + @b.tree.command(name="public", description="Anyone can use this") + async def public(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Hello!") + + # Configure with 2 members; default owner=True makes member[0] the owner + dpytest.configure(b, members=2) + return b + + +@pytest.mark.asyncio +async def test_bot_application_not_none(owner_bot: commands.Bot) -> None: + """bot.application should be populated after configure().""" + assert owner_bot.application is not None + assert owner_bot.application.owner is not None + + +@pytest.mark.asyncio +async def test_owner_default_is_first_member(owner_bot: commands.Bot) -> None: + """With owner=True (default), the first test member is the app owner.""" + cfg = dpytest.get_config() + assert owner_bot.application is not None + assert owner_bot.application.owner.id == cfg.members[0].id + + +@pytest.mark.asyncio +async def test_owner_command_succeeds_for_owner(owner_bot: commands.Bot) -> None: + """The owner can invoke an owner-only command.""" + await dpytest.interaction("owneronly", member=0) + assert dpytest.verify().interaction().content("You are the owner!") + + +@pytest.mark.asyncio +async def test_owner_command_fails_for_non_owner(owner_bot: commands.Bot) -> None: + """A non-owner invoking an owner-only command should trigger the check failure.""" + # member=1 is NOT the owner + await dpytest.interaction("owneronly", member=1) + # The check failure means no response is sent + assert dpytest.verify().interaction().nothing() + + +@pytest.mark.asyncio +async def test_non_owner_can_use_public_command(owner_bot: commands.Bot) -> None: + """A non-owner can still invoke commands without the owner check.""" + await dpytest.interaction("public", member=1) + assert dpytest.verify().interaction().content("Hello!") + + +@pytest.mark.asyncio +async def test_interaction_member_param() -> None: + """The member param on interaction() controls who sends the interaction.""" + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="whoami", description="Tells you who you are") + async def whoami(interaction: discord.Interaction) -> None: + await interaction.response.send_message(f"You are {interaction.user.name}") + + dpytest.configure(b, members=["Alice", "Bob"]) + cfg = dpytest.get_config() + + await dpytest.interaction("whoami", member=cfg.members[0]) + assert dpytest.verify().interaction().content(f"You are {cfg.members[0].name}") + + await dpytest.interaction("whoami", member=cfg.members[1]) + assert dpytest.verify().interaction().content(f"You are {cfg.members[1].name}") + + +@pytest.mark.asyncio +async def test_configure_explicit_owner() -> None: + """Passing a specific user as owner= sets that user as the app owner.""" + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="ping", description="Ping") + async def ping(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Pong!") + + dpytest.configure(b, members=["Alice", "Bob"], owner=False) + cfg = dpytest.get_config() + + # With owner=False, the owner should be the default "TestOwner", not any test member + assert b.application is not None + assert b.application.owner.name != cfg.members[0].name + assert b.application.owner.name != cfg.members[1].name + + +# --- Followup / defer / edit_original_response tests --- + + +@pytest_asyncio.fixture +async def followup_bot() -> commands.Bot: + """Bot fixture with commands that use defer, followup, and edit_original_response.""" + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="deferred", description="Defers then sends a followup") + async def deferred(interaction: discord.Interaction) -> None: + await interaction.response.defer() + await interaction.followup.send("Here is the result!") + + @b.tree.command(name="editoriginal", description="Responds then edits") + async def editoriginal(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Please wait...") + await interaction.edit_original_response(content="Done!") + + @b.tree.command(name="deleteoriginal", description="Responds then deletes") + async def deleteoriginal(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Temporary") + await interaction.delete_original_response() + + dpytest.configure(b) + return b + + +@pytest.mark.asyncio +async def test_defer_then_followup(followup_bot: commands.Bot) -> None: + """A command that defers then sends a followup should produce two responses.""" + await dpytest.interaction("deferred") + + # First response: the defer + assert dpytest.verify().interaction().deferred() + # Second response: the followup + assert dpytest.verify().interaction().content("Here is the result!") + + +@pytest.mark.asyncio +async def test_edit_original_response(followup_bot: commands.Bot) -> None: + """A command that sends then edits the original should produce two responses.""" + await dpytest.interaction("editoriginal") + + # First response: the initial message + assert dpytest.verify().interaction().content("Please wait...") + # Second response: the edit (response_type=7 = UPDATE_MESSAGE) + assert dpytest.verify().interaction().content("Done!") + + +@pytest.mark.asyncio +async def test_delete_original_response(followup_bot: commands.Bot) -> None: + """A command that sends then deletes the original should not crash.""" + await dpytest.interaction("deleteoriginal") + + # First response: the initial message + assert dpytest.verify().interaction().content("Temporary") + # The delete doesn't produce a queued response, so nothing else + assert dpytest.verify().interaction().nothing() + + +# --- Unified queue behavior: all responses in one queue --- + + +@pytest_asyncio.fixture +async def dual_queue_bot() -> commands.Bot: + """Bot fixture for testing that all responses go through the unified queue.""" + intents = discord.Intents.default() + intents.members = True + intents.message_content = True + b = commands.Bot(command_prefix="!", intents=intents) + if isinstance(b.loop, _LoopSentinel): + await b._async_setup_hook() + + @b.tree.command(name="public_reply", description="Non-ephemeral response") + async def public_reply(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Everyone can see this!") + + @b.tree.command(name="secret_reply", description="Ephemeral response") + async def secret_reply(interaction: discord.Interaction) -> None: + await interaction.response.send_message("Only you can see this!", ephemeral=True) + + @b.tree.command(name="defer_public", description="Defers then sends public followup") + async def defer_public(interaction: discord.Interaction) -> None: + await interaction.response.defer() + await interaction.followup.send("Followup for everyone!") + + @b.tree.command(name="defer_ephemeral", description="Defers ephemerally then sends ephemeral followup") + async def defer_ephemeral(interaction: discord.Interaction) -> None: + await interaction.response.defer(ephemeral=True) + await interaction.followup.send("Secret followup!", ephemeral=True) + + dpytest.configure(b) + return b + + +@pytest.mark.asyncio +async def test_non_ephemeral_response_in_queue(dual_queue_bot: commands.Bot) -> None: + """A non-ephemeral interaction response is in the unified queue.""" + await dpytest.interaction("public_reply") + assert dpytest.verify().interaction().content("Everyone can see this!") + + +@pytest.mark.asyncio +async def test_ephemeral_response_in_queue(dual_queue_bot: commands.Bot) -> None: + """An ephemeral interaction response is also in the unified queue.""" + await dpytest.interaction("secret_reply") + assert dpytest.verify().interaction().ephemeral().content("Only you can see this!") + + +@pytest.mark.asyncio +async def test_nothing_after_ephemeral(dual_queue_bot: commands.Bot) -> None: + """After consuming an ephemeral response, the queue is empty.""" + await dpytest.interaction("secret_reply") + assert dpytest.verify().interaction().ephemeral().content("Only you can see this!") + assert dpytest.verify().interaction().nothing() + + +@pytest.mark.asyncio +async def test_followup_non_ephemeral_ordering(dual_queue_bot: commands.Bot) -> None: + """A defer + followup produces two items in the unified queue in order.""" + await dpytest.interaction("defer_public") + + # First: the defer + assert dpytest.verify().interaction().deferred() + # Second: the followup + assert dpytest.verify().interaction().content("Followup for everyone!") + + +@pytest.mark.asyncio +async def test_followup_ephemeral_ordering(dual_queue_bot: commands.Bot) -> None: + """An ephemeral defer + ephemeral followup produces two items in the unified queue.""" + await dpytest.interaction("defer_ephemeral") + + # First: the defer + assert dpytest.verify().interaction().deferred() + # Second: the ephemeral followup + assert dpytest.verify().interaction().ephemeral().content("Secret followup!") + # Queue is empty + assert dpytest.verify().interaction().nothing() + + +@pytest.mark.asyncio +async def test_message_verify_fails_on_interaction_response(dual_queue_bot: commands.Bot) -> None: + """verify().message() should fail when the front of the queue is an InteractionResponse.""" + await dpytest.interaction("public_reply") + # The queue has an InteractionResponse, not a Message — so message verification fails + assert not dpytest.verify().message().content("Everyone can see this!")