diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 70d7875629fb9..4a1de652667bb 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.sensors.base import BaseSensorOperator @@ -78,7 +78,11 @@ def poke(self, context: Context) -> bool: state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time) if state in self.FAILURE_STATES: - raise AirflowException("Athena sensor failed") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = "Athena sensor failed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if state in self.INTERMEDIATE_STATES: return False diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index a9809be1d052b..18012d81cad09 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.sensors.athena import AthenaSensor @@ -59,3 +59,15 @@ def test_poke_cancelled(self, mock_poll_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke({}) assert "Athena sensor failed" in str(ctx.value) + + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_fail_poke(self, soft_fail, expected_exception): + self.sensor.soft_fail = soft_fail + message = "Athena sensor failed" + with pytest.raises(expected_exception, match=message), mock.patch( + "airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status" + ) as poll_query_status: + poll_query_status.return_value = "FAILED" + self.sensor.poke(context={})