From 96e9802bda8515438e9bd75889383fc8e38e849e Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 22 Sep 2023 16:34:34 +0530 Subject: [PATCH 1/3] Saving work --- airflow/providers/amazon/aws/sensors/athena.py | 8 ++++++-- tests/providers/amazon/aws/sensors/test_athena.py | 8 +++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 70d7875629fb9..4a1de652667bb 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.sensors.base import BaseSensorOperator @@ -78,7 +78,11 @@ def poke(self, context: Context) -> bool: state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time) if state in self.FAILURE_STATES: - raise AirflowException("Athena sensor failed") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = "Athena sensor failed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if state in self.INTERMEDIATE_STATES: return False diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index a9809be1d052b..e673e70e44828 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.sensors.athena import AthenaSensor @@ -59,3 +59,9 @@ def test_poke_cancelled(self, mock_poll_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke({}) assert "Athena sensor failed" in str(ctx.value) + + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_fail_poke(self, soft_fail, expected_exception): + pass From 9e219e0dbacb89988414c1419ea6ba7be23465de Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 22 Sep 2023 17:07:35 +0530 Subject: [PATCH 2/3] Respect soft_fail parameter in AthenaSensor --- tests/providers/amazon/aws/sensors/test_athena.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index e673e70e44828..f592c05c21e87 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -64,4 +64,9 @@ def test_poke_cancelled(self, mock_poll_query_status): "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) ) def test_fail_poke(self, soft_fail, expected_exception): - pass + self.sensor.soft_fail = soft_fail + with pytest.raises(expected_exception), mock.patch( + "airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status" + ) as poll_query_status: + poll_query_status.return_value = "FAILED" + self.sensor.poke(context={}) From d8763d78af8e6ee367de5fb4a12718164095d566 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 22 Sep 2023 17:11:52 +0530 Subject: [PATCH 3/3] Assert the error message --- tests/providers/amazon/aws/sensors/test_athena.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index f592c05c21e87..18012d81cad09 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -65,7 +65,8 @@ def test_poke_cancelled(self, mock_poll_query_status): ) def test_fail_poke(self, soft_fail, expected_exception): self.sensor.soft_fail = soft_fail - with pytest.raises(expected_exception), mock.patch( + message = "Athena sensor failed" + with pytest.raises(expected_exception, match=message), mock.patch( "airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status" ) as poll_query_status: poll_query_status.return_value = "FAILED"