From 5dc5ba2789594e19d9f5b3108c029aa1f53009cb Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 10 Nov 2025 15:00:25 +0000 Subject: [PATCH 1/4] Support for "reconnecting" Supervisor Comms from task process when `dag.test()` are used This is a follow up to #57212, which worked fine "at run time" but did not work in many of our own unit tests, which rely on `dag.test` or `ti.run`. The way this is implemented is that when we use the InProcessTestSupervisor we pre-emptively create a socket pair. We have to create it even it its not being used, as we can't know. And since this is all in one process we create a thread to handle the socket comms. Since this is only ever for tests performance or hitting the GIL doesn't matter. --- .../airflow/sdk/execution_time/supervisor.py | 60 ++++++++++++++++--- .../airflow/sdk/execution_time/task_runner.py | 6 +- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 7211423807eb6..ecf6eb34c288c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,6 +27,7 @@ import selectors import signal import sys +import threading import time import weakref from collections import deque @@ -1495,6 +1496,45 @@ def request(self, *args, **kwargs): # Bypass the tenacity retries! return super().request.__wrapped__(self, *args, **kwargs) # type: ignore[attr-defined] + def _check_subprocess_exit( + self, raise_on_timeout: bool = False, expect_signal: None | int = None + ) -> int | None: + # In process has no subprocess, so we don't need to poll anything. This is called from + # _service_subprocess, so we need to override it + return None + + def _handle_socket_comms(self): + + while self._open_sockets: + self._service_subprocess(1.0) + + @contextlib.contextmanager + def _setup_subprocess_socket(self): + thread = threading.Thread(target=self._handle_socket_comms, daemon=True) + + requests, child_sock = socketpair() + + self._open_sockets[requests] = "requests" + self.stdin = requests + + self.selector.register( + requests, + selectors.EVENT_READ, + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), + ) + os.set_inheritable(child_sock.fileno(), True) + os.environ["__AIRFLOW_SUPERVISOR_FD"] = str(child_sock.fileno()) + + try: + thread.start() + yield child_sock + finally: + requests.close() + child_sock.close() + self._on_socket_closed(requests) + thread.join(0) + os.environ.pop("__AIRFLOW_SUPERVISOR_FD", None) + @classmethod def start( # type: ignore[override] cls, @@ -1547,16 +1587,20 @@ def start( # type: ignore[override] start_date=start_date, state=TaskInstanceState.RUNNING, ) - context = ti.get_template_context() - log = structlog.get_logger(logger_name="task") - state, msg, error = run(ti, context, log) - finalize(ti, state, context, log, error) + # Create a socketpair pre-emptively, in case the task process runs VirtualEnv operator or run as + # user. + with supervisor._setup_subprocess_socket(): + context = ti.get_template_context() + log = structlog.get_logger(logger_name="task") + + state, msg, error = run(ti, context, log) + finalize(ti, state, context, log, error) - # In the normal subprocess model, the task runner calls this before exiting. - # Since we're running in-process, we manually notify the API server that - # the task has finished—unless the terminal state was already sent explicitly. - supervisor.update_task_state_if_needed() + # In the normal subprocess model, the task runner calls this before exiting. + # Since we're running in-process, we manually notify the API server that + # the task has finished—unless the terminal state was already sent explicitly. + supervisor.update_task_state_if_needed() return TaskRunResult(ti=ti, state=state, msg=msg, error=error) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 8ca403f3c96ca..26c1e5ee762b2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1491,11 +1491,15 @@ def reinit_supervisor_comms() -> None: run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks can continue to access variables etc. """ + import socket + if "SUPERVISOR_COMMS" not in globals(): global SUPERVISOR_COMMS log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) + fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) + + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) if isinstance(logs, SentFDs): From 0006975e00b7413f1e3fbb2b746d23b74a93e22b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 10 Nov 2025 16:05:40 +0000 Subject: [PATCH 2/4] fixup! Support for "reconnecting" Supervisor Comms from task process when `dag.test()` are used --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index ecf6eb34c288c..52f4bcd3e903f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1504,7 +1504,6 @@ def _check_subprocess_exit( return None def _handle_socket_comms(self): - while self._open_sockets: self._service_subprocess(1.0) @@ -1588,7 +1587,7 @@ def start( # type: ignore[override] state=TaskInstanceState.RUNNING, ) - # Create a socketpair pre-emptively, in case the task process runs VirtualEnv operator or run as + # Create a socketpair preemptively, in case the task process runs VirtualEnv operator or run as # user. with supervisor._setup_subprocess_socket(): context = ti.get_template_context() From 6d6112574fdf9e5921b42c1f02685d65c3cc4a0b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 12 Nov 2025 19:19:55 +0000 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Amogh Desai --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 52f4bcd3e903f..844efa2f987d8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1499,8 +1499,8 @@ def request(self, *args, **kwargs): def _check_subprocess_exit( self, raise_on_timeout: bool = False, expect_signal: None | int = None ) -> int | None: - # In process has no subprocess, so we don't need to poll anything. This is called from - # _service_subprocess, so we need to override it + # InProcessSupervisor has no subprocess, so we don't need to poll anything. This is called from + # _handle_socket_comms, so we need to override it return None def _handle_socket_comms(self): @@ -1587,8 +1587,7 @@ def start( # type: ignore[override] state=TaskInstanceState.RUNNING, ) - # Create a socketpair preemptively, in case the task process runs VirtualEnv operator or run as - # user. + # Create a socketpair pre-emptively, in case the task process runs VirtualEnv operator or run_as_user with supervisor._setup_subprocess_socket(): context = ti.get_template_context() log = structlog.get_logger(logger_name="task") From 3c6b86de39d72ab3f734ec9fa697505b0d25b366 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 13 Nov 2025 08:34:08 +0000 Subject: [PATCH 4/4] Update task-sdk/src/airflow/sdk/execution_time/supervisor.py --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 844efa2f987d8..748cd18187804 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1587,7 +1587,7 @@ def start( # type: ignore[override] state=TaskInstanceState.RUNNING, ) - # Create a socketpair pre-emptively, in case the task process runs VirtualEnv operator or run_as_user + # Create a socketpair preemptively, in case the task process runs VirtualEnv operator or run_as_user with supervisor._setup_subprocess_socket(): context = ti.get_template_context() log = structlog.get_logger(logger_name="task")