From 37676b94ce54a80509c9cf91af817a0f34ebbb2b Mon Sep 17 00:00:00 2001 From: Leonardo Ishida Date: Tue, 25 Nov 2025 19:34:30 -0300 Subject: [PATCH] Handle task timeouts at supervisor --- .../src/airflow/models/taskinstance.py | 6 +- .../unit/dag_processing/test_processor.py | 1 + .../tests/unit/jobs/test_triggerer_job.py | 1 + .../ssh/tests/unit/ssh/operators/test_ssh.py | 22 +++-- .../unit/standard/operators/test_bash.py | 51 ++++++++++-- .../src/airflow/sdk/execution_time/comms.py | 8 ++ .../airflow/sdk/execution_time/supervisor.py | 82 +++++++++++++++++-- .../airflow/sdk/execution_time/task_runner.py | 28 +++---- .../execution_time/test_supervisor.py | 51 ++++++++++++ .../execution_time/test_task_runner.py | 53 +++++------- 10 files changed, 228 insertions(+), 75 deletions(-) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index d396d9a0f3b40..455eebb8e3272 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1301,7 +1301,11 @@ def _run_raw_task( if taskrun_result is None: return None if taskrun_result.error: - raise taskrun_result.error + from airflow.exceptions import AirflowTaskTerminated + + # Don't re-raise AirflowTaskTerminated + if not isinstance(taskrun_result.error, AirflowTaskTerminated): + raise taskrun_result.error self.task = taskrun_result.ti.task # type: ignore[assignment] return None diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 398bd5fcdabba..2d7b7bcaefd00 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1864,6 +1864,7 @@ def get_type_names(union_type): "UpdateHITLDetail", "GetHITLDetailResponse", "SetRenderedMapIndex", + "SetTaskExecutionTimeout", } in_task_runner_but_not_in_dag_processing_process = { diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 2dfb735bc34f5..adce51180e11b 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1218,6 +1218,7 @@ def get_type_names(union_type): "ResendLoggingFD", "CreateHITLDetailPayload", "SetRenderedMapIndex", + "SetTaskExecutionTimeout", } in_task_but_not_in_trigger_runner = { diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py b/providers/ssh/tests/unit/ssh/operators/test_ssh.py index 5747a738fc843..c9a1ba1cf1e61 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py @@ -25,11 +25,12 @@ import pytest from paramiko.client import SSHClient -from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models import TaskInstance from airflow.providers.common.compat.sdk import timezone from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator +from airflow.utils.state import State from tests_common.test_utils.config import conf_vars from tests_common.test_utils.dag import sync_dag_to_db @@ -284,12 +285,12 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker): def test_timeout_triggers_on_kill(self, request, dag_maker): def command_sleep_forever(*args, **kwargs): - time.sleep(100) # This will be interrupted by the timeout + time.sleep(10) # This will be interrupted by the supervisor timeout self.exec_ssh_client_command.side_effect = command_sleep_forever with dag_maker(dag_id=f"dag_{request.node.name}"): - _ = SSHOperator( + task = SSHOperator( task_id="test_timeout", ssh_hook=self.hook, command="sleep 100", @@ -297,14 +298,19 @@ def command_sleep_forever(*args, **kwargs): ) dr = dag_maker.create_dagrun(run_id="test_timeout") + # With supervisor-based timeout, the task is killed externally by the supervisor + # The on_kill handler should be called via SIGTERM signal handler with mock.patch.object(SSHOperator, "on_kill") as mock_on_kill: - with pytest.raises(AirflowTaskTimeout): - dag_maker.run_ti("test_timeout", dr) + dag_maker.run_ti("test_timeout", dr) - # Wait a bit to ensure on_kill has time to be called - time.sleep(1) + # Verify the task failed due to timeout + ti = dr.get_task_instance("test_timeout") + assert ti.state == State.FAILED - mock_on_kill.assert_called_once() + # on_kill should have been called when the task process received SIGTERM + # Note: This may be flaky in tests due to signal timing + time.sleep(1) + mock_on_kill.assert_called() def test_remote_host_passed_at_hook_init(self): remote_host = "test_host.internal" diff --git a/providers/standard/tests/unit/standard/operators/test_bash.py b/providers/standard/tests/unit/standard/operators/test_bash.py index e9822694ac709..a8cd46a3a1b9a 100644 --- a/providers/standard/tests/unit/standard/operators/test_bash.py +++ b/providers/standard/tests/unit/standard/operators/test_bash.py @@ -28,7 +28,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.standard.operators.bash import BashOperator from airflow.utils import timezone from airflow.utils.state import State @@ -245,23 +245,60 @@ def test_bash_operator_output_processor(self, context): @pytest.mark.db_test def test_bash_operator_kill(self, dag_maker): + import time as time_module + import psutil sleep_time = f"100{os.getpid()}" with dag_maker(serialized=True): BashOperator( task_id="test_bash_operator_kill", - execution_timeout=timedelta(microseconds=25), + execution_timeout=timedelta(seconds=2), # Use 2 seconds for more reliable testing bash_command=f"/bin/bash -c 'sleep {sleep_time}'", ) dr = dag_maker.create_dagrun() - with pytest.raises(AirflowTaskTimeout): + + # With supervisor-based timeout, the task is killed externally + # The task should complete (with failure state) without raising AirflowTaskTimeout + start = time_module.time() + try: dag_maker.run_ti("test_bash_operator_kill", dr) - sleep(2) + except Exception as e: + # Log any unexpected exceptions + print(f"Unexpected exception during run_ti: {type(e).__name__}: {e}") + raise + duration = time_module.time() - start + + print(f"Task completed in {duration:.2f} seconds") + + # Should complete within reasonable time (timeout + escalation) + # With 2s timeout and 2s escalation, should be < 10s + assert duration < 10, f"Task took {duration}s, expected < 10s" + + # Verify the task failed due to timeout + ti = dr.get_task_instance("test_bash_operator_kill") + print(f"Task state: {ti.state}") + assert ti.state == State.FAILED, f"Expected task to be FAILED, but got {ti.state}" + + # Give a moment for cleanup + sleep(1) + + # Verify the subprocess was properly killed and is not still running + still_running = [] for proc in psutil.process_iter(): - if proc.cmdline() == ["sleep", sleep_time]: - os.kill(proc.pid, signal.SIGTERM) - pytest.fail("BashOperator's subprocess still running after stopping on timeout!") + try: + if proc.cmdline() == ["sleep", sleep_time]: + still_running.append(proc.pid) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + if still_running: + for pid in still_running: + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + pass + pytest.fail(f"BashOperator's subprocess(es) still running after timeout: {still_running}") @pytest.mark.db_test def test_templated_fields(self, create_task_instance_of_operator): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 934667c5e66d4..2aa0c6cfffe75 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -828,6 +828,13 @@ class SetRenderedMapIndex(BaseModel): type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex" +class SetTaskExecutionTimeout(BaseModel): + """Payload for setting execution_timeout for a task instance.""" + + execution_timeout_seconds: float | None + type: Literal["SetTaskExecutionTimeout"] = "SetTaskExecutionTimeout" + + class TriggerDagRun(TriggerDAGRunPayload): dag_id: str run_id: Annotated[str, Field(title="Dag Run Id")] @@ -978,6 +985,7 @@ class MaskSecret(BaseModel): | RetryTask | SetRenderedFields | SetRenderedMapIndex + | SetTaskExecutionTimeout | SetXCom | SkipDownstreamTasks | SucceedTask diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 3fd3e23b15425..110ee575b8b97 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -107,6 +107,7 @@ SentFDs, SetRenderedFields, SetRenderedMapIndex, + SetTaskExecutionTimeout, SetXCom, SkipDownstreamTasks, StartupDetails, @@ -943,6 +944,12 @@ class ActivitySubprocess(WatchedSubprocess): _task_end_time_monotonic: float | None = attrs.field(default=None, init=False) _rendered_map_index: str | None = attrs.field(default=None, init=False) + _execution_timeout_seconds: float | None = attrs.field(default=None, init=False) + """The execution timeout in seconds, if set by the task.""" + + _execution_timeout_set_time: float | None = attrs.field(default=None, init=False) + """When the execution timeout was set (monotonic time). Timer starts from here.""" + decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor) ti: RuntimeTI | None = None @@ -1072,14 +1079,23 @@ def _monitor_subprocess(self): last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. - max_wait_time = max( - 0, # Make sure this value is never negative, - min( - # Ensure we heartbeat _at most_ 75% through the task instance heartbeat timeout time - HEARTBEAT_TIMEOUT - last_heartbeat_ago * 0.75, - MIN_HEARTBEAT_INTERVAL, - ), - ) + wait_times = [ + # Ensure we heartbeat _at most_ 75% through the task instance heartbeat timeout time + HEARTBEAT_TIMEOUT - last_heartbeat_ago * 0.75, + MIN_HEARTBEAT_INTERVAL, + # Cap at 1 second to ensure we check for new timeout messages frequently + 1.0, + ] + + # If execution timeout is set, also wake up to check it + if self._execution_timeout_seconds is not None and self._execution_timeout_set_time is not None: + elapsed = time.monotonic() - self._execution_timeout_set_time + time_until_timeout = self._execution_timeout_seconds - elapsed + # Wake up 100ms before timeout or when timeout is reached + wait_times.append(max(0.1, time_until_timeout)) + + max_wait_time = max(0, min(wait_times)) # Make sure this value is never negative + # Block until events are ready or the timeout is reached # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None @@ -1103,6 +1119,9 @@ def _monitor_subprocess(self): # logs self._send_heartbeat_if_needed() + # Check if task has exceeded execution_timeout + self._check_task_timeout() + self._handle_process_overtime_if_needed() def _handle_process_overtime_if_needed(self): @@ -1122,6 +1141,47 @@ def _handle_process_overtime_if_needed(self): ) self.kill(signal.SIGTERM, force=True) + def _check_task_timeout(self): + """ + Check if task has exceeded execution_timeout and kill it if necessary. + + This handles task timeout at the supervisor level rather than in the task + process itself. + + The method implements signal escalation: SIGTERM -> SIGKILL if process doesn't exit. + """ + # Only check timeout if we have a timeout set + if self._execution_timeout_seconds is None or self._execution_timeout_set_time is None: + return + + # Don't check timeout if task has already reached a terminal state + if self._terminal_state: + return + + elapsed_time = time.monotonic() - self._execution_timeout_set_time + + if elapsed_time > self._execution_timeout_seconds: + log.error( + "Task execution timeout exceeded; terminating process", + timeout_seconds=self._execution_timeout_seconds, + elapsed_seconds=elapsed_time, + ti_id=self.id, + pid=self.pid, + ) + self.process_log.error( + "Task execution timeout exceeded. Terminating process.", + timeout_seconds=self._execution_timeout_seconds, + elapsed_seconds=elapsed_time, + ) + + # Kill the process with signal escalation (SIGTERM -> SIGKILL) + self.kill(signal.SIGTERM, force=True) + + # Only set terminal state if the task didn't already respond with one + if not self._terminal_state: + self._terminal_state = TaskInstanceState.FAILED + self._task_end_time_monotonic = time.monotonic() + def _send_heartbeat_if_needed(self): """Send a heartbeat to the client if heartbeat interval has passed.""" # Respect the minimum interval between heartbeat attempts @@ -1409,6 +1469,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id) resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp) dump_opts = {"exclude_unset": True} + elif isinstance(msg, SetTaskExecutionTimeout): + self._execution_timeout_seconds = msg.execution_timeout_seconds + self._execution_timeout_set_time = time.monotonic() + resp = None elif isinstance(msg, ResendLoggingFD): # We need special handling here! if send_fds is not None: @@ -1532,6 +1596,8 @@ def _check_subprocess_exit( def _handle_socket_comms(self): while self._open_sockets: self._service_subprocess(1.0) + # Check for execution timeout in the background thread + self._check_task_timeout() @contextlib.contextmanager def _setup_subprocess_socket(self): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 82807bba902b9..942f5b39423a8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -82,6 +82,7 @@ SentFDs, SetRenderedFields, SetRenderedMapIndex, + SetTaskExecutionTimeout, SkipDownstreamTasks, StartupDetails, SucceedTask, @@ -819,6 +820,12 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv # update the value of the task that is sent from there context["task"] = ti.task + # Send execution timeout to supervisor so it can handle timeout in parent process + if ti.task.execution_timeout: + timeout_seconds = ti.task.execution_timeout.total_seconds() + log.debug("Sending execution_timeout to supervisor", timeout_seconds=timeout_seconds) + SUPERVISOR_COMMS.send(msg=SetTaskExecutionTimeout(execution_timeout_seconds=timeout_seconds)) + jinja_env = ti.task.dag.get_template_env() ti.render_templates(context=context, jinja_env=jinja_env) @@ -921,6 +928,8 @@ def _on_term(signum, frame): return ti.task.on_kill() + # Raise exception to cause task to exit gracefully after cleanup + raise AirflowTaskTerminated("Task terminated by timeout") signal.signal(signal.SIGTERM, _on_term) @@ -1314,23 +1323,8 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): _run_task_state_change_callbacks(task, "on_execute_callback", context, log) - if task.execution_timeout: - from airflow.sdk.execution_time.timeout import timeout - - # TODO: handle timeout in case of deferral - timeout_seconds = task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ctx.run(execute, context=context) - except AirflowTaskTimeout: - task.on_kill() - raise - else: - result = ctx.run(execute, context=context) + # Timeout is now handled by the supervisor process, not in the task process + result = ctx.run(execute, context=context) if (post_execute_hook := task._post_execute_hook) is not None: create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 285e36ca7d881..0e5beb5b69681 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -108,6 +108,7 @@ SentFDs, SetRenderedFields, SetRenderedMapIndex, + SetTaskExecutionTimeout, SetXCom, SkipDownstreamTasks, SucceedTask, @@ -224,6 +225,51 @@ def test_supervise( with expectation: supervise(**kw) + def test_supervisor_enforces_execution_timeout(self, captured_logs, client_with_ti_start): + """ + Test that the supervisor enforces execution_timeout and kills the task. + """ + ti = TaskInstance( + id=uuid7(), + task_id="task_with_timeout", + dag_id="timeout_test", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ) + + def subprocess_main(): + comms = CommsDecoder() + comms._get_response() + + comms.send(SetTaskExecutionTimeout(execution_timeout_seconds=1.0)) + + sleep(10) + + start_time = time.time() + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=ti, + client=client_with_ti_start, + target=subprocess_main, + ) + + exit_code = proc.wait() + elapsed_time = time.time() - start_time + + assert exit_code in (-9, -15), f"Expected exit code -9 (SIGKILL) or -15 (SIGTERM), got {exit_code}" + + # We allow up to 15s to account for scheduling delays + assert elapsed_time < 15.0, ( + f"Task ran for {elapsed_time}s but should have been killed (expected < 15s)" + ) + # Task should run for at least the timeout duration (1s) + assert elapsed_time >= 1.0, ( + f"Task killed too early: {elapsed_time}s (expected at least 1s for timeout)" + ) + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @@ -2352,6 +2398,11 @@ class RequestTestCase: }, test_id="get_task_breadcrumbs", ), + RequestTestCase( + message=SetTaskExecutionTimeout(execution_timeout_seconds=10.0), + test_id="set_task_execution_timeout", + expected_body=None, + ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 7942bfdf57ccf..81fbdb1dc0c06 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -21,9 +21,10 @@ import functools import json import os +import signal import textwrap import time -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -41,7 +42,6 @@ AirflowSensorTimeout, AirflowSkipException, AirflowTaskTerminated, - AirflowTaskTimeout, DownstreamTasksSkipped, ) from airflow.listeners import hookimpl @@ -117,7 +117,6 @@ from airflow.sdk.execution_time.task_runner import ( RuntimeTaskInstance, TaskRunnerMarker, - _execute_task, _push_xcom_if_needed, _xcom_push, finalize, @@ -534,45 +533,32 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) -def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms): - """Test running a basic task that times out.""" - from time import sleep +def test_run_sigterm_handler_invokes_on_kill(create_runtime_ti, mock_supervisor_comms, monkeypatch): + """Verify that the SIGTERM handler registered by run() calls the task's on_kill hook.""" - task = PythonOperator( - task_id="sleep", - execution_timeout=timedelta(milliseconds=10), - python_callable=lambda: sleep(2), - ) + task = PythonOperator(task_id="sigterm_task", python_callable=lambda: None) + ti = create_runtime_ti(task=task, dag_id="sigterm_dag") - ti = create_runtime_ti(task=task, dag_id="basic_dag_time_out") + # Replace on_kill with a spy so we can assert it was triggered by the handler. + ti.task.on_kill = mock.Mock() - instant = timezone.datetime(2024, 12, 3, 10, 0) - time_machine.move_to(instant, tick=False) + captured_handlers: dict[int, Callable[[int, object | None], None]] = {} - run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + def capture_signal(sig, handler): + captured_handlers[sig] = handler + return mock.Mock(name="previous_handler") - assert ti.state == TaskInstanceState.FAILED + monkeypatch.setattr(signal, "signal", capture_signal) - # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) - - -def test_execution_timeout(create_runtime_ti): - def sleep_and_catch_other_exceptions(): - with contextlib.suppress(Exception): - # Catching Exception should NOT catch AirflowTaskTimeout - time.sleep(5) + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) - op = PythonOperator( - task_id="test_timeout", - execution_timeout=timedelta(seconds=1), - python_callable=sleep_and_catch_other_exceptions, - ) + assert signal.SIGTERM in captured_handlers, "SIGTERM handler was not registered" - ti = create_runtime_ti(task=op, dag_id="dag_execution_timeout") + handler = captured_handlers[signal.SIGTERM] - with pytest.raises(AirflowTaskTimeout): - _execute_task(context=ti.get_template_context(), ti=ti, log=mock.MagicMock()) + ti.task.on_kill.assert_not_called() + handler(signal.SIGTERM, None) + ti.task.on_kill.assert_called_once_with() def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): @@ -3302,7 +3288,6 @@ def execute(self, context): def test_task_runner_both_callbacks_have_timing_info(self, create_runtime_ti): """Test that both success and failure callbacks receive accurate timing information.""" - import time from airflow.exceptions import AirflowException