diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 8f8449018..e4cc99709 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -14,6 +14,5 @@ then black slack_bolt/ tests/ && \ pytest -vv $1 else - black slack_bolt/ tests/ && pytest - fi + black slack_bolt/ tests/ && pytest fi diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 3fefac341..c72394821 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -735,7 +735,7 @@ def step( elif not isinstance(step, WorkflowStep): raise BoltError(f"Invalid step object ({type(step)})") - self.use(WorkflowStepMiddleware(step, self.listener_runner)) + self.use(WorkflowStepMiddleware(step)) # ------------------------- # global error handler @@ -1350,6 +1350,10 @@ def _init_context(self, req: BoltRequest): ) req.context["client"] = client_per_request + # Most apps do not need this "listener_runner" instance. + # It is intended for apps that start lazy listeners from their custom global middleware. + req.context["listener_runner"] = self.listener_runner + @staticmethod def _to_listener_functions( kwargs: dict, diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 8f66e3ba3..92bad71b7 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -761,7 +761,7 @@ def step( elif not isinstance(step, AsyncWorkflowStep): raise BoltError(f"Invalid step object ({type(step)})") - self.use(AsyncWorkflowStepMiddleware(step, self._async_listener_runner)) + self.use(AsyncWorkflowStepMiddleware(step)) # ------------------------- # global error handler @@ -1390,6 +1390,10 @@ def _init_context(self, req: AsyncBoltRequest): ) req.context["client"] = client_per_request + # Most apps do not need this "listener_runner" instance. + # It is intended for apps that start lazy listeners from their custom global middleware. + req.context["listener_runner"] = self.listener_runner + @staticmethod def _to_listener_functions( kwargs: dict, diff --git a/slack_bolt/context/async_context.py b/slack_bolt/context/async_context.py index 9381cdd3c..cd051fa31 100644 --- a/slack_bolt/context/async_context.py +++ b/slack_bolt/context/async_context.py @@ -17,20 +17,29 @@ class AsyncBoltContext(BaseContext): def to_copyable(self) -> "AsyncBoltContext": new_dict = {} for prop_name, prop_value in self.items(): - if prop_name in self.standard_property_names: + if prop_name in self.copyable_standard_property_names: # all the standard properties are copiable new_dict[prop_name] = prop_value + elif prop_name in self.non_copyable_standard_property_names: + # Do nothing with this property (e.g., listener_runner) + continue else: try: copied_value = create_copy(prop_value) new_dict[prop_name] = copied_value except TypeError as te: self.logger.debug( - f"Skipped settings '{prop_name}' to a copied request for lazy listeners " + f"Skipped setting '{prop_name}' to a copied request for lazy listeners " f"as it's not possible to make a deep copy (error: {te})" ) return AsyncBoltContext(new_dict) + # The return type is intentionally string to avoid circular imports + @property + def listener_runner(self) -> "AsyncioListenerRunner": # type: ignore[name-defined] + """The properly configured listener_runner that is available for middleware/listeners.""" + return self["listener_runner"] + @property def client(self) -> Optional[AsyncWebClient]: """The `AsyncWebClient` instance available for this request. diff --git a/slack_bolt/context/base_context.py b/slack_bolt/context/base_context.py index c85177664..2c00d8082 100644 --- a/slack_bolt/context/base_context.py +++ b/slack_bolt/context/base_context.py @@ -7,7 +7,7 @@ class BaseContext(dict): """Context object associated with a request from Slack.""" - standard_property_names = [ + copyable_standard_property_names = [ "logger", "token", "enterprise_id", @@ -35,6 +35,11 @@ class BaseContext(dict): "complete", "fail", ] + non_copyable_standard_property_names = [ + "listener_runner", + ] + + standard_property_names = copyable_standard_property_names + non_copyable_standard_property_names @property def logger(self) -> Logger: diff --git a/slack_bolt/context/context.py b/slack_bolt/context/context.py index 8faf8bd27..4c78c7fea 100644 --- a/slack_bolt/context/context.py +++ b/slack_bolt/context/context.py @@ -17,9 +17,12 @@ class BoltContext(BaseContext): def to_copyable(self) -> "BoltContext": new_dict = {} for prop_name, prop_value in self.items(): - if prop_name in self.standard_property_names: + if prop_name in self.copyable_standard_property_names: # all the standard properties are copiable new_dict[prop_name] = prop_value + elif prop_name in self.non_copyable_standard_property_names: + # Do nothing with this property (e.g., listener_runner) + continue else: try: copied_value = create_copy(prop_value) @@ -32,6 +35,12 @@ def to_copyable(self) -> "BoltContext": ) return BoltContext(new_dict) + # The return type is intentionally string to avoid circular imports + @property + def listener_runner(self) -> "ThreadListenerRunner": # type: ignore[name-defined] + """The properly configured listener_runner that is available for middleware/listeners.""" + return self["listener_runner"] + @property def client(self) -> Optional[WebClient]: """The `WebClient` instance available for this request. diff --git a/slack_bolt/listener/asyncio_runner.py b/slack_bolt/listener/asyncio_runner.py index 04f6b038e..01e8641ed 100644 --- a/slack_bolt/listener/asyncio_runner.py +++ b/slack_bolt/listener/asyncio_runner.py @@ -174,12 +174,11 @@ def _start_lazy_function(self, lazy_func: Callable[..., Awaitable[None]], reques copied_request = self._build_lazy_request(request, func_name) self.lazy_listener_runner.start(function=lazy_func, request=copied_request) - @staticmethod - def _build_lazy_request(request: AsyncBoltRequest, lazy_func_name: str) -> AsyncBoltRequest: - copied_request = create_copy(request.to_copyable()) - copied_request.method = "NONE" + def _build_lazy_request(self, request: AsyncBoltRequest, lazy_func_name: str) -> AsyncBoltRequest: + copied_request: AsyncBoltRequest = create_copy(request.to_copyable()) copied_request.lazy_only = True copied_request.lazy_function_name = lazy_func_name + copied_request.context["listener_runner"] = self return copied_request def _debug_log_completion(self, starting_time: float, response: BoltResponse) -> None: diff --git a/slack_bolt/listener/thread_runner.py b/slack_bolt/listener/thread_runner.py index 4821fb70b..c2d87b3d5 100644 --- a/slack_bolt/listener/thread_runner.py +++ b/slack_bolt/listener/thread_runner.py @@ -185,12 +185,11 @@ def _start_lazy_function(self, lazy_func: Callable[..., None], request: BoltRequ copied_request = self._build_lazy_request(request, func_name) self.lazy_listener_runner.start(function=lazy_func, request=copied_request) - @staticmethod - def _build_lazy_request(request: BoltRequest, lazy_func_name: str) -> BoltRequest: - copied_request = create_copy(request.to_copyable()) - copied_request.method = "NONE" + def _build_lazy_request(self, request: BoltRequest, lazy_func_name: str) -> BoltRequest: + copied_request: BoltRequest = create_copy(request.to_copyable()) copied_request.lazy_only = True copied_request.lazy_function_name = lazy_func_name + copied_request.context["listener_runner"] = self return copied_request def _debug_log_completion(self, starting_time: float, response: BoltResponse) -> None: diff --git a/slack_bolt/workflows/step/async_step_middleware.py b/slack_bolt/workflows/step/async_step_middleware.py index 62b3d1afe..5801a51e6 100644 --- a/slack_bolt/workflows/step/async_step_middleware.py +++ b/slack_bolt/workflows/step/async_step_middleware.py @@ -2,7 +2,6 @@ from typing import Callable, Optional, Awaitable from slack_bolt.listener.async_listener import AsyncListener -from slack_bolt.listener.asyncio_runner import AsyncioListenerRunner from slack_bolt.middleware.async_middleware import AsyncMiddleware from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.response import BoltResponse @@ -13,9 +12,8 @@ class AsyncWorkflowStepMiddleware(AsyncMiddleware): """Base middleware for step from app specific ones""" - def __init__(self, step: AsyncWorkflowStep, listener_runner: AsyncioListenerRunner): + def __init__(self, step: AsyncWorkflowStep): self.step = step - self.listener_runner = listener_runner async def async_process( self, @@ -40,8 +38,8 @@ async def async_process( return await next() + @staticmethod async def _run( - self, listener: AsyncListener, req: AsyncBoltRequest, resp: BoltResponse, @@ -50,7 +48,7 @@ async def _run( if next_was_not_called: return None - return await self.listener_runner.run( + return await req.context.listener_runner.run( request=req, response=resp, listener_name=get_name_for_callable(listener.ack_function), diff --git a/slack_bolt/workflows/step/step_middleware.py b/slack_bolt/workflows/step/step_middleware.py index 2ea6194d1..59af001a7 100644 --- a/slack_bolt/workflows/step/step_middleware.py +++ b/slack_bolt/workflows/step/step_middleware.py @@ -2,7 +2,6 @@ from typing import Callable, Optional from slack_bolt.listener import Listener -from slack_bolt.listener.thread_runner import ThreadListenerRunner from slack_bolt.middleware import Middleware from slack_bolt.request import BoltRequest from slack_bolt.response import BoltResponse @@ -13,9 +12,8 @@ class WorkflowStepMiddleware(Middleware): """Base middleware for step from app specific ones""" - def __init__(self, step: WorkflowStep, listener_runner: ThreadListenerRunner): + def __init__(self, step: WorkflowStep): self.step = step - self.listener_runner = listener_runner def process( self, @@ -43,8 +41,8 @@ def process( return next() + @staticmethod def _run( - self, listener: Listener, req: BoltRequest, resp: BoltResponse, @@ -53,7 +51,7 @@ def _run( if next_was_not_called: return None - return self.listener_runner.run( + return req.context.listener_runner.run( request=req, response=resp, listener_name=get_name_for_callable(listener.ack_function), diff --git a/tests/scenario_tests/test_middleware.py b/tests/scenario_tests/test_middleware.py index aa16bf620..6553445df 100644 --- a/tests/scenario_tests/test_middleware.py +++ b/tests/scenario_tests/test_middleware.py @@ -1,15 +1,23 @@ import json -from time import time +import logging +from time import time, sleep +from typing import Callable, Optional from slack_sdk.signature import SignatureVerifier from slack_sdk.web import WebClient +from slack_bolt import BoltResponse, CustomListenerMatcher from slack_bolt.app import App +from slack_bolt.listener import CustomListener +from slack_bolt.listener.thread_runner import ThreadListenerRunner +from slack_bolt.middleware import Middleware from slack_bolt.request import BoltRequest +from slack_bolt.request.payload_utils import is_shortcut from tests.mock_web_api_server import ( setup_mock_web_api_server, cleanup_mock_web_api_server, assert_auth_test_count, + assert_received_request_count, ) from tests.utils import remove_os_env_temporarily, restore_os_env @@ -168,6 +176,27 @@ def __call__(self, next_): assert response.body == "acknowledged!" assert_auth_test_count(self, 1) + def test_lazy_listener_middleware(self): + app = App( + client=self.web_client, + signing_secret=self.signing_secret, + ) + unmatch_middleware = LazyListenerStarter("xxxx") + app.use(unmatch_middleware) + + response = app.dispatch(self.build_request()) + assert response.status == 404 + assert_auth_test_count(self, 1) + + my_middleware = LazyListenerStarter("test-shortcut") + app.use(my_middleware) + response = app.dispatch(self.build_request()) + assert response.status == 200 + count = 0 + while count < 20 and my_middleware.lazy_called is False: + sleep(0.05) + assert my_middleware.lazy_called is True + def just_ack(ack): ack("acknowledged!") @@ -183,3 +212,42 @@ def just_next(next): def just_next_(next_): next_() + + +class LazyListenerStarter(Middleware): + lazy_called: bool + callback_id: str + + def __init__(self, callback_id: str): + self.lazy_called = False + self.callback_id = callback_id + + def lazy_listener(self): + self.lazy_called = True + + def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]) -> Optional[BoltResponse]: + if is_shortcut(req.body): + listener = CustomListener( + app_name="test-app", + ack_function=just_ack, + lazy_functions=[self.lazy_listener], + matchers=[ + CustomListenerMatcher( + app_name="test-app", + func=lambda payload: payload.get("callback_id") == self.callback_id, + ) + ], + middleware=[], + base_logger=req.context.logger, + ) + if listener.matches(req=req, resp=resp): + listener_runner: ThreadListenerRunner = req.context.listener_runner + response = listener_runner.run( + request=req, + response=resp, + listener_name="test", + listener=listener, + ) + if response is not None: + return response + next() diff --git a/tests/scenario_tests_async/test_middleware.py b/tests/scenario_tests_async/test_middleware.py index 694a7316e..6272f17e4 100644 --- a/tests/scenario_tests_async/test_middleware.py +++ b/tests/scenario_tests_async/test_middleware.py @@ -1,12 +1,20 @@ import json +import asyncio from time import time +from typing import Callable, Awaitable, Optional import pytest from slack_sdk.signature import SignatureVerifier from slack_sdk.web.async_client import AsyncWebClient +from slack_bolt import BoltResponse +from slack_bolt.listener.async_listener import AsyncCustomListener +from slack_bolt.listener.asyncio_runner import AsyncioListenerRunner +from slack_bolt.listener_matcher.async_listener_matcher import AsyncCustomListenerMatcher +from slack_bolt.middleware.async_middleware import AsyncMiddleware from slack_bolt.app.async_app import AsyncApp from slack_bolt.request.async_request import AsyncBoltRequest +from slack_bolt.request.payload_utils import is_shortcut from tests.mock_web_api_server import ( cleanup_mock_web_api_server_async, assert_auth_test_count_async, @@ -145,6 +153,27 @@ async def just_next_(next_): assert response.body == "acknowledged!" await assert_auth_test_count_async(self, 1) + @pytest.mark.asyncio + async def test_lazy_listener_middleware(self): + app = AsyncApp( + client=self.web_client, + signing_secret=self.signing_secret, + ) + unmatch_middleware = LazyListenerStarter("xxxx") + app.use(unmatch_middleware) + + response = await app.async_dispatch(self.build_request()) + assert response.status == 404 + + my_middleware = LazyListenerStarter("test-shortcut") + app.use(my_middleware) + response = await app.async_dispatch(self.build_request()) + assert response.status == 200 + count = 0 + while count < 20 and my_middleware.lazy_called is False: + await asyncio.sleep(0.05) + assert my_middleware.lazy_called is True + async def just_ack(ack): await ack("acknowledged!") @@ -160,3 +189,47 @@ async def just_next(next): async def just_next_(next_): await next_() + + +class LazyListenerStarter(AsyncMiddleware): + lazy_called: bool + callback_id: str + + def __init__(self, callback_id: str): + self.lazy_called = False + self.callback_id = callback_id + + async def lazy_listener(self): + self.lazy_called = True + + async def async_process( + self, *, req: AsyncBoltRequest, resp: BoltResponse, next: Callable[[], Awaitable[BoltResponse]] + ) -> Optional[BoltResponse]: + async def is_target(payload: dict): + return payload.get("callback_id") == self.callback_id + + if is_shortcut(req.body): + listener = AsyncCustomListener( + app_name="test-app", + ack_function=just_ack, + lazy_functions=[self.lazy_listener], + matchers=[ + AsyncCustomListenerMatcher( + app_name="test-app", + func=is_target, + ) + ], + middleware=[], + base_logger=req.context.logger, + ) + if await listener.async_matches(req=req, resp=resp): + listener_runner: AsyncioListenerRunner = req.context.listener_runner + response = await listener_runner.run( + request=req, + response=resp, + listener_name="test", + listener=listener, + ) + if response is not None: + return response + await next()