diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 8f636f4fe7673..effea9f6a7e52 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -390,19 +390,37 @@ def run(ti: RuntimeTaskInstance, log: Logger): # - Update RTIF # - Pre Execute # etc - result = ti.task.execute(context) # type: ignore[attr-defined] - _push_xcom_if_needed(result, ti) + result = None + if ti.task.execution_timeout: + # TODO: handle timeout in case of deferral + from airflow.utils.timeout import timeout + + timeout_seconds = ti.task.execution_timeout.total_seconds() + try: + # It's possible we're already timed out, so fast-fail if true + if timeout_seconds <= 0: + raise AirflowTaskTimeout() + # Run task in timeout wrapper + with timeout(timeout_seconds): + result = ti.task.execute(context) # type: ignore[attr-defined] + except AirflowTaskTimeout: + # TODO: handle on kill callback here + raise + else: + result = ti.task.execute(context) # type: ignore[attr-defined] + + _push_xcom_if_needed(result, ti) msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) except TaskDeferred as defer: classpath, trigger_kwargs = defer.trigger.serialize() next_method = defer.method_name - timeout = defer.timeout + defer_timeout = defer.timeout msg = DeferTask( classpath=classpath, trigger_kwargs=trigger_kwargs, next_method=next_method, - trigger_timeout=timeout, + trigger_timeout=defer_timeout, ) except AirflowSkipException: msg = TaskState( @@ -423,13 +441,14 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), ) - # TODO: Run task failure callbacks here - except AirflowTaskTimeout: - # TODO: handle the case of up_for_retry here - # TODO: coagulate this exception handling with AirflowException - # once https://github.com/apache/airflow/issues/45307 is handled - ... + except (AirflowTaskTimeout, AirflowException): + # We should allow retries if the task has defined it. + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + # TODO: Run task failure callbacks here except AirflowException: # TODO: handle the case of up_for_retry here msg = TaskState( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index cb756921e1013..ebf03d3323315 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -371,6 +371,48 @@ def test_run_raises_airflow_exception(time_machine, mocked_parse, make_ti_contex ) +def test_run_task_timeout(time_machine, mocked_parse, make_ti_context, mock_supervisor_comms): + """Test running a basic task that times out.""" + from time import sleep + + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="sleep", + execution_timeout=timedelta(milliseconds=10), + python_callable=lambda: sleep(2), + ) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="sleep", + dag_id="basic_dag_time_out", + run_id="c", + try_number=1, + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, "basic_dag_time_out", task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + run(ti, log=mock.MagicMock()) + + # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState( + state=TerminalTIState.FAILED, + end_date=instant, + ), + log=mock.ANY, + ) + + def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator