Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,11 @@ def _run_raw_task(
if taskrun_result is None:
return None
if taskrun_result.error:
raise taskrun_result.error
from airflow.exceptions import AirflowTaskTerminated

# Don't re-raise AirflowTaskTerminated
if not isinstance(taskrun_result.error, AirflowTaskTerminated):
raise taskrun_result.error
self.task = taskrun_result.ti.task # type: ignore[assignment]
return None

Expand Down
1 change: 1 addition & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ def get_type_names(union_type):
"UpdateHITLDetail",
"GetHITLDetailResponse",
"SetRenderedMapIndex",
"SetTaskExecutionTimeout",
}

in_task_runner_but_not_in_dag_processing_process = {
Expand Down
1 change: 1 addition & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ def get_type_names(union_type):
"ResendLoggingFD",
"CreateHITLDetailPayload",
"SetRenderedMapIndex",
"SetTaskExecutionTimeout",
}

in_task_but_not_in_trigger_runner = {
Expand Down
22 changes: 14 additions & 8 deletions providers/ssh/tests/unit/ssh/operators/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
import pytest
from paramiko.client import SSHClient

from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import TaskInstance
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils.state import State

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.dag import sync_dag_to_db
Expand Down Expand Up @@ -284,27 +285,32 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker):

def test_timeout_triggers_on_kill(self, request, dag_maker):
def command_sleep_forever(*args, **kwargs):
time.sleep(100) # This will be interrupted by the timeout
time.sleep(10) # This will be interrupted by the supervisor timeout

self.exec_ssh_client_command.side_effect = command_sleep_forever

