From 0b3d2811de57a2e47c13a37210fdf348b5da9d93 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 4 Jan 2024 18:26:54 +0800 Subject: [PATCH 01/11] feat(providers/amazon): add deferrable mode to RedshiftDataOperator --- .../amazon/aws/hooks/redshift_data.py | 103 +++++++++++++++--- .../amazon/aws/operators/redshift_data.py | 52 ++++++++- .../amazon/aws/triggers/redshift_data.py | 74 +++++++++++++ airflow/providers/amazon/provider.yaml | 1 + .../amazon/aws/triggers/test_redshift_data.py | 103 ++++++++++++++++++ 5 files changed, 314 insertions(+), 19 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/redshift_data.py create mode 100644 tests/providers/amazon/aws/triggers/test_redshift_data.py diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index f7df0fd744eaa..eaa8b322ba2e4 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -17,16 +17,26 @@ # under the License. from __future__ import annotations +import asyncio import time from pprint import pformat from typing import TYPE_CHECKING, Any, Iterable +import botocore.exceptions +from asgiref.sync import sync_to_async + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa +FINISHED_STATE = "FINISHED" +FAILED_STATE = "FAILED" +ABORTED_STATE = "ABORTED" +FAILURE_STATES = {FAILED_STATE, ABORTED_STATE} +RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"} + class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): """ @@ -108,27 +118,33 @@ def execute_query( return statement_id - def wait_for_results(self, statement_id, poll_interval): + def wait_for_results(self, statement_id: str, poll_interval: int) -> str: while True: self.log.info("Polling statement %s", statement_id) - resp = self.conn.describe_statement( - Id=statement_id, - ) - status = resp["Status"] - if status == "FINISHED": - num_rows = resp.get("ResultRows") - if num_rows is not None: - self.log.info("Processed %s rows", num_rows) - return status - elif status in ("FAILED", "ABORTED"): - raise ValueError( - f"Statement {statement_id!r} terminated with status {status}. " - f"Response details: {pformat(resp)}" - ) - else: - self.log.info("Query %s", status) + is_finished = self.check_query_is_finished(statement_id) + if is_finished: + return FINISHED_STATE + time.sleep(poll_interval) + def check_query_is_finished(self, statement_id: str) -> bool: + """Check whether query finished, raise exception is failed.""" + resp = self.conn.describe_statement(Id=statement_id) + status = resp["Status"] + if status == FINISHED_STATE: + num_rows = resp.get("ResultRows") + if num_rows is not None: + self.log.info("Processed %s rows", num_rows) + return True + elif status in FAILURE_STATES: + raise ValueError( + f"Statement {statement_id!r} terminated with status {status}. " + f"Response details: {pformat(resp)}" + ) + + self.log.info("Query %s", status) + return False + def get_table_primary_key( self, table: str, @@ -201,3 +217,56 @@ def get_table_primary_key( break return pk_columns or None + + async def check_query_is_finished_async( + self, statement_id: str, poll_interval: int = 10 + ) -> dict[str, str]: + """Async function to check statement is finished. + + It takes statement_id, makes async connection to redshift data to get the query status + by statement_id and returns the query status. + + :param statement_id: the UUID of the statement + :param poll_interval: how often in seconds to check the query status + """ + try: + client = await sync_to_async(self.get_conn)() + while await self.is_still_running(statement_id): + await asyncio.sleep(poll_interval) + + resp = client.describe_statement(Id=statement_id) + status = resp["Status"] + if status == FINISHED_STATE: + return {"status": "success", "statement_id": statement_id} + elif status == FAILED_STATE: + return { + "status": "error", + "message": f"Error: {resp['QueryString']} query Failed due to, {resp['Error']}", + "statement_id": statement_id, + "type": status, + } + elif status == ABORTED_STATE: + return { + "status": "error", + "message": "The query run was stopped by the user.", + "statement_id": statement_id, + "type": status, + } + + return { + "status": "error", + "message": f"Unexpected statue {status}", + "statement_id": statement_id, + "type": status, + } + except botocore.exceptions.ClientError as error: + return {"status": "error", "message": str(error), "type": "ERROR"} + + async def is_still_running(self, statement_id: str) -> bool: + """Async function to check whether the query is still running. + + :param statement_id: the UUID of the statement + """ + client = await sync_to_async(self.get_conn)() + desc = client.describe_statement(Id=statement_id) + return desc["Status"] in RUNNING_STATES diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index b454ad76ec440..bde3bbb40bcb6 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -92,6 +95,7 @@ def __init__( poll_interval: int = 10, return_sql_result: bool = False, workgroup_name: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -114,11 +118,17 @@ def __init__( ) self.return_sql_result = return_sql_result self.statement_id: str | None = None + self.deferrable = deferrable def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: """Execute a statement against Amazon Redshift.""" self.log.info("Executing statement: %s", self.sql) + # Set wait_for_completion to False so that it waits for the status in the deferred task. + wait_for_completion = self.wait_for_completion + if self.deferrable and self.wait_for_completion: + self.wait_for_completion = False + self.statement_id = self.hook.execute_query( database=self.database, sql=self.sql, @@ -129,10 +139,48 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: secret_arn=self.secret_arn, statement_name=self.statement_name, with_event=self.with_event, - wait_for_completion=self.wait_for_completion, + wait_for_completion=wait_for_completion, poll_interval=self.poll_interval, ) + if self.deferrable: + is_finished = self.hook.check_query_is_finished(self.statement_id) + if not is_finished: + self.defer( + timeout=self.execution_timeout, + trigger=RedshiftDataTrigger( + statement_id=self.statement_id, + task_id=self.task_id, + poll_interval=self.poll_interval, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="execute_complete", + ) + + if self.return_sql_result: + result = self.hook.conn.get_statement_result(Id=self.statement_id) + self.log.debug("Statement result: %s", result) + return result + else: + return self.statement_id + + def execute_complete( + self, context: Context, event: dict[str, Any] | None = None + ) -> GetStatementResultResponseTypeDef | str: + if event is None: + err_msg = "Trigger error: event is None" + self.log.info(err_msg) + raise AirflowException(err_msg) + + if event["status"] == "error": + msg = f"context: {context}, error message: {event['message']}" + raise AirflowException(msg) + + if not self.statement_id: + raise AirflowException("statement_id should not be empty.") + + self.log.info("%s completed successfully.", self.task_id) if self.return_sql_result: result = self.hook.conn.get_statement_result(Id=self.statement_id) self.log.debug("Statement result: %s", result) diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py new file mode 100644 index 0000000000000..2e40663cdb42b --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class RedshiftDataTrigger(BaseTrigger): + """ + RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer. + + :param statement_id: the UUID of the statement + :param task_id: task ID of the Dag + :param poll_interval: polling period in seconds to check for the status + :param aws_conn_id: AWS connection ID for redshift + :param region: aws region to use + """ + + def __init__( + self, + statement_id: str, + task_id: str, + poll_interval: int, + aws_conn_id: str | None = "aws_default", + region: str | None = None, + ): + super().__init__() + self.statement_id = statement_id + self.task_id = task_id + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + self.region = region + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes RedshiftDataTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger", + { + "statement_id": self.statement_id, + "task_id": self.task_id, + "aws_conn_id": self.aws_conn_id, + "poll_interval": self.poll_interval, + "region": self.region, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + # hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, poll_interval=self.poll_interval) + hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + try: + response = await hook.check_query_is_finished_async(self.statement_id) + if not response: + response = {"status": "error", "message": f"{self.task_id} failed"} + yield TriggerEvent(response) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 1b90089db2536..bcbb5c18e30ee 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -621,6 +621,7 @@ triggers: - integration-name: Amazon Redshift python-modules: - airflow.providers.amazon.aws.triggers.redshift_cluster + - airflow.providers.amazon.aws.triggers.redshift_data - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.triggers.sagemaker diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py new file mode 100644 index 0000000000000..fe580989afb9a --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger +from airflow.triggers.base import TriggerEvent + +TEST_CONN_ID = "aws_default" +TEST_TASK_ID = "123" +POLL_INTERVAL = 4.0 + + +class TestRedshiftDataTrigger: + def test_redshift_data_trigger_serialization(self): + """ + Asserts that the RedshiftDataTrigger correctly serializes its arguments + and classpath. + """ + trigger = RedshiftDataTrigger( + statement_id=[], + task_id=TEST_TASK_ID, + aws_conn_id=TEST_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger" + assert kwargs == { + "statement_id": [], + "task_id": TEST_TASK_ID, + "poll_interval": POLL_INTERVAL, + "aws_conn_id": TEST_CONN_ID, + "region": None, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "return_value, response", + [ + ( + {"status": "error", "message": "test error", "statement_id": "uuid", "type": "failed"}, + TriggerEvent( + {"status": "error", "message": "test error", "statement_id": "uuid", "type": "failed"} + ), + ), + (False, TriggerEvent({"status": "error", "message": f"{TEST_TASK_ID} failed"})), + ], + ) + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" + ) + async def test_redshift_data_trigger_run(self, mock_get_query_status, return_value, response): + """ + Tests that RedshiftDataTrigger only fires once the query execution reaches a successful state. + """ + mock_get_query_status.return_value = return_value + trigger = RedshiftDataTrigger( + statement_id="uuid", + task_id=TEST_TASK_ID, + poll_interval=POLL_INTERVAL, + aws_conn_id=TEST_CONN_ID, + ) + generator = trigger.run() + actual = await generator.asend(None) + assert response == actual + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" + ) + async def test_redshift_data_trigger_exception(self, mock_get_query_status): + """ + Test that RedshiftDataTrigger fires the correct event in case of an error. + """ + mock_get_query_status.side_effect = Exception("Test exception") + + trigger = RedshiftDataTrigger( + statement_id="uuid", + task_id=TEST_TASK_ID, + poll_interval=POLL_INTERVAL, + aws_conn_id=TEST_CONN_ID, + ) + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task From f01a62158095d99d29d0df55b8442ffe2044515d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 16 Jan 2024 10:21:59 +0800 Subject: [PATCH 02/11] test(providers/amazon): add test case to RedshiftDataHook async methods --- .../amazon/aws/hooks/test_redshift_data.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index cc174a872cbdd..6f85072fc1f31 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -292,3 +292,58 @@ def test_result_num_rows(self, mock_conn, caplog): wait_for_completion=True, ) assert "Processed " not in caplog.text + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "describe_statement_response, expected_result", + [ + ({"Status": "PICKED"}, True), + ({"Status": "STARTED"}, True), + ({"Status": "SUBMITTED"}, True), + ({"Status": "FINISHED"}, False), + ({"Status": "FAILED"}, False), + ({"Status": "ABORTED"}, False), + ], + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_conn") + async def test_is_still_running(self, mock_client, describe_statement_response, expected_result): + hook = RedshiftDataHook() + mock_client.return_value.describe_statement.return_value = describe_statement_response + response = await hook.is_still_running("uuid") + assert response == expected_result + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "describe_statement_response, expected_result", + [ + ({"Status": "FINISHED"}, {"status": "success", "statement_id": "uuid"}), + ( + {"Status": "FAILED", "QueryString": "select 1", "Error": "Test error"}, + { + "status": "error", + "message": "Error: select 1 query Failed due to, Test error", + "statement_id": "uuid", + "type": "FAILED", + }, + ), + ( + {"Status": "ABORTED"}, + { + "status": "error", + "message": "The query run was stopped by the user.", + "statement_id": "uuid", + "type": "ABORTED", + }, + ), + ], + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") + async def test_get_query_status( + self, mock_is_still_running, mock_conn, describe_statement_response, expected_result + ): + hook = RedshiftDataHook() + mock_is_still_running.return_value = False + mock_conn.return_value.describe_statement.return_value = describe_statement_response + response = await hook.check_query_is_finished_async(statement_id="uuid") + assert response == expected_result From 0f74f09636a1e8c4b1d074695b8a12678611b722 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 16 Jan 2024 11:10:50 +0800 Subject: [PATCH 03/11] test(providers/amazon): add test case to RedshiftDataOperator when deferrable = True --- .../aws/operators/test_redshift_data.py | 177 +++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index 4b921b71423b0..4380755ff7126 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -21,8 +21,9 @@ import pytest -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger CONN_ID = "aws_conn_test" TASK_ID = "task_id" @@ -202,3 +203,177 @@ def test_return_sql_result(self, mock_conn): mock_conn.get_statement_result.assert_called_once_with( Id=STATEMENT_ID, ) + + @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", + return_value=True, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_finished, mock_defer): + cluster_identifier = "cluster_identifier" + workgroup_name = None + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + wait_for_completion = True + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=True, + poll_interval=poll_interval, + deferrable=True, + ) + operator.execute(None) + + assert not mock_defer.called + mock_exec_query.assert_called_once_with( + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + with_event=False, + wait_for_completion=wait_for_completion, + poll_interval=poll_interval, + ) + + # @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", + return_value=False, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_defer(self, mock_exec_query, check_query_is_finished): + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute(None) + + assert isinstance(exc.value.trigger, RedshiftDataTrigger) + + def test_execute_complete_failure(self): + """Tests that an AirflowException is raised in case of error event""" + + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + + with pytest.raises(AirflowException): + operator.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) + + def test_execute_complete_exception(self): + """Tests that an AirflowException is raised in case of error event""" + + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + + with pytest.raises(AirflowException) as exc: + operator.execute_complete(context=None, event=None) + assert exc.value.args[0] == "Did not receive valid event from the trigerrer" + + def test_execute_complete(self): + """Asserts that logging occurs as expected""" + + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + operator.statement_id = "uuid" + + with mock.patch.object(operator.log, "info") as mock_log_info: + assert ( + operator.execute_complete( + context=None, event={"status": "success", "message": "Job completed"} + ) + == "uuid" + ) + mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) From 6a2ae24546ef603236a3a15f6dc4cb9fe95e0223 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 16 Jan 2024 11:16:33 +0800 Subject: [PATCH 04/11] refactor(providers/amazon): extract comment operator initialization as deferrable_operator fixture --- .../aws/operators/test_redshift_data.py | 141 +++++------------- 1 file changed, 39 insertions(+), 102 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index 4380755ff7126..682127e33fc45 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -32,6 +32,32 @@ STATEMENT_ID = "statement_id" +@pytest.fixture +def deferrable_operator(): + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + return operator + + class TestRedshiftDataOperator: def test_init(self): op = RedshiftDataOperator( @@ -218,7 +244,6 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin statement_name = "statement_name" parameters = [{"name": "id", "value": "1"}] poll_interval = 5 - wait_for_completion = True operator = RedshiftDataOperator( aws_conn_id=CONN_ID, @@ -230,7 +255,7 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin secret_arn=secret_arn, statement_name=statement_name, parameters=parameters, - wait_for_completion=True, + wait_for_completion=False, poll_interval=poll_interval, deferrable=True, ) @@ -247,7 +272,7 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin statement_name=statement_name, parameters=parameters, with_event=False, - wait_for_completion=wait_for_completion, + wait_for_completion=False, poll_interval=poll_interval, ) @@ -257,121 +282,33 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin return_value=False, ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_execute_defer(self, mock_exec_query, check_query_is_finished): - cluster_identifier = "cluster_identifier" - db_user = "db_user" - secret_arn = "secret_arn" - statement_name = "statement_name" - parameters = [{"name": "id", "value": "1"}] - poll_interval = 5 - - operator = RedshiftDataOperator( - aws_conn_id=CONN_ID, - task_id=TASK_ID, - sql=SQL, - database=DATABASE, - cluster_identifier=cluster_identifier, - db_user=db_user, - secret_arn=secret_arn, - statement_name=statement_name, - parameters=parameters, - wait_for_completion=False, - poll_interval=poll_interval, - deferrable=True, - ) + def test_execute_defer(self, mock_exec_query, check_query_is_finished, deferrable_operator): with pytest.raises(TaskDeferred) as exc: - operator.execute(None) + deferrable_operator.execute(None) assert isinstance(exc.value.trigger, RedshiftDataTrigger) - def test_execute_complete_failure(self): + def test_execute_complete_failure(self, deferrable_operator): """Tests that an AirflowException is raised in case of error event""" - - cluster_identifier = "cluster_identifier" - db_user = "db_user" - secret_arn = "secret_arn" - statement_name = "statement_name" - parameters = [{"name": "id", "value": "1"}] - poll_interval = 5 - - operator = RedshiftDataOperator( - aws_conn_id=CONN_ID, - task_id=TASK_ID, - sql=SQL, - database=DATABASE, - cluster_identifier=cluster_identifier, - db_user=db_user, - secret_arn=secret_arn, - statement_name=statement_name, - parameters=parameters, - wait_for_completion=False, - poll_interval=poll_interval, - deferrable=True, - ) - with pytest.raises(AirflowException): - operator.execute_complete( + deferrable_operator.execute_complete( context=None, event={"status": "error", "message": "test failure message"} ) - def test_execute_complete_exception(self): - """Tests that an AirflowException is raised in case of error event""" - - cluster_identifier = "cluster_identifier" - db_user = "db_user" - secret_arn = "secret_arn" - statement_name = "statement_name" - parameters = [{"name": "id", "value": "1"}] - poll_interval = 5 - - operator = RedshiftDataOperator( - aws_conn_id=CONN_ID, - task_id=TASK_ID, - sql=SQL, - database=DATABASE, - cluster_identifier=cluster_identifier, - db_user=db_user, - secret_arn=secret_arn, - statement_name=statement_name, - parameters=parameters, - wait_for_completion=False, - poll_interval=poll_interval, - deferrable=True, - ) - + def test_execute_complete_exception(self, deferrable_operator): + """Tests that an AirflowException is raised in case of empty event""" with pytest.raises(AirflowException) as exc: - operator.execute_complete(context=None, event=None) + deferrable_operator.execute_complete(context=None, event=None) assert exc.value.args[0] == "Did not receive valid event from the trigerrer" - def test_execute_complete(self): + def test_execute_complete(self, deferrable_operator): """Asserts that logging occurs as expected""" - cluster_identifier = "cluster_identifier" - db_user = "db_user" - secret_arn = "secret_arn" - statement_name = "statement_name" - parameters = [{"name": "id", "value": "1"}] - poll_interval = 5 - - operator = RedshiftDataOperator( - aws_conn_id=CONN_ID, - task_id=TASK_ID, - sql=SQL, - database=DATABASE, - cluster_identifier=cluster_identifier, - db_user=db_user, - secret_arn=secret_arn, - statement_name=statement_name, - parameters=parameters, - wait_for_completion=False, - poll_interval=poll_interval, - deferrable=True, - ) - operator.statement_id = "uuid" + deferrable_operator.statement_id = "uuid" - with mock.patch.object(operator.log, "info") as mock_log_info: + with mock.patch.object(deferrable_operator.log, "info") as mock_log_info: assert ( - operator.execute_complete( + deferrable_operator.execute_complete( context=None, event={"status": "success", "message": "Job completed"} ) == "uuid" From 39c5981f6c4dcc8c94b7cd09c904bcd622cebb41 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 16 Jan 2024 14:28:20 +0800 Subject: [PATCH 05/11] refactor(providers/amaozn): rename region as region_name --- .../providers/amazon/aws/operators/redshift_data.py | 2 +- .../providers/amazon/aws/triggers/redshift_data.py | 11 +++++------ .../amazon/aws/triggers/test_redshift_data.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index bde3bbb40bcb6..687ef35d046b6 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -153,7 +153,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: task_id=self.task_id, poll_interval=self.poll_interval, aws_conn_id=self.aws_conn_id, - region=self.region, + region_name=self.region_name, ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py index 2e40663cdb42b..2c44d143ffaa5 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_data.py +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -31,7 +31,7 @@ class RedshiftDataTrigger(BaseTrigger): :param task_id: task ID of the Dag :param poll_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for redshift - :param region: aws region to use + :param region_name: aws region to use """ def __init__( @@ -40,14 +40,14 @@ def __init__( task_id: str, poll_interval: int, aws_conn_id: str | None = "aws_default", - region: str | None = None, + region_name: str | None = None, ): super().__init__() self.statement_id = statement_id self.task_id = task_id self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval - self.region = region + self.region_name = region_name def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftDataTrigger arguments and classpath.""" @@ -58,13 +58,12 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "task_id": self.task_id, "aws_conn_id": self.aws_conn_id, "poll_interval": self.poll_interval, - "region": self.region, + "region_name": self.region_name, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: - # hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, poll_interval=self.poll_interval) - hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) try: response = await hook.check_query_is_finished_async(self.statement_id) if not response: diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py index fe580989afb9a..9d8ec79647516 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_data.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -48,7 +48,7 @@ def test_redshift_data_trigger_serialization(self): "task_id": TEST_TASK_ID, "poll_interval": POLL_INTERVAL, "aws_conn_id": TEST_CONN_ID, - "region": None, + "region_name": None, } @pytest.mark.asyncio From 6a149ba8c06d0d0a441e814d6046a2a9dbd552b6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 16 Jan 2024 14:43:30 +0800 Subject: [PATCH 06/11] feat(providers/amazon): add verify and botocore_config as suggested --- .../amazon/aws/operators/redshift_data.py | 2 ++ .../amazon/aws/triggers/redshift_data.py | 21 ++++++++++++++++--- .../amazon/aws/triggers/test_redshift_data.py | 2 ++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index 687ef35d046b6..e5667781c6620 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -154,6 +154,8 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: poll_interval=self.poll_interval, aws_conn_id=self.aws_conn_id, region_name=self.region_name, + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py index 2c44d143ffaa5..da10270a1b7e9 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_data.py +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -41,13 +41,18 @@ def __init__( poll_interval: int, aws_conn_id: str | None = "aws_default", region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): super().__init__() self.statement_id = statement_id self.task_id = task_id - self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval + + self.aws_conn_id = aws_conn_id self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftDataTrigger arguments and classpath.""" @@ -59,13 +64,23 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "aws_conn_id": self.aws_conn_id, "poll_interval": self.poll_interval, "region_name": self.region_name, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) + @property + def hook(self) -> RedshiftDataHook: + return RedshiftDataHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + async def run(self) -> AsyncIterator[TriggerEvent]: - hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) try: - response = await hook.check_query_is_finished_async(self.statement_id) + response = await self.hook.check_query_is_finished_async(self.statement_id) if not response: response = {"status": "error", "message": f"{self.task_id} failed"} yield TriggerEvent(response) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py index 9d8ec79647516..df2859c775677 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_data.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -49,6 +49,8 @@ def test_redshift_data_trigger_serialization(self): "poll_interval": POLL_INTERVAL, "aws_conn_id": TEST_CONN_ID, "region_name": None, + "botocore_config": None, + "verify": None, } @pytest.mark.asyncio From a4eb56e210fe6eb9b738e9b624515de2667a2600 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 17 Jan 2024 18:44:07 +0800 Subject: [PATCH 07/11] refactor(providers/amazon): use async_conn from aws hook and add missing await --- airflow/providers/amazon/aws/hooks/redshift_data.py | 7 ++----- .../providers/amazon/aws/operators/redshift_data.py | 9 +++++---- .../providers/amazon/aws/triggers/redshift_data.py | 8 ++++++-- .../providers/amazon/aws/hooks/test_redshift_data.py | 12 +++++++----- .../amazon/aws/operators/test_redshift_data.py | 3 ++- .../amazon/aws/triggers/test_redshift_data.py | 9 +++++++-- 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index eaa8b322ba2e4..149f5d3a5e027 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, Iterable import botocore.exceptions -from asgiref.sync import sync_to_async from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils import trim_none_values @@ -230,11 +229,10 @@ async def check_query_is_finished_async( :param poll_interval: how often in seconds to check the query status """ try: - client = await sync_to_async(self.get_conn)() while await self.is_still_running(statement_id): await asyncio.sleep(poll_interval) - resp = client.describe_statement(Id=statement_id) + resp = await self.async_conn.describe_statement(Id=statement_id) status = resp["Status"] if status == FINISHED_STATE: return {"status": "success", "statement_id": statement_id} @@ -267,6 +265,5 @@ async def is_still_running(self, statement_id: str) -> bool: :param statement_id: the UUID of the statement """ - client = await sync_to_async(self.get_conn)() - desc = client.describe_statement(Id=statement_id) + desc = await self.async_conn.describe_statement(Id=statement_id) return desc["Status"] in RUNNING_STATES diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index e5667781c6620..71ee82069e662 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -179,16 +179,17 @@ def execute_complete( msg = f"context: {context}, error message: {event['message']}" raise AirflowException(msg) - if not self.statement_id: + statement_id = event["statement_id"] + if not statement_id: raise AirflowException("statement_id should not be empty.") self.log.info("%s completed successfully.", self.task_id) if self.return_sql_result: - result = self.hook.conn.get_statement_result(Id=self.statement_id) + result = self.hook.conn.get_statement_result(Id=statement_id) self.log.debug("Statement result: %s", result) return result - else: - return self.statement_id + + return statement_id def on_kill(self) -> None: """Cancel the submitted redshift query.""" diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py index da10270a1b7e9..bf5c63fc0cbab 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_data.py +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -82,7 +82,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: response = await self.hook.check_query_is_finished_async(self.statement_id) if not response: - response = {"status": "error", "message": f"{self.task_id} failed"} + response = { + "status": "error", + "message": f"{self.task_id} failed", + "statement_id": self.statement_id, + } yield TriggerEvent(response) except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) + yield TriggerEvent({"status": "error", "message": str(e), "statement_id": self.statement_id}) diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index 6f85072fc1f31..9b6858f1abe4c 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -305,10 +305,11 @@ def test_result_num_rows(self, mock_conn, caplog): ({"Status": "ABORTED"}, False), ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_conn") - async def test_is_still_running(self, mock_client, describe_statement_response, expected_result): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") + async def test_is_still_running(self, mock_conn, describe_statement_response, expected_result): hook = RedshiftDataHook() - mock_client.return_value.describe_statement.return_value = describe_statement_response + mock_conn.describe_statement = mock.AsyncMock() + mock_conn.describe_statement.return_value = describe_statement_response response = await hook.is_still_running("uuid") assert response == expected_result @@ -337,13 +338,14 @@ async def test_is_still_running(self, mock_client, describe_statement_response, ), ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") async def test_get_query_status( self, mock_is_still_running, mock_conn, describe_statement_response, expected_result ): hook = RedshiftDataHook() mock_is_still_running.return_value = False - mock_conn.return_value.describe_statement.return_value = describe_statement_response + mock_conn.describe_statement = mock.AsyncMock() + mock_conn.describe_statement.return_value = describe_statement_response response = await hook.check_query_is_finished_async(statement_id="uuid") assert response == expected_result diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index 682127e33fc45..fa22c98218305 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -309,7 +309,8 @@ def test_execute_complete(self, deferrable_operator): with mock.patch.object(deferrable_operator.log, "info") as mock_log_info: assert ( deferrable_operator.execute_complete( - context=None, event={"status": "success", "message": "Job completed"} + context=None, + event={"status": "success", "message": "Job completed", "statement_id": "uuid"}, ) == "uuid" ) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py index df2859c775677..8207506c424f1 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_data.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -63,7 +63,12 @@ def test_redshift_data_trigger_serialization(self): {"status": "error", "message": "test error", "statement_id": "uuid", "type": "failed"} ), ), - (False, TriggerEvent({"status": "error", "message": f"{TEST_TASK_ID} failed"})), + ( + False, + TriggerEvent( + {"status": "error", "message": f"{TEST_TASK_ID} failed", "statement_id": "uuid"} + ), + ), ], ) @mock.patch( @@ -102,4 +107,4 @@ async def test_redshift_data_trigger_exception(self, mock_get_query_status): ) task = [i async for i in trigger.run()] assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + assert TriggerEvent({"status": "error", "message": "Test exception", "statement_id": "uuid"}) in task From 2ff2f2c68ffd2f793b1d657489b480d4829edcd8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 18 Jan 2024 10:15:51 +0800 Subject: [PATCH 08/11] feat(providers/amazon): make RedshiftDataTrigger.hook a cached_property --- airflow/providers/amazon/aws/triggers/redshift_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py index bf5c63fc0cbab..61ed6a339a58d 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_data.py +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -17,6 +17,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, AsyncIterator from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook @@ -69,7 +70,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) - @property + @cached_property def hook(self) -> RedshiftDataHook: return RedshiftDataHook( aws_conn_id=self.aws_conn_id, From bdec22607bb075e1a041e90f6608d84323d2b4a5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 18 Jan 2024 12:36:53 +0800 Subject: [PATCH 09/11] refactor(providers/amaozn): unify how async and sync version of check_query_is_finished are implemented --- .../amazon/aws/hooks/redshift_data.py | 78 +++++++------------ .../amazon/aws/triggers/redshift_data.py | 32 ++++++-- .../amazon/aws/hooks/test_redshift_data.py | 51 ++++++------ .../amazon/aws/triggers/test_redshift_data.py | 61 ++++++++++++--- 4 files changed, 133 insertions(+), 89 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 149f5d3a5e027..aebffc357b7d6 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -17,13 +17,10 @@ # under the License. from __future__ import annotations -import asyncio import time from pprint import pformat from typing import TYPE_CHECKING, Any, Iterable -import botocore.exceptions - from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils import trim_none_values @@ -37,6 +34,14 @@ RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"} +class RedshiftDataQueryFailedError(ValueError): + """Raise an error that redshift data query failed.""" + + +class RedshiftDataQueryAbortedError(ValueError): + """Raise an error that redshift data query was aborted.""" + + class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): """ Interact with Amazon Redshift Data API. @@ -129,6 +134,10 @@ def wait_for_results(self, statement_id: str, poll_interval: int) -> str: def check_query_is_finished(self, statement_id: str) -> bool: """Check whether query finished, raise exception is failed.""" resp = self.conn.describe_statement(Id=statement_id) + return self.parse_statement_resposne(resp) + + def parse_statement_resposne(self, resp: dict[str, Any]) -> bool: + """Parse the response of describe_statement.""" status = resp["Status"] if status == FINISHED_STATE: num_rows = resp.get("ResultRows") @@ -136,12 +145,15 @@ def check_query_is_finished(self, statement_id: str) -> bool: self.log.info("Processed %s rows", num_rows) return True elif status in FAILURE_STATES: - raise ValueError( - f"Statement {statement_id!r} terminated with status {status}. " + exception_cls = ( + RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError + ) + raise exception_cls( + f"Statement {resp['Id']} terminated with status {status}. " f"Response details: {pformat(resp)}" ) - self.log.info("Query %s", status) + self.log.info("Query status: %s", status) return False def get_table_primary_key( @@ -217,49 +229,6 @@ def get_table_primary_key( return pk_columns or None - async def check_query_is_finished_async( - self, statement_id: str, poll_interval: int = 10 - ) -> dict[str, str]: - """Async function to check statement is finished. - - It takes statement_id, makes async connection to redshift data to get the query status - by statement_id and returns the query status. - - :param statement_id: the UUID of the statement - :param poll_interval: how often in seconds to check the query status - """ - try: - while await self.is_still_running(statement_id): - await asyncio.sleep(poll_interval) - - resp = await self.async_conn.describe_statement(Id=statement_id) - status = resp["Status"] - if status == FINISHED_STATE: - return {"status": "success", "statement_id": statement_id} - elif status == FAILED_STATE: - return { - "status": "error", - "message": f"Error: {resp['QueryString']} query Failed due to, {resp['Error']}", - "statement_id": statement_id, - "type": status, - } - elif status == ABORTED_STATE: - return { - "status": "error", - "message": "The query run was stopped by the user.", - "statement_id": statement_id, - "type": status, - } - - return { - "status": "error", - "message": f"Unexpected statue {status}", - "statement_id": statement_id, - "type": status, - } - except botocore.exceptions.ClientError as error: - return {"status": "error", "message": str(error), "type": "ERROR"} - async def is_still_running(self, statement_id: str) -> bool: """Async function to check whether the query is still running. @@ -267,3 +236,14 @@ async def is_still_running(self, statement_id: str) -> bool: """ desc = await self.async_conn.describe_statement(Id=statement_id) return desc["Status"] in RUNNING_STATES + + async def check_query_is_finished_async(self, statement_id: str) -> bool: + """Async function to check statement is finished. + + It takes statement_id, makes async connection to redshift data to get the query status + by statement_id and returns the query status. + + :param statement_id: the UUID of the statement + """ + resp = await self.async_conn.describe_statement(Id=statement_id) + return self.parse_statement_resposne(resp) diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py index 61ed6a339a58d..2d0ecbc594db3 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_data.py +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -17,10 +17,17 @@ from __future__ import annotations +import asyncio from functools import cached_property from typing import Any, AsyncIterator -from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.hooks.redshift_data import ( + ABORTED_STATE, + FAILED_STATE, + RedshiftDataHook, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -81,13 +88,26 @@ def hook(self) -> RedshiftDataHook: async def run(self) -> AsyncIterator[TriggerEvent]: try: - response = await self.hook.check_query_is_finished_async(self.statement_id) - if not response: + while await self.hook.is_still_running(self.statement_id): + await asyncio.sleep(self.poll_interval) + + is_finished = await self.hook.check_query_is_finished_async(self.statement_id) + if is_finished: + response = {"status": "success", "statement_id": self.statement_id} + else: response = { "status": "error", - "message": f"{self.task_id} failed", "statement_id": self.statement_id, + "message": f"{self.task_id} failed", } yield TriggerEvent(response) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e), "statement_id": self.statement_id}) + except (RedshiftDataQueryFailedError, RedshiftDataQueryAbortedError) as error: + response = { + "status": "error", + "statement_id": self.statement_id, + "message": str(error), + "type": FAILED_STATE if isinstance(error, RedshiftDataQueryFailedError) else ABORTED_STATE, + } + yield TriggerEvent(response) + except Exception as error: + yield TriggerEvent({"status": "error", "statement_id": self.statement_id, "message": str(error)}) diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index 9b6858f1abe4c..689878f174460 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -22,7 +22,11 @@ import pytest -from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.hooks.redshift_data import ( + RedshiftDataHook, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) SQL = "sql" DATABASE = "database" @@ -313,39 +317,36 @@ async def test_is_still_running(self, mock_conn, describe_statement_response, ex response = await hook.is_still_running("uuid") assert response == expected_result + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") + async def test_check_query_is_finished_async(self, mock_is_still_running, mock_conn): + hook = RedshiftDataHook() + mock_is_still_running.return_value = False + mock_conn.describe_statement = mock.AsyncMock() + mock_conn.describe_statement.return_value = {"Id": "uuid", "Status": "FINISHED"} + is_finished = await hook.check_query_is_finished_async(statement_id="uuid") + assert is_finished is True + @pytest.mark.asyncio @pytest.mark.parametrize( - "describe_statement_response, expected_result", - [ - ({"Status": "FINISHED"}, {"status": "success", "statement_id": "uuid"}), - ( - {"Status": "FAILED", "QueryString": "select 1", "Error": "Test error"}, - { - "status": "error", - "message": "Error: select 1 query Failed due to, Test error", - "statement_id": "uuid", - "type": "FAILED", - }, - ), + "describe_statement_response, expected_exception", + ( ( - {"Status": "ABORTED"}, - { - "status": "error", - "message": "The query run was stopped by the user.", - "statement_id": "uuid", - "type": "ABORTED", - }, + {"Id": "uuid", "Status": "FAILED", "QueryString": "select 1", "Error": "Test error"}, + RedshiftDataQueryFailedError, ), - ], + ({"Id": "uuid", "Status": "ABORTED"}, RedshiftDataQueryAbortedError), + ), ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") - async def test_get_query_status( - self, mock_is_still_running, mock_conn, describe_statement_response, expected_result + async def test_check_query_is_finished_async_exception( + self, mock_is_still_running, mock_conn, describe_statement_response, expected_exception ): hook = RedshiftDataHook() mock_is_still_running.return_value = False mock_conn.describe_statement = mock.AsyncMock() mock_conn.describe_statement.return_value = describe_statement_response - response = await hook.check_query_is_finished_async(statement_id="uuid") - assert response == expected_result + with pytest.raises(expected_exception): + await hook.check_query_is_finished_async(statement_id="uuid") diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py index 8207506c424f1..49c0862af274f 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_data.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -21,6 +21,12 @@ import pytest +from airflow.providers.amazon.aws.hooks.redshift_data import ( + ABORTED_STATE, + FAILED_STATE, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger from airflow.triggers.base import TriggerEvent @@ -58,10 +64,8 @@ def test_redshift_data_trigger_serialization(self): "return_value, response", [ ( - {"status": "error", "message": "test error", "statement_id": "uuid", "type": "failed"}, - TriggerEvent( - {"status": "error", "message": "test error", "statement_id": "uuid", "type": "failed"} - ), + True, + TriggerEvent({"status": "success", "statement_id": "uuid"}), ), ( False, @@ -74,11 +78,17 @@ def test_redshift_data_trigger_serialization(self): @mock.patch( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" ) - async def test_redshift_data_trigger_run(self, mock_get_query_status, return_value, response): + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running", + return_value=False, + ) + async def test_redshift_data_trigger_run( + self, mocked_is_still_running, mock_check_query_is_finised_async, return_value, response + ): """ Tests that RedshiftDataTrigger only fires once the query execution reaches a successful state. """ - mock_get_query_status.return_value = return_value + mock_check_query_is_finised_async.return_value = return_value trigger = RedshiftDataTrigger( statement_id="uuid", task_id=TEST_TASK_ID, @@ -90,14 +100,47 @@ async def test_redshift_data_trigger_run(self, mock_get_query_status, return_val assert response == actual @pytest.mark.asyncio + @pytest.mark.parametrize( + "raised_exception, expected_response", + [ + ( + RedshiftDataQueryFailedError("Failed"), + { + "status": "error", + "statement_id": "uuid", + "message": "Failed", + "type": FAILED_STATE, + }, + ), + ( + RedshiftDataQueryAbortedError("Aborted"), + { + "status": "error", + "statement_id": "uuid", + "message": "Aborted", + "type": ABORTED_STATE, + }, + ), + ( + Exception(f"{TEST_TASK_ID} failed"), + {"status": "error", "statement_id": "uuid", "message": f"{TEST_TASK_ID} failed"}, + ), + ], + ) @mock.patch( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" ) - async def test_redshift_data_trigger_exception(self, mock_get_query_status): + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running", + return_value=False, + ) + async def test_redshift_data_trigger_exception( + self, mocked_is_still_running, mock_check_query_is_finised_async, raised_exception, expected_response + ): """ Test that RedshiftDataTrigger fires the correct event in case of an error. """ - mock_get_query_status.side_effect = Exception("Test exception") + mock_check_query_is_finised_async.side_effect = raised_exception trigger = RedshiftDataTrigger( statement_id="uuid", @@ -107,4 +150,4 @@ async def test_redshift_data_trigger_exception(self, mock_get_query_status): ) task = [i async for i in trigger.run()] assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception", "statement_id": "uuid"}) in task + assert TriggerEvent(expected_response) in task From 14049d6cad76eb0014043fcf356863e0a57c8918 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 18 Jan 2024 14:47:38 +0800 Subject: [PATCH 10/11] style(providers/amazon): fix mypy failure --- airflow/providers/amazon/aws/hooks/redshift_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index aebffc357b7d6..7872408db7d9e 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -25,7 +25,7 @@ from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: - from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa + from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef FINISHED_STATE = "FINISHED" FAILED_STATE = "FAILED" @@ -136,7 +136,7 @@ def check_query_is_finished(self, statement_id: str) -> bool: resp = self.conn.describe_statement(Id=statement_id) return self.parse_statement_resposne(resp) - def parse_statement_resposne(self, resp: dict[str, Any]) -> bool: + def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool: """Parse the response of describe_statement.""" status = resp["Status"] if status == FINISHED_STATE: From d0a7d1c98782e628f60eb036f7d4ed287c84dc7a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 18 Jan 2024 16:12:03 +0800 Subject: [PATCH 11/11] fix(providers/amazon): fix async_conn call --- airflow/providers/amazon/aws/hooks/redshift_data.py | 11 +++++++---- .../providers/amazon/aws/hooks/test_redshift_data.py | 11 ++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 7872408db7d9e..538e5cee96909 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: + from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef FINISHED_STATE = "FINISHED" @@ -234,8 +235,9 @@ async def is_still_running(self, statement_id: str) -> bool: :param statement_id: the UUID of the statement """ - desc = await self.async_conn.describe_statement(Id=statement_id) - return desc["Status"] in RUNNING_STATES + async with self.async_conn as client: + desc = await client.describe_statement(Id=statement_id) + return desc["Status"] in RUNNING_STATES async def check_query_is_finished_async(self, statement_id: str) -> bool: """Async function to check statement is finished. @@ -245,5 +247,6 @@ async def check_query_is_finished_async(self, statement_id: str) -> bool: :param statement_id: the UUID of the statement """ - resp = await self.async_conn.describe_statement(Id=statement_id) - return self.parse_statement_resposne(resp) + async with self.async_conn as client: + resp = await client.describe_statement(Id=statement_id) + return self.parse_statement_resposne(resp) diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index 689878f174460..126585b432d48 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -312,8 +312,7 @@ def test_result_num_rows(self, mock_conn, caplog): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") async def test_is_still_running(self, mock_conn, describe_statement_response, expected_result): hook = RedshiftDataHook() - mock_conn.describe_statement = mock.AsyncMock() - mock_conn.describe_statement.return_value = describe_statement_response + mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response response = await hook.is_still_running("uuid") assert response == expected_result @@ -324,7 +323,10 @@ async def test_check_query_is_finished_async(self, mock_is_still_running, mock_c hook = RedshiftDataHook() mock_is_still_running.return_value = False mock_conn.describe_statement = mock.AsyncMock() - mock_conn.describe_statement.return_value = {"Id": "uuid", "Status": "FINISHED"} + mock_conn.__aenter__.return_value.describe_statement.return_value = { + "Id": "uuid", + "Status": "FINISHED", + } is_finished = await hook.check_query_is_finished_async(statement_id="uuid") assert is_finished is True @@ -346,7 +348,6 @@ async def test_check_query_is_finished_async_exception( ): hook = RedshiftDataHook() mock_is_still_running.return_value = False - mock_conn.describe_statement = mock.AsyncMock() - mock_conn.describe_statement.return_value = describe_statement_response + mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response with pytest.raises(expected_exception): await hook.check_query_is_finished_async(statement_id="uuid")