Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,19 +390,37 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# - Update RTIF
# - Pre Execute
# etc
result = ti.task.execute(context) # type: ignore[attr-defined]
_push_xcom_if_needed(result, ti)

result = None
if ti.task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout

timeout_seconds = ti.task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ti.task.execute(context) # type: ignore[attr-defined]
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
result = ti.task.execute(context) # type: ignore[attr-defined]

_push_xcom_if_needed(result, ti)
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
next_method = defer.method_name
timeout = defer.timeout
defer_timeout = defer.timeout
msg = DeferTask(
classpath=classpath,
trigger_kwargs=trigger_kwargs,
next_method=next_method,
trigger_timeout=timeout,
trigger_timeout=defer_timeout,
)
except AirflowSkipException:
msg = TaskState(
Expand All @@ -423,13 +441,14 @@ def run(ti: RuntimeTaskInstance, log: Logger):
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)

# TODO: Run task failure callbacks here
except AirflowTaskTimeout:
# TODO: handle the case of up_for_retry here
# TODO: coagulate this exception handling with AirflowException
# once https://github.com/apache/airflow/issues/45307 is handled
...
except (AirflowTaskTimeout, AirflowException):
# We should allow retries if the task has defined it.
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except AirflowException:
# TODO: handle the case of up_for_retry here
msg = TaskState(
Expand Down
42 changes: 42 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,48 @@ def test_run_raises_airflow_exception(time_machine, mocked_parse, make_ti_contex
)


def test_run_task_timeout(time_machine, mocked_parse, make_ti_context, mock_supervisor_comms):
"""Test running a basic task that times out."""
from time import sleep

from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="sleep",
execution_timeout=timedelta(milliseconds=10),
python_callable=lambda: sleep(2),
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="sleep",
dag_id="basic_dag_time_out",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_time_out", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

run(ti, log=mock.MagicMock())

# this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout
mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(
state=TerminalTIState.FAILED,
end_date=instant,
),
log=mock.ANY,
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator
Expand Down