with dag_maker(dag_id=f"dag_{request.node.name}"):
_ = SSHOperator(
task = SSHOperator(
task_id="test_timeout",
ssh_hook=self.hook,
command="sleep 100",
execution_timeout=timedelta(seconds=1),
)
dr = dag_maker.create_dagrun(run_id="test_timeout")

# With supervisor-based timeout, the task is killed externally by the supervisor
# The on_kill handler should be called via SIGTERM signal handler
with mock.patch.object(SSHOperator, "on_kill") as mock_on_kill:
with pytest.raises(AirflowTaskTimeout):
dag_maker.run_ti("test_timeout", dr)
dag_maker.run_ti("test_timeout", dr)

# Wait a bit to ensure on_kill has time to be called
time.sleep(1)
# Verify the task failed due to timeout
ti = dr.get_task_instance("test_timeout")
assert ti.state == State.FAILED

mock_on_kill.assert_called_once()
# on_kill should have been called when the task process received SIGTERM
# Note: This may be flaky in tests due to signal timing
time.sleep(1)
mock_on_kill.assert_called()

def test_remote_host_passed_at_hook_init(self):
remote_host = "test_host.internal"
Expand Down
51 changes: 44 additions & 7 deletions providers/standard/tests/unit/standard/operators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import pytest

from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.standard.operators.bash import BashOperator
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -245,23 +245,60 @@ def test_bash_operator_output_processor(self, context):

@pytest.mark.db_test
def test_bash_operator_kill(self, dag_maker):
import time as time_module

import psutil

sleep_time = f"100{os.getpid()}"
with dag_maker(serialized=True):
BashOperator(
task_id="test_bash_operator_kill",
execution_timeout=timedelta(microseconds=25),
execution_timeout=timedelta(seconds=2), # Use 2 seconds for more reliable testing
bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
)
dr = dag_maker.create_dagrun()
with pytest.raises(AirflowTaskTimeout):

# With supervisor-based timeout, the task is killed externally
# The task should complete (with failure state) without raising AirflowTaskTimeout
start = time_module.time()
try:
dag_maker.run_ti("test_bash_operator_kill", dr)
sleep(2)
except Exception as e:
# Log any unexpected exceptions
print(f"Unexpected exception during run_ti: {type(e).__name__}: {e}")
raise
duration = time_module.time() - start

print(f"Task completed in {duration:.2f} seconds")

# Should complete within reasonable time (timeout + escalation)
# With 2s timeout and 2s escalation, should be < 10s
assert duration < 10, f"Task took {duration}s, expected < 10s"

# Verify the task failed due to timeout
ti = dr.get_task_instance("test_bash_operator_kill")
print(f"Task state: {ti.state}")
assert ti.state == State.FAILED, f"Expected task to be FAILED, but got {ti.state}"

# Give a moment for cleanup
sleep(1)

# Verify the subprocess was properly killed and is not still running
still_running = []
for proc in psutil.process_iter():
if proc.cmdline() == ["sleep", sleep_time]:
os.kill(proc.pid, signal.SIGTERM)
pytest.fail("BashOperator's subprocess still running after stopping on timeout!")
try:
if proc.cmdline() == ["sleep", sleep_time]:
still_running.append(proc.pid)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass

if still_running:
for pid in still_running:
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
pass
pytest.fail(f"BashOperator's subprocess(es) still running after timeout: {still_running}")

@pytest.mark.db_test
def test_templated_fields(self, create_task_instance_of_operator):
Expand Down
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,13 @@ class SetRenderedMapIndex(BaseModel):
type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex"


class SetTaskExecutionTimeout(BaseModel):
"""Payload for setting execution_timeout for a task instance."""

execution_timeout_seconds: float | None
type: Literal["SetTaskExecutionTimeout"] = "SetTaskExecutionTimeout"


class TriggerDagRun(TriggerDAGRunPayload):
dag_id: str
run_id: Annotated[str, Field(title="Dag Run Id")]
Expand Down Expand Up @@ -978,6 +985,7 @@ class MaskSecret(BaseModel):
| RetryTask
| SetRenderedFields
| SetRenderedMapIndex
| SetTaskExecutionTimeout
| SetXCom
| SkipDownstreamTasks
| SucceedTask
Expand Down
82 changes: 74 additions & 8 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
SentFDs,
SetRenderedFields,
SetRenderedMapIndex,
SetTaskExecutionTimeout,
SetXCom,
SkipDownstreamTasks,
StartupDetails,
Expand Down Expand Up @@ -943,6 +944,12 @@ class ActivitySubprocess(WatchedSubprocess):
_task_end_time_monotonic: float | None = attrs.field(default=None, init=False)
_rendered_map_index: str | None = attrs.field(default=None, init=False)

_execution_timeout_seconds: float | None = attrs.field(default=None, init=False)
"""The execution timeout in seconds, if set by the task."""

_execution_timeout_set_time: float | None = attrs.field(default=None, init=False)
"""When the execution timeout was set (monotonic time). Timer starts from here."""

decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor)

ti: RuntimeTI | None = None
Expand Down Expand Up @@ -1072,14 +1079,23 @@ def _monitor_subprocess(self):
last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
# so we notice the subprocess finishing as quick as we can.
max_wait_time = max(
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through the task instance heartbeat timeout time
HEARTBEAT_TIMEOUT - last_heartbeat_ago * 0.75,
MIN_HEARTBEAT_INTERVAL,
),
)
wait_times = [
# Ensure we heartbeat _at most_ 75% through the task instance heartbeat timeout time
HEARTBEAT_TIMEOUT - last_heartbeat_ago * 0.75,
MIN_HEARTBEAT_INTERVAL,
# Cap at 1 second to ensure we check for new timeout messages frequently
1.0,
]

# If execution timeout is set, also wake up to check it
if self._execution_timeout_seconds is not None and self._execution_timeout_set_time is not None:
elapsed = time.monotonic() - self._execution_timeout_set_time
time_until_timeout = self._execution_timeout_seconds - elapsed
# Wake up 100ms before timeout or when timeout is reached
wait_times.append(max(0.1, time_until_timeout))

max_wait_time = max(0, min(wait_times)) # Make sure this value is never negative

# Block until events are ready or the timeout is reached
# This listens for activity (e.g., subprocess output) on registered file objects
alive = self._service_subprocess(max_wait_time=max_wait_time) is None
Expand All @@ -1103,6 +1119,9 @@ def _monitor_subprocess(self):
# logs
self._send_heartbeat_if_needed()

