diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 8c9e33077f6ed..b883e36bdb12a 100644 --- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -17,8 +17,10 @@ from __future__ import annotations +import asyncio import logging import time +from typing import Any import jmespath from botocore.exceptions import WaiterError @@ -31,10 +33,10 @@ def wait( waiter: Waiter, waiter_delay: int, max_attempts: int, - args: dict, + args: dict[str, Any], failure_message: str, status_message: str, - status_args: list, + status_args: list[str], ) -> None: """ Use a boto waiter to poll an AWS service for the specified state. Although this function @@ -47,7 +49,7 @@ def wait( :param args: The arguments to pass to the waiter. :param failure_message: The message to log if a failure state is reached. :param status_message: The message logged when printing the status of the service. - :param status_args: A list containing the arguments to retrieve status information from + :param status_args: A list containing the JMESPath queries to retrieve status information from the waiter response. e.g. response = {"Cluster": {"state": "CREATING"}} @@ -68,23 +70,83 @@ def wait( except WaiterError as error: if "terminal failure" in str(error): raise AirflowException(f"{failure_message}: {error}") - status_string = _format_status_string(status_args, error.last_response) - log.info("%s: %s", status_message, status_string) + + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) + if attempt >= max_attempts: + raise AirflowException("Waiter error: max attempts reached") + time.sleep(waiter_delay) + +async def async_wait( + waiter: Waiter, + waiter_delay: int, + max_attempts: int, + args: dict[str, Any], + failure_message: str, + status_message: str, + status_args: list[str], +): + """ + Use an async boto waiter to poll an AWS service for the specified state. Although this function + uses boto waiters to poll the state of the service, it logs the response of the service + after every attempt, which is not currently supported by boto waiters. + + :param waiter: The boto waiter to use. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param args: The arguments to pass to the waiter. + :param failure_message: The message to log if a failure state is reached. + :param status_message: The message logged when printing the status of the service. + :param status_args: A list containing the JMESPath queries to retrieve status information from + the waiter response. + e.g. + response = {"Cluster": {"state": "CREATING"}} + status_args = ["Cluster.state"] + + response = { + "Clusters": [{"state": "CREATING", "details": "User initiated."},] + } + status_args = ["Clusters[0].state", "Clusters[0].details"] + """ + log = logging.getLogger(__name__) + attempt = 0 + while True: + attempt += 1 + try: + await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"{failure_message}: {error}") + + log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) if attempt >= max_attempts: raise AirflowException("Waiter error: max attempts reached") + await asyncio.sleep(waiter_delay) + -def _format_status_string(args, response): +class _LazyStatusFormatter: """ - Loops through the supplied args list and generates a string - which contains values from the waiter response. + a wrapper containing the info necessary to extract the status from a response, + that'll only compute the value when necessary. + Used to avoid computations if the logs are disabled at the given level. """ - values = [] - for arg in args: - value = jmespath.search(arg, response) - if value is not None and value != "": - values.append(str(value)) - return " - ".join(values) + def __init__(self, jmespath_queries: list[str], response: dict[str, Any]): + self.jmespath_queries = jmespath_queries + self.response = response + + def __str__(self): + """ + Loops through the supplied args list and generates a string + which contains values from the waiter response. + """ + values = [] + for query in self.jmespath_queries: + value = jmespath.search(query, self.response) + if value is not None and value != "": + values.append(str(value)) + + return " - ".join(values) diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py index 2ca74936d7d71..4580c210548a8 100644 --- a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py +++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py @@ -20,12 +20,13 @@ import logging from typing import Any from unittest import mock +from unittest.mock import AsyncMock import pytest from botocore.exceptions import WaiterError from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.utils.waiter_with_logging import wait +from airflow.providers.amazon.aws.utils.waiter_with_logging import _LazyStatusFormatter, async_wait, wait def generate_response(state: str) -> dict[str, Any]: @@ -63,7 +64,7 @@ def test_wait(self, mock_sleep, caplog): "MaxAttempts": 1, }, ) - mock_waiter.wait.call_count == 3 + assert mock_waiter.wait.call_count == 3 mock_sleep.assert_called_with(123) assert ( caplog.record_tuples @@ -77,6 +78,36 @@ def test_wait(self, mock_sleep, caplog): * 2 ) + @pytest.mark.asyncio + async def test_async_wait(self, caplog): + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait = AsyncMock() + mock_waiter.wait.side_effect = [error, error, True] + + await async_wait( + waiter=mock_waiter, + waiter_delay=0, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + assert mock_waiter.wait.call_count == 3 + assert caplog.messages == ["test status message: Pending", "test status message: Pending"] + @mock.patch("time.sleep") def test_wait_max_attempts_exceeded(self, mock_sleep, caplog): mock_sleep.return_value = True @@ -302,3 +333,27 @@ def test_wait_with_multiple_args(self, mock_sleep, caplog): ] * 2 ) + + @mock.patch.object(_LazyStatusFormatter, "__str__") + def test_status_formatting_not_done_if_higher_log_level(self, status_format_mock: mock.MagicMock, caplog): + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait.side_effect = [error, error, True] + + with caplog.at_level(level=logging.WARNING): + wait( + waiter=mock_waiter, + waiter_delay=0, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + + assert len(caplog.messages) == 0 + status_format_mock.assert_not_called()