diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 09c6870948ad5..5de873f220f7a 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -26,6 +26,20 @@ Changelog --------- +Main +...... + +.. warning:: When deferrable mode was introduced for ``RedshiftDataOperator``, in version 8.17.0, tasks configured with + ``deferrable=True`` and ``wait_for_completion=True`` wouldn't enter the deferred state. Instead, the task would occupy + an executor slot until the statement was completed. A workaround may have been to set ``wait_for_completion=False``. + In this version, tasks set up with ``wait_for_completion=False`` will not wait anymore, regardless of the value of + ``deferrable``. + +Bug Fixes +~~~~~~~~~ + +* ``Fix deferred mode for 'RedshiftDataOperator' (#41206)`` + 8.27.0 ...... diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index 54e3c2c7ae1ae..45fee2a919483 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -127,8 +127,8 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: # Set wait_for_completion to False so that it waits for the status in the deferred task. wait_for_completion = self.wait_for_completion - if self.deferrable and self.wait_for_completion: - self.wait_for_completion = False + if self.deferrable: + wait_for_completion = False self.statement_id = self.hook.execute_query( database=self.database, @@ -144,7 +144,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: poll_interval=self.poll_interval, ) - if self.deferrable: + if self.deferrable and self.wait_for_completion: is_finished = self.hook.check_query_is_finished(self.statement_id) if not is_finished: self.defer( diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index a02515441b0fb..fa021395a419d 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -51,7 +51,7 @@ def deferrable_operator(): secret_arn=secret_arn, statement_name=statement_name, parameters=parameters, - wait_for_completion=False, + wait_for_completion=True, poll_interval=poll_interval, deferrable=True, ) @@ -276,7 +276,6 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin poll_interval=poll_interval, ) - # @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") @mock.patch( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", return_value=False, @@ -315,3 +314,38 @@ def test_execute_complete(self, deferrable_operator): == "uuid" ) mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) + + @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finished, mock_defer): + """Tests that the operator does not check for completion nor defers when wait_for_completion is False, + no matter the value of deferrable""" + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + wait_for_completion = False + + for deferrable in [True, False]: + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=wait_for_completion, + poll_interval=poll_interval, + deferrable=deferrable, + ) + operator.execute(None) + + assert not mock_check_query_is_finished.called + assert not mock_defer.called