From 0d00105ea65755bdacda62e31eaccb4480c9a842 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 27 Jun 2023 09:45:24 -0700 Subject: [PATCH 1/4] Restore "add deferrable mode for `AthenaOperator` (#32110)" (reverted in #32172) This reverts commit 3a85d4e7e867357fb5456973544028e8249bff2f. --- .../providers/amazon/aws/operators/athena.py | 20 ++++- .../providers/amazon/aws/triggers/athena.py | 76 +++++++++++++++++++ airflow/providers/amazon/provider.yaml | 3 + .../amazon/aws/operators/test_athena.py | 12 +++ .../amazon/aws/triggers/test_athena.py | 53 +++++++++++++ 5 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/amazon/aws/triggers/athena.py create mode 100644 tests/providers/amazon/aws/triggers/test_athena.py diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 612e563ce678b..990f2ec414f13 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -20,8 +20,10 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -69,6 +71,7 @@ def __init__( sleep_time: int = 30, max_polling_attempts: int | None = None, log_query: bool = True, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -81,9 +84,10 @@ def __init__( self.query_execution_context = query_execution_context or {} self.result_configuration = result_configuration or {} self.sleep_time = sleep_time - self.max_polling_attempts = max_polling_attempts + self.max_polling_attempts = max_polling_attempts or 999999 self.query_execution_id: str | None = None self.log_query: bool = log_query + self.deferrable = deferrable @cached_property def hook(self) -> AthenaHook: @@ -101,6 +105,15 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.workgroup, ) + + if self.deferrable: + self.defer( + trigger=AthenaTrigger( + self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id + ), + method_name="execute_complete", + ) + # implicit else: query_status = self.hook.poll_query_status( self.query_execution_id, max_polling_attempts=self.max_polling_attempts, @@ -121,6 +134,11 @@ def execute(self, context: Context) -> str | None: return self.query_execution_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") + return event["value"] + def on_kill(self) -> None: """Cancel the submitted athena query.""" if self.query_execution_id: diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py new file mode 100644 index 0000000000000..780d9e9b98df2 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/athena.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AthenaTrigger(BaseTrigger): + """ + Trigger for RedshiftCreateClusterOperator. + + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `available` state. + + :param query_execution_id: ID of the Athena query execution to watch + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempt: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + query_execution_id: str, + poll_interval: int, + max_attempt: int, + aws_conn_id: str, + ): + self.query_execution_id = query_execution_id + self.poll_interval = poll_interval + self.max_attempt = max_attempt + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "query_execution_id": str(self.query_execution_id), + "poll_interval": str(self.poll_interval), + "max_attempt": str(self.max_attempt), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + async def run(self): + hook = AthenaHook(self.aws_conn_id) + async with hook.async_conn as client: + waiter = hook.get_waiter("query_complete", deferrable=True, client=client) + await async_wait( + waiter=waiter, + waiter_delay=self.poll_interval, + max_attempts=self.max_attempt, + args={"QueryExecutionId": self.query_execution_id}, + failure_message=f"Error while waiting for query {self.query_execution_id} to complete", + status_message=f"Query execution id: {self.query_execution_id}, " + "Query is still in non-terminal state", + status_args=["QueryExecution.Status.State"], + ) + yield TriggerEvent({"status": "success", "value": self.query_execution_id}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 5439f9c8cbb76..e4f16ce398460 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -515,6 +515,9 @@ hooks: - airflow.providers.amazon.aws.hooks.appflow triggers: + - integration-name: Amazon Athena + python-modules: + - airflow.providers.amazon.aws.triggers.athena - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.triggers.batch diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index cfc78697683d1..9e528525204c6 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,9 +20,11 @@ import pytest +from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger from airflow.utils import timezone from airflow.utils.timezone import datetime @@ -158,3 +160,13 @@ def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status): ti.dag_run = dag_run assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID + + @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) + def test_is_deferred(self, mock_run_query): + self.athena.deferrable = True + + with pytest.raises(TaskDeferred) as deferred: + self.athena.execute(None) + + assert isinstance(deferred.value.trigger, AthenaTrigger) + assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py new file mode 100644 index 0000000000000..04e601f4392c4 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_athena.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger + + +class TestAthenaTrigger: + @pytest.mark.asyncio + @mock.patch.object(AthenaHook, "get_waiter") + @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this + async def test_run_with_error(self, conn_mock, waiter_mock): + waiter_mock.side_effect = WaiterError("name", "reason", {}) + + trigger = AthenaTrigger("query_id", 0, 5, None) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) + + @pytest.mark.asyncio + @mock.patch.object(AthenaHook, "get_waiter") + @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this + async def test_run_success(self, conn_mock, waiter_mock): + waiter_mock().wait = AsyncMock() + trigger = AthenaTrigger("my_query_id", 0, 5, None) + + generator = trigger.run() + event = await generator.asend(None) + + assert event.payload["status"] == "success" + assert event.payload["value"] == "my_query_id" From b4fc9449939b6a8c700e12b2ecffb0628427d7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 27 Jun 2023 09:47:36 -0700 Subject: [PATCH 2/4] fix build --- airflow/providers/amazon/aws/triggers/athena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py index 780d9e9b98df2..efae559470a28 100644 --- a/airflow/providers/amazon/aws/triggers/athena.py +++ b/airflow/providers/amazon/aws/triggers/athena.py @@ -66,7 +66,7 @@ async def run(self): await async_wait( waiter=waiter, waiter_delay=self.poll_interval, - max_attempts=self.max_attempt, + waiter_max_attempts=self.max_attempt, args={"QueryExecutionId": self.query_execution_id}, failure_message=f"Error while waiting for query {self.query_execution_id} to complete", status_message=f"Query execution id: {self.query_execution_id}, " From 35ebe8c49a42087a24cf1e913636942457edc59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 27 Jun 2023 10:52:16 -0700 Subject: [PATCH 3/4] document behavior around task cancellation --- airflow/providers/amazon/aws/operators/athena.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 990f2ec414f13..de79d685ab583 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -33,6 +33,9 @@ class AthenaOperator(BaseOperator): """ An operator that submits a presto query to athena. + Note: if the task is killed while it runs, it'll cancel the athena query that was launched, + EXCEPT if running in deferrable mode. + .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:AthenaOperator` From 7d170db983140adae7820700a2484ca3de5fe21c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 27 Jun 2023 13:30:32 -0700 Subject: [PATCH 4/4] use sphynx-friendly annotation --- airflow/providers/amazon/aws/operators/athena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index de79d685ab583..0467fe6d11aed 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -33,7 +33,7 @@ class AthenaOperator(BaseOperator): """ An operator that submits a presto query to athena. - Note: if the task is killed while it runs, it'll cancel the athena query that was launched, + .. note:: if the task is killed while it runs, it'll cancel the athena query that was launched, EXCEPT if running in deferrable mode. .. seealso::