From 6908b42c66a4bd2b9a7b97de532a8210acc9b924 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 21 Jun 2023 11:44:21 -0700 Subject: [PATCH 1/2] add async wait method to the "with logging" aws utils also changed the status formatting in the logs so that it'd not be done if log level is not including INFO --- .../amazon/aws/utils/waiter_with_logging.py | 88 ++++++++++++++++--- .../aws/utils/test_waiter_with_logging.py | 59 ++++++++++++- 2 files changed, 131 insertions(+), 16 deletions(-) diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 8c9e33077f6ed..62888f1b4b982 100644 --- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -17,8 +17,10 @@ from __future__ import annotations +import asyncio import logging import time +from typing import Any import jmespath from botocore.exceptions import WaiterError @@ -31,10 +33,10 @@ def wait( waiter: Waiter, waiter_delay: int, max_attempts: int, - args: dict, + args: dict[str, Any], failure_message: str, status_message: str, - status_args: list, + status_args: list[str], ) -> None: """ Use a boto waiter to poll an AWS service for the specified state. Although this function @@ -47,7 +49,7 @@ def wait( :param args: The arguments to pass to the waiter. :param failure_message: The message to log if a failure state is reached. :param status_message: The message logged when printing the status of the service. - :param status_args: A list containing the arguments to retrieve status information from + :param status_args: A list containing the JMESPath queries to retrieve status information from the waiter response. e.g. response = {"Cluster": {"state": "CREATING"}} @@ -68,23 +70,81 @@ def wait( except WaiterError as error: if "terminal failure" in str(error): raise AirflowException(f"{failure_message}: {error}") - status_string = _format_status_string(status_args, error.last_response) - log.info("%s: %s", status_message, status_string) + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) time.sleep(waiter_delay) if attempt >= max_attempts: raise AirflowException("Waiter error: max attempts reached") -def _format_status_string(args, response): +async def async_wait( + waiter: Waiter, + waiter_delay: int, + max_attempts: int, + args: dict[str, Any], + failure_message: str, + status_message: str, + status_args: list[str], +): """ - Loops through the supplied args list and generates a string - which contains values from the waiter response. + Use an async boto waiter to poll an AWS service for the specified state. Although this function + uses boto waiters to poll the state of the service, it logs the response of the service + after every attempt, which is not currently supported by boto waiters. + + :param waiter: The boto waiter to use. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param args: The arguments to pass to the waiter. + :param failure_message: The message to log if a failure state is reached. + :param status_message: The message logged when printing the status of the service. + :param status_args: A list containing the JMESPath queries to retrieve status information from + the waiter response. + e.g. + response = {"Cluster": {"state": "CREATING"}} + status_args = ["Cluster.state"] + + response = { + "Clusters": [{"state": "CREATING", "details": "User initiated."},] + } + status_args = ["Clusters[0].state", "Clusters[0].details"] + """ + log = logging.getLogger(__name__) + attempt = 0 + while True: + attempt += 1 + try: + await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"{failure_message}: {error}") + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) + await asyncio.sleep(waiter_delay) + + if attempt >= max_attempts: + raise AirflowException("Waiter error: max attempts reached") + + +class _LazyStatusFormatter: + """ + a wrapper containing the info necessary to extract the status from a response, + that'll only compute the value when necessary. + Used to avoid computations if the logs are disabled at the given level. """ - values = [] - for arg in args: - value = jmespath.search(arg, response) - if value is not None and value != "": - values.append(str(value)) - return " - ".join(values) + def __init__(self, jmespath_queries: list[str], response: dict[str, Any]): + self.jmespath_queries = jmespath_queries + self.response = response + + def __str__(self): + """ + Loops through the supplied args list and generates a string + which contains values from the waiter response. + """ + values = [] + for query in self.jmespath_queries: + value = jmespath.search(query, self.response) + if value is not None and value != "": + values.append(str(value)) + + return " - ".join(values) diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py index 2ca74936d7d71..4580c210548a8 100644 --- a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py +++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py @@ -20,12 +20,13 @@ import logging from typing import Any from unittest import mock +from unittest.mock import AsyncMock import pytest from botocore.exceptions import WaiterError from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.utils.waiter_with_logging import wait +from airflow.providers.amazon.aws.utils.waiter_with_logging import _LazyStatusFormatter, async_wait, wait def generate_response(state: str) -> dict[str, Any]: @@ -63,7 +64,7 @@ def test_wait(self, mock_sleep, caplog): "MaxAttempts": 1, }, ) - mock_waiter.wait.call_count == 3 + assert mock_waiter.wait.call_count == 3 mock_sleep.assert_called_with(123) assert ( caplog.record_tuples @@ -77,6 +78,36 @@ def test_wait(self, mock_sleep, caplog): * 2 ) + @pytest.mark.asyncio + async def test_async_wait(self, caplog): + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait = AsyncMock() + mock_waiter.wait.side_effect = [error, error, True] + + await async_wait( + waiter=mock_waiter, + waiter_delay=0, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + assert mock_waiter.wait.call_count == 3 + assert caplog.messages == ["test status message: Pending", "test status message: Pending"] + @mock.patch("time.sleep") def test_wait_max_attempts_exceeded(self, mock_sleep, caplog): mock_sleep.return_value = True @@ -302,3 +333,27 @@ def test_wait_with_multiple_args(self, mock_sleep, caplog): ] * 2 ) + + @mock.patch.object(_LazyStatusFormatter, "__str__") + def test_status_formatting_not_done_if_higher_log_level(self, status_format_mock: mock.MagicMock, caplog): + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait.side_effect = [error, error, True] + + with caplog.at_level(level=logging.WARNING): + wait( + waiter=mock_waiter, + waiter_delay=0, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + + assert len(caplog.messages) == 0 + status_format_mock.assert_not_called() From 90147277578707fc6997ec9086ee2b95b9f411d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 22 Jun 2023 09:37:12 -0700 Subject: [PATCH 2/2] no sleep if last attempt --- .../providers/amazon/aws/utils/waiter_with_logging.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 62888f1b4b982..b883e36bdb12a 100644 --- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -70,12 +70,13 @@ def wait( except WaiterError as error: if "terminal failure" in str(error): raise AirflowException(f"{failure_message}: {error}") - log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) - time.sleep(waiter_delay) + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) if attempt >= max_attempts: raise AirflowException("Waiter error: max attempts reached") + time.sleep(waiter_delay) + async def async_wait( waiter: Waiter, @@ -118,12 +119,13 @@ async def async_wait( except WaiterError as error: if "terminal failure" in str(error): raise AirflowException(f"{failure_message}: {error}") - log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) - await asyncio.sleep(waiter_delay) + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) if attempt >= max_attempts: raise AirflowException("Waiter error: max attempts reached") + await asyncio.sleep(waiter_delay) + class _LazyStatusFormatter: """