# Check if task has exceeded execution_timeout
self._check_task_timeout()

self._handle_process_overtime_if_needed()

def _handle_process_overtime_if_needed(self):
Expand All @@ -1122,6 +1141,47 @@ def _handle_process_overtime_if_needed(self):
)
self.kill(signal.SIGTERM, force=True)

def _check_task_timeout(self):
"""
Check if task has exceeded execution_timeout and kill it if necessary.

This handles task timeout at the supervisor level rather than in the task
process itself.

The method implements signal escalation: SIGTERM -> SIGKILL if process doesn't exit.
"""
# Only check timeout if we have a timeout set
if self._execution_timeout_seconds is None or self._execution_timeout_set_time is None:
return

# Don't check timeout if task has already reached a terminal state
if self._terminal_state:
return

elapsed_time = time.monotonic() - self._execution_timeout_set_time

if elapsed_time > self._execution_timeout_seconds:
log.error(
"Task execution timeout exceeded; terminating process",
timeout_seconds=self._execution_timeout_seconds,
elapsed_seconds=elapsed_time,
ti_id=self.id,
pid=self.pid,
)
self.process_log.error(
"Task execution timeout exceeded. Terminating process.",
timeout_seconds=self._execution_timeout_seconds,
elapsed_seconds=elapsed_time,
)

# Kill the process with signal escalation (SIGTERM -> SIGKILL)
self.kill(signal.SIGTERM, force=True)

# Only set terminal state if the task didn't already respond with one
if not self._terminal_state:
self._terminal_state = TaskInstanceState.FAILED
self._task_end_time_monotonic = time.monotonic()

def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
# Respect the minimum interval between heartbeat attempts
Expand Down Expand Up @@ -1409,6 +1469,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)
dump_opts = {"exclude_unset": True}
elif isinstance(msg, SetTaskExecutionTimeout):
self._execution_timeout_seconds = msg.execution_timeout_seconds
self._execution_timeout_set_time = time.monotonic()
resp = None
elif isinstance(msg, ResendLoggingFD):
# We need special handling here!
if send_fds is not None:
Expand Down Expand Up @@ -1532,6 +1596,8 @@ def _check_subprocess_exit(
def _handle_socket_comms(self):
while self._open_sockets:
self._service_subprocess(1.0)
# Check for execution timeout in the background thread
self._check_task_timeout()

@contextlib.contextmanager
def _setup_subprocess_socket(self):
Expand Down
28 changes: 11 additions & 17 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
SentFDs,
SetRenderedFields,
SetRenderedMapIndex,
SetTaskExecutionTimeout,
SkipDownstreamTasks,
StartupDetails,
SucceedTask,
Expand Down Expand Up @@ -819,6 +820,12 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv
# update the value of the task that is sent from there
context["task"] = ti.task

# Send execution timeout to supervisor so it can handle timeout in parent process
if ti.task.execution_timeout:
timeout_seconds = ti.task.execution_timeout.total_seconds()
log.debug("Sending execution_timeout to supervisor", timeout_seconds=timeout_seconds)
SUPERVISOR_COMMS.send(msg=SetTaskExecutionTimeout(execution_timeout_seconds=timeout_seconds))

jinja_env = ti.task.dag.get_template_env()
ti.render_templates(context=context, jinja_env=jinja_env)

Expand Down Expand Up @@ -921,6 +928,8 @@ def _on_term(signum, frame):
return

ti.task.on_kill()
# Raise exception to cause task to exit gracefully after cleanup
raise AirflowTaskTerminated("Task terminated by timeout")

signal.signal(signal.SIGTERM, _on_term)

Expand Down Expand Up @@ -1314,23 +1323,8 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

_run_task_state_change_callbacks(task, "on_execute_callback", context, log)

if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
# Timeout is now handled by the supervisor process, not in the task process
result = ctx.run(execute, context=context)

if (post_execute_hook := task._post_execute_hook) is not None:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
Expand Down
Loading
Loading