diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 47e6c07ee2066..47b2aa5044689 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -58,6 +58,7 @@ AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning, + TaskDeferralError, TaskDeferred, ) from airflow.lineage import apply_lineage, prepare_lineage @@ -1590,6 +1591,22 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): + """This method is called when a deferred task is resumed.""" + # __fail__ is a special signal value for next_method that indicates + # this task was scheduled specifically to fail. + if next_method == "__fail__": + next_kwargs = next_kwargs or {} + traceback = next_kwargs.get("traceback") + if traceback is not None: + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) + raise TaskDeferralError(next_kwargs.get("error", "Unknown")) + # Grab the callable off the Operator/Task and add in any kwargs + execute_callable = getattr(self, next_method) + if next_kwargs: + execute_callable = functools.partial(execute_callable, **next_kwargs) + return execute_callable(context) + def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """Get the "normal" operator from the current operator. diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 74cc5e45ffda6..975e615d95414 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -29,7 +29,6 @@ from collections import defaultdict from datetime import datetime, timedelta from enum import Enum -from functools import partial from pathlib import PurePath from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple @@ -81,7 +80,6 @@ AirflowTaskTimeout, DagRunNotFound, RemovedInAirflow3Warning, - TaskDeferralError, TaskDeferred, UnmappableXComLengthPushed, UnmappableXComTypePushed, @@ -1710,19 +1708,11 @@ def _execute_task(self, context, task_orig): # If the task has been deferred and is being executed due to a trigger, # then we need to pick the right method to come back to, otherwise # we go for the default execute + execute_callable_kwargs = {} if self.next_method: - # __fail__ is a special signal value for next_method that indicates - # this task was scheduled specifically to fail. - if self.next_method == "__fail__": - next_kwargs = self.next_kwargs or {} - traceback = self.next_kwargs.get("traceback") - if traceback is not None: - self.log.error("Trigger failed:\n%s", "\n".join(traceback)) - raise TaskDeferralError(next_kwargs.get("error", "Unknown")) - # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(task_to_execute, self.next_method) - if self.next_kwargs: - execute_callable = partial(execute_callable, **self.next_kwargs) + execute_callable = task_to_execute.resume_execution + execute_callable_kwargs["next_method"] = self.next_method + execute_callable_kwargs["next_kwargs"] = self.next_kwargs else: execute_callable = task_to_execute.execute # If a timeout is specified for the task, make it fail @@ -1742,12 +1732,12 @@ def _execute_task(self, context, task_orig): raise AirflowTaskTimeout() # Run task in timeout wrapper with timeout(timeout_seconds): - result = execute_callable(context=context) + result = execute_callable(context=context, **execute_callable_kwargs) except AirflowTaskTimeout: task_to_execute.on_kill() raise else: - result = execute_callable(context=context) + result = execute_callable(context=context, **execute_callable_kwargs) with create_session() as session: if task_to_execute.do_xcom_push: xcom_value = result diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 792d907d1f508..3f8b6bf2e6bce 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -35,6 +35,7 @@ AirflowSensorTimeout, AirflowSkipException, AirflowTaskTimeout, + TaskDeferralError, ) from airflow.executors.executor_loader import ExecutorLoader from airflow.models.baseoperator import BaseOperator @@ -281,6 +282,14 @@ def run_duration() -> float: self.log.info("Success criteria met. Exiting.") return xcom_value + def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): + try: + return super().resume_execution(next_method, next_kwargs, context) + except (AirflowException, TaskDeferralError) as e: + if self.soft_fail: + raise AirflowSkipException(str(e)) from e + raise + def _get_next_poke_interval( self, started_at: datetime.datetime | float, diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index e4e6ac5ad5d50..4dff8222e200a 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -23,7 +23,12 @@ import pytest import time_machine -from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout +from airflow.exceptions import ( + AirflowException, + AirflowRescheduleException, + AirflowSensorTimeout, + AirflowSkipException, +) from airflow.executors.debug_executor import DebugExecutor from airflow.executors.executor_constants import ( CELERY_EXECUTOR, @@ -37,7 +42,7 @@ ) from airflow.executors.local_executor import LocalExecutor from airflow.executors.sequential_executor import SequentialExecutor -from airflow.models import TaskReschedule +from airflow.models import TaskInstance, TaskReschedule from airflow.models.xcom import XCom from airflow.operators.empty import EmptyOperator from airflow.providers.celery.executors.celery_executor import CeleryExecutor @@ -70,6 +75,15 @@ def poke(self, context: Context): return self.return_value +class DummyAsyncSensor(BaseSensorOperator): + def __init__(self, return_value=False, **kwargs): + super().__init__(**kwargs) + self.return_value = return_value + + def execute_complete(self, context, event=None): + raise AirflowException("Should be skipped") + + class DummySensorWithXcomValue(BaseSensorOperator): def __init__(self, return_value=False, xcom_value=None, **kwargs): super().__init__(**kwargs) @@ -910,3 +924,19 @@ def test_poke_mode_only_bad_poke(self): sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=True) with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"): sensor.poke({}) + + +class TestAsyncSensor: + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + (True, AirflowSkipException), + (False, AirflowException), + ], + ) + def test_fail_after_resuming_deffered_sensor(self, soft_fail, expected_exception): + async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail) + ti = TaskInstance(task=async_sensor) + ti.next_method = "execute_complete" + with pytest.raises(expected_exception): + ti._execute_task({}, None)