From f9739d814fa5dc5e6abc9a71c321765265a43a51 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 2 Apr 2021 17:35:57 +0900 Subject: [PATCH 1/2] Fix #280 Django thread-local connection cleanup in multi threads --- examples/django/.gitignore | 3 +- examples/django/mysql-docker-compose.yml | 16 ++++ .../slackapp/migrations/0001_initial.py | 34 +++---- examples/django/slackapp/models.py | 67 +++++++++----- examples/django/slackapp/settings.py | 25 ++++-- slack_bolt/adapter/django/handler.py | 90 ++++++++++++++++++- slack_bolt/app/app.py | 11 +++ slack_bolt/app/async_app.py | 11 +++ slack_bolt/listener/async_internals.py | 59 ++++++++++++ .../async_listener_completion_handler.py | 67 ++++++++++++++ .../listener/async_listener_error_handler.py | 72 ++++----------- slack_bolt/listener/asyncio_runner.py | 14 +++ slack_bolt/listener/internals.py | 76 ++++++++++++++++ .../listener/listener_completion_handler.py | 63 +++++++++++++ slack_bolt/listener/listener_error_handler.py | 70 ++++----------- slack_bolt/listener/thread_runner.py | 14 +++ tests/adapter_tests/django/test_django.py | 82 +++++++++++++++++ 17 files changed, 620 insertions(+), 154 deletions(-) create mode 100644 examples/django/mysql-docker-compose.yml create mode 100644 slack_bolt/listener/async_internals.py create mode 100644 slack_bolt/listener/async_listener_completion_handler.py create mode 100644 slack_bolt/listener/internals.py create mode 100644 slack_bolt/listener/listener_completion_handler.py diff --git a/examples/django/.gitignore b/examples/django/.gitignore index ba520ccd8..68712b32b 100644 --- a/examples/django/.gitignore +++ b/examples/django/.gitignore @@ -1 +1,2 @@ -db.sqlite3 \ No newline at end of file +db.sqlite3 +db/ diff --git a/examples/django/mysql-docker-compose.yml b/examples/django/mysql-docker-compose.yml new file mode 100644 index 000000000..e1f543f56 --- /dev/null +++ b/examples/django/mysql-docker-compose.yml @@ -0,0 +1,16 @@ +version: '3.9' +services: + db: + image: mysql:8 + environment: + MYSQL_DATABASE: slackapp + MYSQL_USER: app + MYSQL_PASSWORD: password + MYSQL_ROOT_PASSWORD: password + #command: + # - '--wait_timeout=3' + volumes: + - './db:/var/lib/mysql' + ports: + - 33306:3306 + diff --git a/examples/django/slackapp/migrations/0001_initial.py b/examples/django/slackapp/migrations/0001_initial.py index a2bcb49ef..ebc013537 100644 --- a/examples/django/slackapp/migrations/0001_initial.py +++ b/examples/django/slackapp/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.1.4 on 2020-12-04 13:07 +# Generated by Django 3.1.7 on 2021-04-02 05:53 from django.db import migrations, models @@ -22,15 +22,15 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("client_id", models.TextField()), - ("app_id", models.TextField()), - ("enterprise_id", models.TextField(null=True)), + ("client_id", models.CharField(max_length=32)), + ("app_id", models.CharField(max_length=32)), + ("enterprise_id", models.CharField(max_length=32, null=True)), ("enterprise_name", models.TextField(null=True)), - ("team_id", models.TextField(null=True)), + ("team_id", models.CharField(max_length=32, null=True)), ("team_name", models.TextField(null=True)), ("bot_token", models.TextField(null=True)), - ("bot_id", models.TextField(null=True)), - ("bot_user_id", models.TextField(null=True)), + ("bot_id", models.CharField(max_length=32, null=True)), + ("bot_user_id", models.CharField(max_length=32, null=True)), ("bot_scopes", models.TextField(null=True)), ("is_enterprise_install", models.BooleanField(null=True)), ("installed_at", models.DateTimeField()), @@ -48,18 +48,18 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("client_id", models.TextField()), - ("app_id", models.TextField()), - ("enterprise_id", models.TextField(null=True)), + ("client_id", models.CharField(max_length=32)), + ("app_id", models.CharField(max_length=32)), + ("enterprise_id", models.CharField(max_length=32, null=True)), ("enterprise_name", models.TextField(null=True)), ("enterprise_url", models.TextField(null=True)), - ("team_id", models.TextField(null=True)), + ("team_id", models.CharField(max_length=32, null=True)), ("team_name", models.TextField(null=True)), ("bot_token", models.TextField(null=True)), - ("bot_id", models.TextField(null=True)), + ("bot_id", models.CharField(max_length=32, null=True)), ("bot_user_id", models.TextField(null=True)), ("bot_scopes", models.TextField(null=True)), - ("user_id", models.TextField()), + ("user_id", models.CharField(max_length=32)), ("user_token", models.TextField(null=True)), ("user_scopes", models.TextField(null=True)), ("incoming_webhook_url", models.TextField(null=True)), @@ -67,7 +67,7 @@ class Migration(migrations.Migration): ("incoming_webhook_channel_id", models.TextField(null=True)), ("incoming_webhook_configuration_url", models.TextField(null=True)), ("is_enterprise_install", models.BooleanField(null=True)), - ("token_type", models.TextField(null=True)), + ("token_type", models.CharField(max_length=32, null=True)), ("installed_at", models.DateTimeField()), ], ), @@ -83,7 +83,7 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("state", models.TextField()), + ("state", models.CharField(max_length=64)), ("expire_at", models.DateTimeField()), ], ), @@ -97,14 +97,14 @@ class Migration(migrations.Migration): "user_id", "installed_at", ], - name="bolt_slacki_client__62c411_idx", + name="slackapp_sl_client__9b0d3f_idx", ), ), migrations.AddIndex( model_name="slackbot", index=models.Index( fields=["client_id", "enterprise_id", "team_id", "installed_at"], - name="bolt_slackb_client__be066b_idx", + name="slackapp_sl_client__d220d6_idx", ), ), ] diff --git a/examples/django/slackapp/models.py b/examples/django/slackapp/models.py index 737c4bd90..cfbbd8ade 100644 --- a/examples/django/slackapp/models.py +++ b/examples/django/slackapp/models.py @@ -6,15 +6,15 @@ class SlackBot(models.Model): - client_id = models.TextField(null=False) - app_id = models.TextField(null=False) - enterprise_id = models.TextField(null=True) + client_id = models.CharField(null=False, max_length=32) + app_id = models.CharField(null=False, max_length=32) + enterprise_id = models.CharField(null=True, max_length=32) enterprise_name = models.TextField(null=True) - team_id = models.TextField(null=True) + team_id = models.CharField(null=True, max_length=32) team_name = models.TextField(null=True) bot_token = models.TextField(null=True) - bot_id = models.TextField(null=True) - bot_user_id = models.TextField(null=True) + bot_id = models.CharField(null=True, max_length=32) + bot_user_id = models.CharField(null=True, max_length=32) bot_scopes = models.TextField(null=True) is_enterprise_install = models.BooleanField(null=True) installed_at = models.DateTimeField(null=False) @@ -28,18 +28,18 @@ class Meta: class SlackInstallation(models.Model): - client_id = models.TextField(null=False) - app_id = models.TextField(null=False) - enterprise_id = models.TextField(null=True) + client_id = models.CharField(null=False, max_length=32) + app_id = models.CharField(null=False, max_length=32) + enterprise_id = models.CharField(null=True, max_length=32) enterprise_name = models.TextField(null=True) enterprise_url = models.TextField(null=True) - team_id = models.TextField(null=True) + team_id = models.CharField(null=True, max_length=32) team_name = models.TextField(null=True) bot_token = models.TextField(null=True) - bot_id = models.TextField(null=True) + bot_id = models.CharField(null=True, max_length=32) bot_user_id = models.TextField(null=True) bot_scopes = models.TextField(null=True) - user_id = models.TextField(null=False) + user_id = models.CharField(null=False, max_length=32) user_token = models.TextField(null=True) user_scopes = models.TextField(null=True) incoming_webhook_url = models.TextField(null=True) @@ -47,7 +47,7 @@ class SlackInstallation(models.Model): incoming_webhook_channel_id = models.TextField(null=True) incoming_webhook_configuration_url = models.TextField(null=True) is_enterprise_install = models.BooleanField(null=True) - token_type = models.TextField(null=True) + token_type = models.CharField(null=True, max_length=32) installed_at = models.DateTimeField(null=False) class Meta: @@ -65,7 +65,7 @@ class Meta: class SlackOAuthState(models.Model): - state = models.TextField(null=False) + state = models.CharField(null=False, max_length=64) expire_at = models.DateTimeField(null=False) @@ -81,6 +81,7 @@ class SlackOAuthState(models.Model): from django.utils import timezone from slack_sdk.oauth import InstallationStore, OAuthStateStore from slack_sdk.oauth.installation_store import Bot, Installation +from slack_sdk.webhook import WebhookClient class DjangoInstallationStore(InstallationStore): @@ -100,9 +101,13 @@ def logger(self) -> Logger: def save(self, installation: Installation): i = installation.to_dict() + if is_naive(i["installed_at"]): + i["installed_at"] = make_aware(i["installed_at"]) i["client_id"] = self.client_id SlackInstallation(**i).save() b = installation.to_bot().to_dict() + if is_naive(b["installed_at"]): + b["installed_at"] = make_aware(b["installed_at"]) b["client_id"] = self.client_id SlackBot(**b).save() @@ -222,7 +227,7 @@ def consume(self, state: str) -> bool: import logging import os -from slack_bolt import App +from slack_bolt import App, BoltContext from slack_bolt.oauth.oauth_settings import OAuthSettings logger = logging.getLogger(__name__) @@ -249,12 +254,34 @@ def consume(self, state: str) -> bool: ) -@app.event("app_mention") -def event_test(body, say, logger): +def event_test(body, say, context: BoltContext, logger): logger.info(body) - say("What's up?") + say(":wave: What's up?") + + found_rows = list( + SlackInstallation.objects.filter(enterprise_id=context.enterprise_id) + .filter(team_id=context.team_id) + .filter(incoming_webhook_url__isnull=False) + .order_by(F("installed_at").desc())[:1] + ) + if len(found_rows) > 0: + webhook_url = found_rows[0].incoming_webhook_url + logger.info(f"webhook_url: {webhook_url}") + client = WebhookClient(webhook_url) + client.send(text=":wave: This is a message posted using Incoming Webhook!") + + +# lazy listener example +def noop(): + pass + + +app.event("app_mention")( + ack=event_test, + lazy=[noop], +) -@app.command("/hello-bolt-python") +@app.command("/hello-django-app") def command(ack): - ack("This is a Django app!") + ack(":wave: Hello from a Django app :smile:") diff --git a/examples/django/slackapp/settings.py b/examples/django/slackapp/settings.py index 1553e7df3..56bdc50e9 100644 --- a/examples/django/slackapp/settings.py +++ b/examples/django/slackapp/settings.py @@ -22,7 +22,7 @@ }, "root": { "handlers": ["console"], - "level": "INFO", + "level": "DEBUG", }, "loggers": { "django": { @@ -30,7 +30,7 @@ "level": os.getenv("DJANGO_LOG_LEVEL", "INFO"), "propagate": False, }, - "django.db.backends": { + "django.db": { "level": "DEBUG", }, "slack_bolt": { @@ -105,10 +105,25 @@ # https://docs.djangoproject.com/en/3.0/ref/settings/#databases DATABASES = { + # python manage.py migrate + # python manage.py runserver 0.0.0.0:3000 + # "default": { + # "ENGINE": "django.db.backends.sqlite3", + # "NAME": os.path.join(BASE_DIR, "db.sqlite3"), + # }, + + # docker-compose -f mysql-docker-compose.yml up --build + # pip install mysqlclient + # python manage.py migrate + # python manage.py runserver 0.0.0.0:3000 "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": os.path.join(BASE_DIR, "db.sqlite3"), - } + "ENGINE": "django.db.backends.mysql", + "NAME": "slackapp", + "USER": "app", + "PASSWORD": "password", + "HOST": "127.0.0.1", + "PORT": 33306, + }, } # Password validation diff --git a/slack_bolt/adapter/django/handler.py b/slack_bolt/adapter/django/handler.py index 24eb91b0e..2862a380b 100644 --- a/slack_bolt/adapter/django/handler.py +++ b/slack_bolt/adapter/django/handler.py @@ -1,8 +1,19 @@ -from typing import Optional +import logging +from logging import Logger +from threading import current_thread, Thread +from typing import Optional, Callable from django.http import HttpRequest, HttpResponse from slack_bolt.app import App +from slack_bolt.error import BoltError +from slack_bolt.lazy_listener import ThreadLazyListenerRunner +from slack_bolt.lazy_listener.internals import build_runnable_function +from slack_bolt.listener.listener_completion_handler import ( + ListenerCompletionHandler, + DefaultListenerCompletionHandler, +) +from slack_bolt.listener.thread_runner import ThreadListenerRunner from slack_bolt.oauth import OAuthFlow from slack_bolt.request import BoltRequest from slack_bolt.response import BoltResponse @@ -43,9 +54,86 @@ def to_django_response(bolt_resp: BoltResponse) -> HttpResponse: return resp +from django.db import connections + + +def release_thread_local_connections(logger: Logger, execution_type: str): + connections.close_all() + if logger.level <= logging.DEBUG: + current: Thread = current_thread() + logger.debug( + f"Released thread-bound DB connections (thread name: {current.name}, execution type: {execution_type})" + ) + + +class DjangoListenerCompletionHandler(ListenerCompletionHandler): + """Django sets DB connections as a thread-local variable per thread. + If the thread is not managed on the Django app side, the connections won't be released by Django. + This handler releases the connections every time a ThreadListenerRunner execution completes. + """ + + def handle(self, request: BoltRequest, response: Optional[BoltResponse]) -> None: + release_thread_local_connections(request.context.logger, "listener") + + +class DjangoThreadLazyListenerRunner(ThreadLazyListenerRunner): + def start(self, function: Callable[..., None], request: BoltRequest) -> None: + func: Callable[[], None] = build_runnable_function( + func=function, + logger=self.logger, + request=request, + ) + + def wrapped_func(): + try: + func() + finally: + release_thread_local_connections( + request.context.logger, "lazy-listener" + ) + + self.executor.submit(wrapped_func) + + class SlackRequestHandler: def __init__(self, app: App): # type: ignore self.app = app + listener_runner = self.app.listener_runner + # This runner closes all thread-local connections in the thread when an execution completes + self.app.listener_runner.lazy_listener_runner = DjangoThreadLazyListenerRunner( + logger=listener_runner.logger, + executor=listener_runner.listener_executor, + ) + + if not isinstance(listener_runner, ThreadListenerRunner): + raise BoltError( + "Custom listener_runners are not compatible with this Django adapter." + ) + + if app.process_before_response is True: + # As long as the app access Django models in the same thread, + # Django cleans the connections up for you. + self.app.logger.debug("App.process_before_response is set to True") + return + + current_completion_handler = listener_runner.listener_completion_handler + if current_completion_handler is not None and not isinstance( + current_completion_handler, DefaultListenerCompletionHandler + ): + message = """As you've already set app.listener_runner.listener_completion_handler to your own one, + Bolt skipped to set it to slack_sdk.adapter.django.DjangoListenerCompletionHandler. + We strongly recommend having the following lines of code in your listener_completion_handler: + + from django.db import connections + connections.close_all() + """ + self.app.logger.warning(message) + return + # for proper management of thread-local Django DB connections + self.app.listener_runner.listener_completion_handler = ( + DjangoListenerCompletionHandler() + ) + self.app.logger.debug("DjangoListenerCompletionHandler has been enabled") def handle(self, req: HttpRequest) -> HttpResponse: if req.method == "GET": diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index b0ce87d7c..1abc01a03 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -21,6 +21,9 @@ from slack_bolt.lazy_listener.thread_runner import ThreadLazyListenerRunner from slack_bolt.listener.custom_listener import CustomListener from slack_bolt.listener.listener import Listener +from slack_bolt.listener.listener_completion_handler import ( + DefaultListenerCompletionHandler, +) from slack_bolt.listener.listener_error_handler import ( DefaultListenerErrorHandler, CustomListenerErrorHandler, @@ -255,12 +258,16 @@ def message_hello(message, say): self._listeners: List[Listener] = [] listener_executor = ThreadPoolExecutor(max_workers=5) + self._process_before_response = process_before_response self._listener_runner = ThreadListenerRunner( logger=self._framework_logger, process_before_response=process_before_response, listener_error_handler=DefaultListenerErrorHandler( logger=self._framework_logger ), + listener_completion_handler=DefaultListenerCompletionHandler( + logger=self._framework_logger + ), listener_executor=listener_executor, lazy_listener_runner=ThreadLazyListenerRunner( logger=self._framework_logger, @@ -339,6 +346,10 @@ def listener_runner(self) -> ThreadListenerRunner: """The thread executor for asynchronously running listeners.""" return self._listener_runner + @property + def process_before_response(self) -> bool: + return self._process_before_response or False + # ------------------------- # standalone server diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 97ee2f9d5..cf4385701 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -7,6 +7,9 @@ from aiohttp import web from slack_bolt.app.async_server import AsyncSlackAppServer +from slack_bolt.listener.async_listener_completion_handler import ( + AsyncDefaultListenerCompletionHandler, +) from slack_bolt.listener.asyncio_runner import AsyncioListenerRunner from slack_bolt.middleware.message_listener_matches.async_message_listener_matches import ( AsyncMessageListenerMatches, @@ -279,12 +282,16 @@ async def message_hello(message, say): # async function self._async_middleware_list: List[Union[Callable, AsyncMiddleware]] = [] self._async_listeners: List[AsyncListener] = [] + self._process_before_response = process_before_response self._async_listener_runner = AsyncioListenerRunner( logger=self._framework_logger, process_before_response=process_before_response, listener_error_handler=AsyncDefaultListenerErrorHandler( logger=self._framework_logger ), + listener_completion_handler=AsyncDefaultListenerCompletionHandler( + logger=self._framework_logger + ), lazy_listener_runner=AsyncioLazyListenerRunner( logger=self._framework_logger, ), @@ -355,6 +362,10 @@ def listener_runner(self) -> AsyncioListenerRunner: """The asyncio-based executor for asynchronously running listeners.""" return self._async_listener_runner + @property + def process_before_response(self) -> bool: + return self._process_before_response or False + # ------------------------- # standalone server diff --git a/slack_bolt/listener/async_internals.py b/slack_bolt/listener/async_internals.py new file mode 100644 index 000000000..6a8fc8ab1 --- /dev/null +++ b/slack_bolt/listener/async_internals.py @@ -0,0 +1,59 @@ +from logging import Logger +from typing import Dict, Any, Optional + +from slack_bolt.request.async_request import AsyncBoltRequest +from slack_bolt.request.payload_utils import ( + to_options, + to_shortcut, + to_action, + to_view, + to_command, + to_event, + to_message, + to_step, +) +from slack_bolt.response import BoltResponse + + +def _build_all_available_args( + logger: Logger, + request: AsyncBoltRequest, + response: Optional[BoltResponse], + error: Optional[Exception] = None, +) -> Dict[str, Any]: + all_available_args = { + "logger": logger, + "error": error, + "client": request.context.client, + "req": request, + "request": request, + "resp": response, + "response": response, + "context": request.context, + "body": request.body, + # payload + "body": request.body, + "options": to_options(request.body), + "shortcut": to_shortcut(request.body), + "action": to_action(request.body), + "view": to_view(request.body), + "command": to_command(request.body), + "event": to_event(request.body), + "message": to_message(request.body), + "step": to_step(request.body), + # utilities + "say": request.context.say, + "respond": request.context.respond, + } + all_available_args["payload"] = ( + all_available_args["options"] + or all_available_args["shortcut"] + or all_available_args["action"] + or all_available_args["view"] + or all_available_args["command"] + or all_available_args["event"] + or all_available_args["message"] + or all_available_args["step"] + or request.body + ) + return all_available_args diff --git a/slack_bolt/listener/async_listener_completion_handler.py b/slack_bolt/listener/async_listener_completion_handler.py new file mode 100644 index 000000000..e14a58a5d --- /dev/null +++ b/slack_bolt/listener/async_listener_completion_handler.py @@ -0,0 +1,67 @@ +import inspect +from abc import ABCMeta, abstractmethod +from logging import Logger +from typing import Callable, Dict, Any, Awaitable, Optional + +from slack_bolt.listener.async_internals import ( + _build_all_available_args, +) +from slack_bolt.listener.internals import ( + _convert_all_available_args_to_kwargs, +) + +from slack_bolt.request.async_request import AsyncBoltRequest +from slack_bolt.response import BoltResponse + + +class AsyncListenerCompletionHandler(metaclass=ABCMeta): + @abstractmethod + async def handle( + self, + request: AsyncBoltRequest, + response: Optional[BoltResponse], + ) -> None: + """Handles an unhandled exception. + + Args: + error: The raised exception. + request: The request. + response: The response. + """ + raise NotImplementedError() + + +class AsyncCustomListenerCompletionHandler(AsyncListenerCompletionHandler): + def __init__(self, logger: Logger, func: Callable[..., Awaitable[None]]): + self.func = func + self.logger = logger + self.arg_names = inspect.getfullargspec(func).args + + async def handle( + self, + request: AsyncBoltRequest, + response: Optional[BoltResponse], + ) -> None: + all_available_args = _build_all_available_args( + logger=self.logger, + request=request, + response=response, + ) + kwargs: Dict[str, Any] = _convert_all_available_args_to_kwargs( + all_available_args=all_available_args, + arg_names=self.arg_names, + logger=self.logger, + ) + await self.func(**kwargs) + + +class AsyncDefaultListenerCompletionHandler(AsyncListenerCompletionHandler): + def __init__(self, logger: Logger): + self.logger = logger + + async def handle( + self, + request: AsyncBoltRequest, + response: Optional[BoltResponse], + ): + pass diff --git a/slack_bolt/listener/async_listener_error_handler.py b/slack_bolt/listener/async_listener_error_handler.py index c4e55e7d9..3aa0360d8 100644 --- a/slack_bolt/listener/async_listener_error_handler.py +++ b/slack_bolt/listener/async_listener_error_handler.py @@ -3,20 +3,16 @@ from logging import Logger from typing import Callable, Dict, Any, Awaitable, Optional +from slack_bolt.listener.async_internals import ( + _build_all_available_args, +) +from slack_bolt.listener.internals import ( + _convert_all_available_args_to_kwargs, +) + from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.response import BoltResponse -from slack_bolt.request.payload_utils import ( - to_options, - to_shortcut, - to_action, - to_view, - to_command, - to_event, - to_message, - to_step, -) - class AsyncListenerErrorHandler(metaclass=ABCMeta): @abstractmethod @@ -48,51 +44,17 @@ async def handle( request: AsyncBoltRequest, response: Optional[BoltResponse], ) -> None: - all_available_args = { - "logger": self.logger, - "error": error, - "client": request.context.client, - "req": request, - "request": request, - "resp": response, - "response": response, - "context": request.context, - "body": request.body, - # payload - "body": request.body, - "options": to_options(request.body), - "shortcut": to_shortcut(request.body), - "action": to_action(request.body), - "view": to_view(request.body), - "command": to_command(request.body), - "event": to_event(request.body), - "message": to_message(request.body), - "step": to_step(request.body), - # utilities - "say": request.context.say, - "respond": request.context.respond, - } - all_available_args["payload"] = ( - all_available_args["options"] - or all_available_args["shortcut"] - or all_available_args["action"] - or all_available_args["view"] - or all_available_args["command"] - or all_available_args["event"] - or all_available_args["message"] - or all_available_args["step"] - or request.body + all_available_args = _build_all_available_args( + logger=self.logger, + error=error, + request=request, + response=response, + ) + kwargs: Dict[str, Any] = _convert_all_available_args_to_kwargs( + all_available_args=all_available_args, + arg_names=self.arg_names, + logger=self.logger, ) - - kwargs: Dict[str, Any] = { # type: ignore - k: v for k, v in all_available_args.items() if k in self.arg_names # type: ignore - } - found_arg_names = kwargs.keys() - for name in self.arg_names: - if name not in found_arg_names: - self.logger.warning(f"{name} is not a valid argument") - kwargs[name] = None - await self.func(**kwargs) diff --git a/slack_bolt/listener/asyncio_runner.py b/slack_bolt/listener/asyncio_runner.py index 8ebcbde45..37a532f3c 100644 --- a/slack_bolt/listener/asyncio_runner.py +++ b/slack_bolt/listener/asyncio_runner.py @@ -7,6 +7,9 @@ from slack_bolt.context.ack.async_ack import AsyncAck from slack_bolt.lazy_listener.async_runner import AsyncLazyListenerRunner from slack_bolt.listener.async_listener import AsyncListener +from slack_bolt.listener.async_listener_completion_handler import ( + AsyncListenerCompletionHandler, +) from slack_bolt.listener.async_listener_error_handler import AsyncListenerErrorHandler from slack_bolt.logger.messages import ( debug_responding, @@ -22,6 +25,7 @@ class AsyncioListenerRunner: logger: Logger process_before_response: bool listener_error_handler: AsyncListenerErrorHandler + listener_completion_handler: AsyncListenerCompletionHandler lazy_listener_runner: AsyncLazyListenerRunner def __init__( @@ -29,11 +33,13 @@ def __init__( logger: Logger, process_before_response: bool, listener_error_handler: AsyncListenerErrorHandler, + listener_completion_handler: AsyncListenerCompletionHandler, lazy_listener_runner: AsyncLazyListenerRunner, ): self.logger = logger self.process_before_response = process_before_response self.listener_error_handler = listener_error_handler + self.listener_completion_handler = listener_completion_handler self.lazy_listener_runner = lazy_listener_runner async def run( @@ -68,6 +74,10 @@ async def run( response=response, ) ack.response = response + finally: + await self.listener_completion_handler.handle( + request=request, response=response + ) for lazy_func in listener.lazy_functions: if request.lazy_function_name: @@ -121,6 +131,10 @@ async def run_ack_function_asynchronously( response=response, ) ack.response = response + finally: + await self.listener_completion_handler.handle( + request=request, response=response + ) _f: Future = asyncio.ensure_future( run_ack_function_asynchronously(ack, request, response) diff --git a/slack_bolt/listener/internals.py b/slack_bolt/listener/internals.py new file mode 100644 index 000000000..9c1ac3a31 --- /dev/null +++ b/slack_bolt/listener/internals.py @@ -0,0 +1,76 @@ +from logging import Logger +from typing import Optional, Dict, Any, List + +from slack_bolt.request.request import BoltRequest +from slack_bolt.response.response import BoltResponse + +from slack_bolt.request.payload_utils import ( + to_options, + to_shortcut, + to_action, + to_view, + to_command, + to_event, + to_message, + to_step, +) + + +def _build_all_available_args( + logger: Logger, + request: BoltRequest, + response: Optional[BoltResponse], + error: Optional[Exception] = None, +) -> Dict[str, Any]: + all_available_args = { + "logger": logger, + "error": error, + "client": request.context.client, + "req": request, + "request": request, + "resp": response, + "response": response, + "context": request.context, + "body": request.body, + # payload + "body": request.body, + "options": to_options(request.body), + "shortcut": to_shortcut(request.body), + "action": to_action(request.body), + "view": to_view(request.body), + "command": to_command(request.body), + "event": to_event(request.body), + "message": to_message(request.body), + "step": to_step(request.body), + # utilities + "say": request.context.say, + "respond": request.context.respond, + } + all_available_args["payload"] = ( + all_available_args["options"] + or all_available_args["shortcut"] + or all_available_args["action"] + or all_available_args["view"] + or all_available_args["command"] + or all_available_args["event"] + or all_available_args["message"] + or all_available_args["step"] + or request.body + ) + return all_available_args + + +def _convert_all_available_args_to_kwargs( + all_available_args: Dict[str, Any], + arg_names: List[str], + logger: Logger, +) -> Dict[str, Any]: + kwargs: Dict[str, Any] = { # type: ignore + k: v for k, v in all_available_args.items() if k in arg_names # type: ignore + } + found_arg_names = kwargs.keys() + for name in arg_names: + if name not in found_arg_names: + logger.warning(f"{name} is not a valid argument") + kwargs[name] = None + return kwargs diff --git a/slack_bolt/listener/listener_completion_handler.py b/slack_bolt/listener/listener_completion_handler.py new file mode 100644 index 000000000..2419bc56e --- /dev/null +++ b/slack_bolt/listener/listener_completion_handler.py @@ -0,0 +1,63 @@ +import inspect +from abc import ABCMeta, abstractmethod +from logging import Logger +from typing import Callable, Dict, Any, Optional + +from slack_bolt.listener.internals import ( + _build_all_available_args, + _convert_all_available_args_to_kwargs, +) +from slack_bolt.request.request import BoltRequest +from slack_bolt.response.response import BoltResponse + + +class ListenerCompletionHandler(metaclass=ABCMeta): + @abstractmethod + def handle( + self, + request: BoltRequest, + response: Optional[BoltResponse], + ) -> None: + """Handles an unhandled exception. + + Args: + request: The request. + response: The response. + """ + raise NotImplementedError() + + +class CustomListenerCompletionHandler(ListenerCompletionHandler): + def __init__(self, logger: Logger, func: Callable[..., None]): + self.func = func + self.logger = logger + self.arg_names = inspect.getfullargspec(func).args + + def handle( + self, + request: BoltRequest, + response: Optional[BoltResponse], + ): + all_available_args = _build_all_available_args( + logger=self.logger, + request=request, + response=response, + ) + kwargs: Dict[str, Any] = _convert_all_available_args_to_kwargs( + all_available_args=all_available_args, + arg_names=self.arg_names, + logger=self.logger, + ) + self.func(**kwargs) + + +class DefaultListenerCompletionHandler(ListenerCompletionHandler): + def __init__(self, logger: Logger): + self.logger = logger + + def handle( + self, + request: BoltRequest, + response: Optional[BoltResponse], + ): + pass diff --git a/slack_bolt/listener/listener_error_handler.py b/slack_bolt/listener/listener_error_handler.py index d4cbd98b0..fe5c9b173 100644 --- a/slack_bolt/listener/listener_error_handler.py +++ b/slack_bolt/listener/listener_error_handler.py @@ -3,20 +3,14 @@ from logging import Logger from typing import Callable, Dict, Any, Optional +from slack_bolt.listener.internals import ( + _build_all_available_args, + _convert_all_available_args_to_kwargs, +) + from slack_bolt.request.request import BoltRequest from slack_bolt.response.response import BoltResponse -from slack_bolt.request.payload_utils import ( - to_options, - to_shortcut, - to_action, - to_view, - to_command, - to_event, - to_message, - to_step, -) - class ListenerErrorHandler(metaclass=ABCMeta): @abstractmethod @@ -48,51 +42,17 @@ def handle( request: BoltRequest, response: Optional[BoltResponse], ): - all_available_args = { - "logger": self.logger, - "error": error, - "client": request.context.client, - "req": request, - "request": request, - "resp": response, - "response": response, - "context": request.context, - "body": request.body, - # payload - "body": request.body, - "options": to_options(request.body), - "shortcut": to_shortcut(request.body), - "action": to_action(request.body), - "view": to_view(request.body), - "command": to_command(request.body), - "event": to_event(request.body), - "message": to_message(request.body), - "step": to_step(request.body), - # utilities - "say": request.context.say, - "respond": request.context.respond, - } - all_available_args["payload"] = ( - all_available_args["options"] - or all_available_args["shortcut"] - or all_available_args["action"] - or all_available_args["view"] - or all_available_args["command"] - or all_available_args["event"] - or all_available_args["message"] - or all_available_args["step"] - or request.body + all_available_args = _build_all_available_args( + logger=self.logger, + error=error, + request=request, + response=response, + ) + kwargs: Dict[str, Any] = _convert_all_available_args_to_kwargs( + all_available_args=all_available_args, + arg_names=self.arg_names, + logger=self.logger, ) - - kwargs: Dict[str, Any] = { # type: ignore - k: v for k, v in all_available_args.items() if k in self.arg_names # type: ignore - } - found_arg_names = kwargs.keys() - for name in self.arg_names: - if name not in found_arg_names: - self.logger.warning(f"{name} is not a valid argument") - kwargs[name] = None - self.func(**kwargs) diff --git a/slack_bolt/listener/thread_runner.py b/slack_bolt/listener/thread_runner.py index 8a02320e4..74d20dd98 100644 --- a/slack_bolt/listener/thread_runner.py +++ b/slack_bolt/listener/thread_runner.py @@ -5,6 +5,7 @@ from slack_bolt.lazy_listener import LazyListenerRunner from slack_bolt.listener import Listener +from slack_bolt.listener.listener_completion_handler import ListenerCompletionHandler from slack_bolt.listener.listener_error_handler import ListenerErrorHandler from slack_bolt.logger.messages import ( debug_responding, @@ -20,6 +21,7 @@ class ThreadListenerRunner: logger: Logger process_before_response: bool listener_error_handler: ListenerErrorHandler + listener_completion_handler: ListenerCompletionHandler listener_executor: ThreadPoolExecutor lazy_listener_runner: LazyListenerRunner @@ -28,12 +30,14 @@ def __init__( logger: Logger, process_before_response: bool, listener_error_handler: ListenerErrorHandler, + listener_completion_handler: ListenerCompletionHandler, listener_executor: ThreadPoolExecutor, lazy_listener_runner: LazyListenerRunner, ): self.logger = logger self.process_before_response = process_before_response self.listener_error_handler = listener_error_handler + self.listener_completion_handler = listener_completion_handler self.listener_executor = listener_executor self.lazy_listener_runner = lazy_listener_runner @@ -69,6 +73,11 @@ def run( # type: ignore response=response, ) ack.response = response + finally: + self.listener_completion_handler.handle( + request=request, + response=response, + ) for lazy_func in listener.lazy_functions: if request.lazy_function_name: @@ -122,6 +131,11 @@ def run_ack_function_asynchronously(): response=response, ) ack.response = response + finally: + self.listener_completion_handler.handle( + request=request, + response=response, + ) self.listener_executor.submit(run_ack_function_asynchronously) diff --git a/tests/adapter_tests/django/test_django.py b/tests/adapter_tests/django/test_django.py index f387427a9..036ab5608 100644 --- a/tests/adapter_tests/django/test_django.py +++ b/tests/adapter_tests/django/test_django.py @@ -174,6 +174,88 @@ def command_handler(ack): assert response.status_code == 200 assert_auth_test_count(self, 1) + def test_commands_process_before_response(self): + app = App( + client=self.web_client, + signing_secret=self.signing_secret, + process_before_response=True, + ) + + def command_handler(ack): + ack() + + app.command("/hello-world")(command_handler) + + input = ( + "token=verification_token" + "&team_id=T111" + "&team_domain=test-domain" + "&channel_id=C111" + "&channel_name=random" + "&user_id=W111" + "&user_name=primary-owner" + "&command=%2Fhello-world" + "&text=Hi" + "&enterprise_id=E111" + "&enterprise_name=Org+Name" + "&response_url=https%3A%2F%2Fhooks.slack.com%2Fcommands%2FT111%2F111%2Fxxxxx" + "&trigger_id=111.111.xxx" + ) + timestamp, body = str(int(time())), input + + request = self.rf.post( + "/slack/events", + data=body, + content_type="application/x-www-form-urlencoded", + ) + request.headers = self.build_headers(timestamp, body) + + response = SlackRequestHandler(app).handle(request) + assert response.status_code == 200 + assert_auth_test_count(self, 1) + + def test_commands_lazy(self): + app = App( + client=self.web_client, + signing_secret=self.signing_secret, + ) + + def command_handler(ack): + ack() + + def lazy_handler(): + pass + + app.command("/hello-world")(ack=command_handler, lazy=[lazy_handler]) + + input = ( + "token=verification_token" + "&team_id=T111" + "&team_domain=test-domain" + "&channel_id=C111" + "&channel_name=random" + "&user_id=W111" + "&user_name=primary-owner" + "&command=%2Fhello-world" + "&text=Hi" + "&enterprise_id=E111" + "&enterprise_name=Org+Name" + "&response_url=https%3A%2F%2Fhooks.slack.com%2Fcommands%2FT111%2F111%2Fxxxxx" + "&trigger_id=111.111.xxx" + ) + timestamp, body = str(int(time())), input + + request = self.rf.post( + "/slack/events", + data=body, + content_type="application/x-www-form-urlencoded", + ) + request.headers = self.build_headers(timestamp, body) + + response = SlackRequestHandler(app).handle(request) + assert response.status_code == 200 + assert_auth_test_count(self, 1) + def test_oauth(self): app = App( client=self.web_client, From 63479fd72ca822ba97e7b85dfde4992a4c4d1335 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Sat, 10 Apr 2021 08:17:28 +0900 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Michael Brooks --- slack_bolt/listener/async_internals.py | 1 - slack_bolt/listener/internals.py | 1 - 2 files changed, 2 deletions(-) diff --git a/slack_bolt/listener/async_internals.py b/slack_bolt/listener/async_internals.py index 6a8fc8ab1..60347bfb6 100644 --- a/slack_bolt/listener/async_internals.py +++ b/slack_bolt/listener/async_internals.py @@ -30,7 +30,6 @@ def _build_all_available_args( "resp": response, "response": response, "context": request.context, - "body": request.body, # payload "body": request.body, "options": to_options(request.body), diff --git a/slack_bolt/listener/internals.py b/slack_bolt/listener/internals.py index 9c1ac3a31..9b0682c1d 100644 --- a/slack_bolt/listener/internals.py +++ b/slack_bolt/listener/internals.py @@ -31,7 +31,6 @@ def _build_all_available_args( "resp": response, "response": response, "context": request.context, - "body": request.body, # payload "body": request.body, "options": to_options(request.body),