From cb9182f632ba351f7234fdbfc11c7c776f7d382d Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 1 May 2025 18:40:06 +0530 Subject: [PATCH] Port `dag.test` to Task SDK closes https://github.com/apache/airflow/issues/45549 Key changes: - Moves `dag.test` implementation to Task SDK, leveraging the existing in-process execution infrastructure - Adds `JWTBearerTIPathDep` for proper task instance path validation - Updates `InProcessExecutionAPI` to support task instance validation - Removes legacy `dag.test` implementation from DAG class The changes ensure that `dag.test` uses the same execution path as regular task execution. --- .../airflow/api_fastapi/execution_api/app.py | 7 +- .../airflow/api_fastapi/execution_api/deps.py | 3 + .../execution_api/routes/task_instances.py | 6 +- .../src/airflow/cli/commands/dag_command.py | 1 - .../src/airflow/cli/commands/task_command.py | 3 +- .../src/airflow/dag_processing/processor.py | 6 +- airflow-core/src/airflow/models/dag.py | 321 +----------------- airflow-core/src/airflow/models/dagrun.py | 39 ++- .../unit/cli/commands/test_dag_command.py | 24 +- airflow-core/tests/unit/models/test_dag.py | 14 +- .../tests/unit/models/test_mappedoperator.py | 17 +- task-sdk/src/airflow/sdk/definitions/dag.py | 273 +++++++++++++++ .../airflow/sdk/execution_time/supervisor.py | 196 ++++++++++- .../airflow/sdk/execution_time/task_runner.py | 4 +- 14 files changed, 539 insertions(+), 375 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index ef51da9827943..691853f322a1f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -225,7 +225,11 @@ class InProcessExecutionAPI: def app(self): if not self._app: from airflow.api_fastapi.execution_api.app import create_task_execution_api_app - from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTRefresherDep + from airflow.api_fastapi.execution_api.deps import ( + JWTBearerDep, + JWTBearerTIPathDep, + JWTRefresherDep, + ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access @@ -235,6 +239,7 @@ def app(self): async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow + self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow self._app.dependency_overrides[JWTRefresherDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index 8106a7e81e37b..c2161180dbb46 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -96,6 +96,9 @@ async def __call__( # type: ignore[override] JWTBearerDep: TIToken = Depends(JWTBearer()) +# This checks that the UUID in the url matches the one in the token for us. +JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) + class JWTReissuer: """Re-issue JWTs to requests when they are about to run out.""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 1dfc9bb10e073..00f149f1d10c0 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -25,7 +25,7 @@ import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, Depends, HTTPException, Query, status +from fastapi import Body, HTTPException, Query, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError @@ -50,7 +50,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import JWTBearer +from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks @@ -70,7 +70,7 @@ ti_id_router = VersionedAPIRouter( dependencies=[ # This checks that the UUID in the url matches the one in the token for us. - Depends(JWTBearer(path_param_name="task_instance_id")), + JWTBearerTIPathDep ] ) diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index b1151f34091ea..1b4b017e6c793 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -644,7 +644,6 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No run_conf=run_conf, use_executor=use_executor, mark_success_pattern=mark_success_pattern, - session=session, ) show_dagrun = args.show_dagrun imgcat = args.imgcat_dagrun diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index 0a4e771315ba9..a3492a828c03a 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -33,8 +33,9 @@ from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string from airflow.exceptions import DagRunNotFound, TaskDeferred, TaskInstanceNotFound from airflow.models import TaskInstance -from airflow.models.dag import DAG, _run_inline_trigger +from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.sdk.definitions.dag import _run_inline_trigger from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.execution_time.secrets_masker import RedactedIO from airflow.ti_deps.dep_context import DepContext diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index ef3e5cb69ca56..73d2c23c7f5b2 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -161,7 +161,11 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil callbacks = callbacks if isinstance(callbacks, list) else [callbacks] # TODO:We need a proper context object! - context: Context = {} # type: ignore[assignment] + context: Context = { # type: ignore[assignment] + "dag": dag, + "run_id": request.run_id, + "reason": request.msg, + } for callback in callbacks: log.info( diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 8024f417ef480..57c1050e3ce94 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -17,20 +17,14 @@ # under the License. from __future__ import annotations -import asyncio import copy import functools import logging import re -import sys -import time from collections import defaultdict from collections.abc import Collection, Generator, Iterable, Sequence -from contextlib import ExitStack from datetime import datetime, timedelta from functools import cache -from pathlib import Path -from re import Pattern from typing import ( TYPE_CHECKING, Any, @@ -70,14 +64,12 @@ from airflow import settings, utils from airflow.assets.evaluation import AssetEvaluator -from airflow.configuration import conf as airflow_conf, secrets_backend_list +from airflow.configuration import conf as airflow_conf from airflow.exceptions import ( AirflowException, - TaskDeferred, UnknownExecutorException, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.executors.workloads import BundleInfo from airflow.models.asset import ( AssetDagRunQueue, AssetModel, @@ -95,9 +87,7 @@ from airflow.sdk import TaskGroup from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as task_sdk_dag_decorator -from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.settings import json -from airflow.stats import Stats from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import ( @@ -111,7 +101,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: @@ -121,7 +111,6 @@ from airflow.models.dagbag import DagBag from airflow.models.operator import Operator - from airflow.sdk.definitions._internal.abstractoperator import TaskStateChangeCallback from airflow.serialization.serialized_objects import MaybeSerializedDAG from airflow.typing_compat import Literal @@ -777,89 +766,6 @@ def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"} - @staticmethod - @provide_session - def fetch_callback( - dag: DAG, - run_id: str, - success: bool = True, - reason: str | None = None, - *, - session: Session = NEW_SESSION, - ) -> tuple[list[TaskStateChangeCallback], Context] | None: - """ - Fetch the appropriate callbacks depending on the value of success. - - This method gets the context of a single TaskInstance part of this DagRun and returns it along - the list of callbacks. - - :param dag: DAG object - :param run_id: The DAG run ID - :param success: Flag to specify if failure or success callback should be called - :param reason: Completion reason - :param session: Database session - """ - callbacks = dag.on_success_callback if success else dag.on_failure_callback - if callbacks: - dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=run_id, session=session) - callbacks = callbacks if isinstance(callbacks, list) else [callbacks] - tis = dagrun.get_task_instances(session=session) - # tis from a dagrun may not be a part of dag.partial_subset, - # since dag.partial_subset is a subset of the dag. - # This ensures that we will only use the accessible TI - # context for the callback. - if dag.partial: - tis = [ti for ti in tis if not ti.state == State.NONE] - # filter out removed tasks - tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED] - ti = tis[-1] # get first TaskInstance of DagRun - ti.task = dag.get_task(ti.task_id) - context = ti.get_template_context(session=session) - context["reason"] = reason - return callbacks, context - return None - - @provide_session - def handle_callback(self, dagrun: DagRun, success=True, reason=None, session=NEW_SESSION): - """ - Triggers on_failure_callback or on_success_callback as appropriate. - - This method gets the context of a single TaskInstance part of this DagRun - and passes that to the callable along with a 'reason', primarily to - differentiate DagRun failures. - - .. note: The logs end up in - ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` - - :param dagrun: DagRun object - :param success: Flag to specify if failure or success callback should be called - :param reason: Completion reason - :param session: Database session - """ - callbacks, context = DAG.fetch_callback( - dag=self, run_id=dagrun.run_id, success=success, reason=reason, session=session - ) or (None, None) - - DAG.execute_callback(callbacks, context, self.dag_id) - - @classmethod - def execute_callback(cls, callbacks: list[Callable] | None, context: Context | None, dag_id: str): - """ - Triggers the callbacks with the given context. - - :param callbacks: List of callbacks to call - :param context: Context to pass to all callbacks - :param dag_id: The dag_id of the DAG to find. - """ - if callbacks and context: - for callback in callbacks: - cls.logger().info("Executing dag callback function: %s", callback) - try: - callback(context) - except Exception: - cls.logger().exception("failed to invoke dag state update callback") - Stats.incr("dag.callback_exceptions", tags={"dag_id": dag_id}) - def get_active_runs(self): """ Return a list of dag run logical dates currently running. @@ -1603,188 +1509,6 @@ def cli(self): args = parser.parse_args() args.func(args, self) - @provide_session - def test( - self, - run_after: datetime | None = None, - logical_date: datetime | None = None, - run_conf: dict[str, Any] | None = None, - conn_file_path: str | None = None, - variable_file_path: str | None = None, - use_executor: bool = False, - mark_success_pattern: Pattern | str | None = None, - session: Session = NEW_SESSION, - ) -> DagRun: - """ - Execute one single DagRun for a given DAG and logical date. - - :param run_after: the datetime before which to Dag cannot run. - :param logical_date: logical date for the DAG run - :param run_conf: configuration to pass to newly created dagrun - :param conn_file_path: file path to a connection file in either yaml or json - :param variable_file_path: file path to a variable file in either yaml or json - :param use_executor: if set, uses an executor to test the DAG - :param mark_success_pattern: regex of task_ids to mark as success instead of running - :param session: database connection (optional) - """ - from airflow.serialization.serialized_objects import SerializedDAG - - def add_logger_if_needed(ti: TaskInstance): - """ - Add a formatted logger to the task instance. - - This allows all logs to surface to the command line, instead of into - a task file. Since this is a local test run, it is much better for - the user to see logs in the command line, rather than needing to - search for a log file. - - :param ti: The task instance that will receive a logger. - """ - format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") - handler = logging.StreamHandler(sys.stdout) - handler.level = logging.INFO - handler.setFormatter(format) - # only add log handler once - if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): - self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) - ti.log.addHandler(handler) - - exit_stack = ExitStack() - if conn_file_path or variable_file_path: - local_secrets = LocalFilesystemBackend( - variables_file_path=variable_file_path, connections_file_path=conn_file_path - ) - secrets_backend_list.insert(0, local_secrets) - exit_stack.callback(lambda: secrets_backend_list.pop(0)) - - with exit_stack: - self.validate() - self.log.debug("Clearing existing task instances for logical date %s", logical_date) - self.clear( - start_date=logical_date, - end_date=logical_date, - dag_run_state=False, # type: ignore - session=session, - ) - self.log.debug("Getting dagrun for dag %s", self.dag_id) - logical_date = timezone.coerce_datetime(logical_date) - run_after = timezone.coerce_datetime(run_after) or timezone.coerce_datetime(timezone.utcnow()) - data_interval = ( - self.timetable.infer_manual_data_interval(run_after=logical_date) if logical_date else None - ) - scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) - - dr: DagRun = _get_or_create_dagrun( - dag=scheduler_dag, - start_date=logical_date or run_after, - logical_date=logical_date, - data_interval=data_interval, - run_after=run_after, - run_id=DagRun.generate_run_id( - run_type=DagRunType.MANUAL, - logical_date=logical_date, - run_after=run_after, - ), - session=session, - conf=run_conf, - triggered_by=DagRunTriggeredByType.TEST, - ) - # Start a mock span so that one is present and not started downstream. We - # don't care about otel in dag.test and starting the span during dagrun update - # is not functioning properly in this context anyway. - dr.start_dr_spans_if_needed(tis=[]) - - tasks = self.task_dict - self.log.debug("starting dagrun") - # Instead of starting a scheduler, we run the minimal loop possible to check - # for task readiness and dependency management. - - # ``Dag.test()`` works in two different modes depending on ``use_executor``: - # - if ``use_executor`` is False, runs the task locally with no executor using ``_run_task`` - # - if ``use_executor`` is True, sends the task instances to the executor with - # ``BaseExecutor.queue_task_instance`` - if use_executor: - from airflow.models.dagbag import DagBag - - dag_bag = DagBag() - dag_bag.bag_dag(self) - - executor = ExecutorLoader.get_default_executor() - executor.start() - - while dr.state == DagRunState.RUNNING: - session.expire_all() - schedulable_tis, _ = dr.update_state(session=session) - for s in schedulable_tis: - if s.state != TaskInstanceState.UP_FOR_RESCHEDULE: - s.try_number += 1 - s.state = TaskInstanceState.SCHEDULED - s.scheduled_dttm = timezone.utcnow() - session.commit() - # triggerer may mark tasks scheduled so we read from DB - all_tis = set(dr.get_task_instances(session=session)) - scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED} - ids_unrunnable = {x for x in all_tis if x.state not in State.finished} - scheduled_tis - if not scheduled_tis and ids_unrunnable: - self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) - time.sleep(1) - - triggerer_running = _triggerer_is_healthy(session) - for ti in scheduled_tis: - ti.task = tasks[ti.task_id] - - mark_success = ( - re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None - if mark_success_pattern is not None - else False - ) - - if use_executor: - if executor.has_task(ti): - continue - # TODO: Task-SDK: This check is transitionary. Remove once all executors are ported over. - from airflow.executors import workloads - from airflow.executors.base_executor import BaseExecutor - - if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined] - workload = workloads.ExecuteTask.make( - ti, - dag_rel_path=Path(self.fileloc), - generator=executor.jwt_generator, - # For the system test/debug purpose, we use the default bundle which uses - # local file system. If it turns out to be a feature people want, we could - # plumb the Bundle to use as a parameter to dag.test - bundle_info=BundleInfo(name="dags-folder"), - ) - executor.queue_workload(workload, session=session) - ti.state = TaskInstanceState.QUEUED - session.commit() - else: - # Send the task to the executor - executor.queue_task_instance(ti, ignore_ti_state=True) - else: - # Run the task locally - try: - add_logger_if_needed(ti) - _run_task( - ti=ti, - inline_trigger=not triggerer_running, - session=session, - mark_success=mark_success, - ) - except Exception: - self.log.exception("Task failed; ti=%s", ti) - if use_executor: - executor.heartbeat() - from airflow.jobs.scheduler_job_runner import SchedulerDagBag, SchedulerJobRunner - - SchedulerJobRunner.process_executor_events( - executor=executor, job_id=None, scheduler_dag_bag=SchedulerDagBag(), session=session - ) - if use_executor: - executor.end() - return dr - @provide_session def create_dagrun( self, @@ -2535,47 +2259,6 @@ def get_asset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, """:sphinx-autoapi-skip:""" -def _run_inline_trigger(trigger): - async def _run_inline_trigger_main(): - # We can replace it with `return await anext(trigger.run(), default=None)` - # when we drop support for Python 3.9 - try: - return await trigger.run().__anext__() - except StopAsyncIteration: - return None - - return asyncio.run(_run_inline_trigger_main()) - - -def _run_task( - *, ti: TaskInstance, inline_trigger: bool = False, mark_success: bool = False, session: Session -): - """ - Run a single task instance, and push result to Xcom for downstream tasks. - - Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as - possible. This function is only meant for the `dag.test` function as a helper function. - - Args: - ti: TaskInstance to run - """ - log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) - while True: - try: - log.info("[DAG TEST] running task %s", ti) - ti._run_raw_task(session=session, raise_on_defer=inline_trigger, mark_success=mark_success) - break - except TaskDeferred as e: - log.info("[DAG TEST] running trigger in line") - event = _run_inline_trigger(e.trigger) - ti.next_method = e.method_name - ti.next_kwargs = {"event": event.payload} if event else e.kwargs - log.info("[DAG TEST] Trigger completed") - session.merge(ti) - session.commit() - log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) - - def _get_or_create_dagrun( *, dag: DAG, diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 138481643550e..11a65f8705543 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -100,6 +100,7 @@ from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion from airflow.models.operator import Operator + from airflow.sdk import DAG as SDKDAG, Context from airflow.typing_compat import Literal from airflow.utils.types import ArgNotSet @@ -1147,8 +1148,8 @@ def recalculate(self) -> _UnfinishedStates: self.set_state(DagRunState.FAILED) self.notify_dagrun_state_changed(msg="task_failure") - if execute_callbacks: - dag.handle_callback(self, success=False, reason="task_failure", session=session) + if execute_callbacks and dag.has_on_failure_callback: + self.handle_dag_callback(dag=dag, success=False, reason="task_failure") elif dag.has_on_failure_callback: callback = DagCallbackRequest( filepath=self.dag_model.relative_fileloc, @@ -1176,8 +1177,8 @@ def recalculate(self) -> _UnfinishedStates: self.set_state(DagRunState.SUCCESS) self.notify_dagrun_state_changed(msg="success") - if execute_callbacks: - dag.handle_callback(self, success=True, reason="success", session=session) + if execute_callbacks and dag.has_on_success_callback: + self.handle_dag_callback(dag=dag, success=True, reason="success") elif dag.has_on_success_callback: callback = DagCallbackRequest( filepath=self.dag_model.relative_fileloc, @@ -1195,8 +1196,8 @@ def recalculate(self) -> _UnfinishedStates: self.set_state(DagRunState.FAILED) self.notify_dagrun_state_changed(msg="all_tasks_deadlocked") - if execute_callbacks: - dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session) + if execute_callbacks and dag.has_on_failure_callback: + self.handle_dag_callback(dag=dag, success=False, reason="all_tasks_deadlocked") elif dag.has_on_failure_callback: callback = DagCallbackRequest( filepath=self.dag_model.relative_fileloc, @@ -1316,6 +1317,32 @@ def notify_dagrun_state_changed(self, msg: str = ""): # we can't get all the state changes on SchedulerJob, # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that + def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"): + """Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`.""" + context: Context = { # type: ignore[assignment] + "dag": dag, + "run_id": str(self.run_id), + "reason": reason, + } + + callbacks = dag.on_success_callback if success else dag.on_failure_callback + if not callbacks: + self.log.warning("Callback requested, but dag didn't have any for DAG: %s.", dag.dag_id) + return + callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + + for callback in callbacks: + self.log.info( + "Executing on_%s dag callback: %s", + "success" if success else "failure", + callback.__name__ if hasattr(callback, "__name__") else repr(callback), + ) + try: + callback(context) + except Exception: + self.log.exception("Callback failed for %s", dag.dag_id) + Stats.incr("dag.callback_exceptions", tags={"dag_id": dag.dag_id}) + def _get_ready_tis( self, schedulable_tis: list[TI], diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index 278248486bd58..d1a37fd539271 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -38,10 +38,10 @@ from airflow.exceptions import AirflowException from airflow.models import DagBag, DagModel, DagRun from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import _run_inline_trigger from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.sdk import task +from airflow.sdk.definitions.dag import _run_inline_trigger from airflow.triggers.base import TriggerEvent from airflow.utils import timezone from airflow.utils.session import create_session @@ -631,7 +631,6 @@ def test_dag_test(self, mock_parse_and_get_dag): run_conf=None, use_executor=False, mark_success_pattern=None, - session=mock.ANY, ), ] ) @@ -665,7 +664,6 @@ def test_dag_test_no_logical_date(self, mock_utcnow, mock_parse_and_get_dag): logical_date=mock.ANY, run_conf=None, use_executor=False, - session=mock.ANY, mark_success_pattern=None, ), ] @@ -693,7 +691,6 @@ def test_dag_test_conf(self, mock_parse_and_get_dag): logical_date=timezone.parse(DEFAULT_DATE.isoformat()), run_conf={"dag_run_conf_param": "param_value"}, use_executor=False, - session=mock.ANY, mark_success_pattern=None, ), ] @@ -722,7 +719,6 @@ def test_dag_test_show_dag(self, mock_parse_and_get_dag, mock_render_dag): logical_date=timezone.parse(DEFAULT_DATE.isoformat()), run_conf=None, use_executor=False, - session=mock.ANY, mark_success_pattern=None, ), ] @@ -773,7 +769,9 @@ def test_dag_test_run_inline_trigger(self, dag_maker): assert e.payload == now def test_dag_test_no_triggerer_running(self, dag_maker): - with mock.patch("airflow.models.dag._run_inline_trigger", wraps=_run_inline_trigger) as mock_run: + with mock.patch( + "airflow.sdk.definitions.dag._run_inline_trigger", wraps=_run_inline_trigger + ) as mock_run: with dag_maker() as dag: @task @@ -806,12 +804,16 @@ def execute(self, context, event=None): op = MyOp(task_id="abc", tfield=task_two) task_two >> op dr = dag.test() - assert mock_run.call_args_list[0] == ((trigger,), {}) + + trigger_arg = mock_run.call_args_list[0].args[0] + assert isinstance(trigger_arg, DateTimeTrigger) + assert trigger_arg.moment == trigger.moment + tis = dr.get_task_instances() assert next(x for x in tis if x.task_id == "abc").state == "success" - @mock.patch("airflow.models.taskinstance.TaskInstance._execute_task_with_callbacks") - def test_dag_test_with_mark_success(self, mock__execute_task_with_callbacks): + @mock.patch("airflow.sdk.execution_time.task_runner._execute_task") + def test_dag_test_with_mark_success(self, mock__execute_task): """ option `--mark-success-pattern` should mark matching tasks as success without executing them. """ @@ -828,8 +830,8 @@ def test_dag_test_with_mark_success(self, mock__execute_task_with_callbacks): dag_command.dag_test(cli_args) # only second operator was actually executed, first one was marked as success - assert len(mock__execute_task_with_callbacks.call_args_list) == 1 - assert mock__execute_task_with_callbacks.call_args_list[0].kwargs["self"].task_id == "dummy_operator" + assert len(mock__execute_task.call_args_list) == 1 + assert mock__execute_task.call_args_list[0].kwargs["ti"].task_id == "dummy_operator" class TestCliDagsReserialize: diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index e9b42e7dbaa38..548461ccdecf9 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -1031,7 +1031,7 @@ def test_schedule_dag_no_previous_runs(self): assert dag_run.state == State.RUNNING assert dag_run.run_type != DagRunType.MANUAL - @patch("airflow.models.dag.Stats") + @patch("airflow.models.dagrun.Stats") def test_dag_handle_callback_crash(self, mock_stats): """ Tests avoid crashes from calling dag callbacks exceptions @@ -1062,8 +1062,8 @@ def test_dag_handle_callback_crash(self, mock_stats): ) # should not raise any exception - dag.handle_callback(dag_run, success=False) - dag.handle_callback(dag_run, success=True) + dag_run.handle_dag_callback(dag=dag, success=False) + dag_run.handle_dag_callback(dag=dag, success=True) mock_stats.incr.assert_called_with( "dag.callback_exceptions", @@ -1102,8 +1102,8 @@ def test_dag_handle_callback_with_removed_task(self, dag_maker, session): assert dag_run.get_task_instance(task_removed.task_id).state == TaskInstanceState.REMOVED # should not raise any exception - dag.handle_callback(dag_run, success=True) - dag.handle_callback(dag_run, success=False) + dag_run.handle_dag_callback(dag=dag, success=False) + dag_run.handle_dag_callback(dag=dag, success=True) @pytest.mark.parametrize("catchup,expected_next_dagrun", [(True, DEFAULT_DATE), (False, None)]) def test_next_dagrun_after_fake_scheduled_previous(self, catchup, expected_next_dagrun): @@ -1507,8 +1507,8 @@ def handle_task_failure(context): mock_handle_object_1(f"task {ti.task_id} failed...") def handle_dag_failure(context): - ti = context["task_instance"] - mock_handle_object_2(f"dag {ti.dag_id} run failed...") + dag_id = context["dag"].dag_id + mock_handle_object_2(f"dag {dag_id} run failed...") dag = DAG( dag_id="test_local_testing_conn_file", diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index 74453449d9707..b116082111a47 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -32,7 +32,6 @@ from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import setup, task, task_group, teardown -from airflow.sdk.execution_time.comms import XComCountResponse, XComResult from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule @@ -1270,21 +1269,7 @@ def my_teardown(val): tg1, tg2 = dag.task_group.children.values() tg1 >> tg2 - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - # TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should - # really work without this! - supervisor_comms.get_message.side_effect = [ - XComCountResponse(len=3), - XComResult(key="return_value", value=1), - XComCountResponse(len=3), - XComResult(key="return_value", value=2), - XComCountResponse(len=3), - XComResult(key="return_value", value=3), - ] - dr = dag.test() - assert supervisor_comms.get_message.call_count == 6 + dr = dag.test() states = self.get_states(dr) expected = { "tg_1.my_pre_setup": "success", diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 8b6ed3e4b19f5..972fea5624198 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -70,6 +70,8 @@ from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: + from re import Pattern + from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.decorators import TaskDecoratorCollection @@ -1014,6 +1016,277 @@ def _validate_owner_links(self, _, owner_links): f"Bad formatted links are: {wrong_links}" ) + def test( + self, + run_after: datetime | None = None, + logical_date: datetime | None = None, + run_conf: dict[str, Any] | None = None, + conn_file_path: str | None = None, + variable_file_path: str | None = None, + use_executor: bool = False, + mark_success_pattern: Pattern | str | None = None, + ): + """ + Execute one single DagRun for a given DAG and logical date. + + :param run_after: the datetime before which to Dag cannot run. + :param logical_date: logical date for the DAG run + :param run_conf: configuration to pass to newly created dagrun + :param conn_file_path: file path to a connection file in either yaml or json + :param variable_file_path: file path to a variable file in either yaml or json + :param use_executor: if set, uses an executor to test the DAG + :param mark_success_pattern: regex of task_ids to mark as success instead of running + """ + import re + import time + from contextlib import ExitStack + + from airflow import settings + from airflow.configuration import secrets_backend_list + from airflow.models.dag import DAG as SchedulerDAG, _get_or_create_dagrun + from airflow.models.dagrun import DagRun + from airflow.secrets.local_filesystem import LocalFilesystemBackend + from airflow.serialization.serialized_objects import SerializedDAG + from airflow.utils import timezone + from airflow.utils.state import DagRunState, State, TaskInstanceState + from airflow.utils.types import DagRunTriggeredByType, DagRunType + + if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + + def add_logger_if_needed(ti: TaskInstance): + """ + Add a formatted logger to the task instance. + + This allows all logs to surface to the command line, instead of into + a task file. Since this is a local test run, it is much better for + the user to see logs in the command line, rather than needing to + search for a log file. + + :param ti: The task instance that will receive a logger. + """ + format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") + handler = logging.StreamHandler(sys.stdout) + handler.level = logging.INFO + handler.setFormatter(format) + # only add log handler once + if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): + log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) + ti.log.addHandler(handler) + + exit_stack = ExitStack() + + if conn_file_path or variable_file_path: + local_secrets = LocalFilesystemBackend( + variables_file_path=variable_file_path, connections_file_path=conn_file_path + ) + secrets_backend_list.insert(0, local_secrets) + exit_stack.callback(lambda: secrets_backend_list.pop(0)) + + session = settings.Session() + + with exit_stack: + self.validate() + log.debug("Clearing existing task instances for logical date %s", logical_date) + # TODO: Replace with calling client.dag_run.clear in Execution API at some point + SchedulerDAG.clear_dags( + dags=[self], + start_date=logical_date, + end_date=logical_date, + dag_run_state=False, # type: ignore + ) + + log.debug("Getting dagrun for dag %s", self.dag_id) + logical_date = timezone.coerce_datetime(logical_date) + run_after = timezone.coerce_datetime(run_after) or timezone.coerce_datetime(timezone.utcnow()) + data_interval = ( + self.timetable.infer_manual_data_interval(run_after=logical_date) if logical_date else None + ) + scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) # type: ignore[arg-type] + + dr: DagRun = _get_or_create_dagrun( + dag=scheduler_dag, + start_date=logical_date or run_after, + logical_date=logical_date, + data_interval=data_interval, + run_after=run_after, + run_id=DagRun.generate_run_id( + run_type=DagRunType.MANUAL, + logical_date=logical_date, + run_after=run_after, + ), + session=session, + conf=run_conf, + triggered_by=DagRunTriggeredByType.TEST, + ) + # Start a mock span so that one is present and not started downstream. We + # don't care about otel in dag.test and starting the span during dagrun update + # is not functioning properly in this context anyway. + dr.start_dr_spans_if_needed(tis=[]) + dr.dag = self # type: ignore[assignment] + + tasks = self.task_dict + log.debug("starting dagrun") + # Instead of starting a scheduler, we run the minimal loop possible to check + # for task readiness and dependency management. + # Instead of starting a scheduler, we run the minimal loop possible to check + # for task readiness and dependency management. + + # ``Dag.test()`` works in two different modes depending on ``use_executor``: + # - if ``use_executor`` is False, runs the task locally with no executor using ``_run_task`` + # - if ``use_executor`` is True, sends the task instances to the executor with + # ``BaseExecutor.queue_task_instance`` + if use_executor: + from airflow.executors.base_executor import ExecutorLoader + + executor = ExecutorLoader.get_default_executor() + executor.start() + + while dr.state == DagRunState.RUNNING: + session.expire_all() + schedulable_tis, _ = dr.update_state(session=session) + for s in schedulable_tis: + if s.state != TaskInstanceState.UP_FOR_RESCHEDULE: + s.try_number += 1 + s.state = TaskInstanceState.SCHEDULED + s.scheduled_dttm = timezone.utcnow() + session.commit() + # triggerer may mark tasks scheduled so we read from DB + all_tis = set(dr.get_task_instances(session=session)) + scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED} + ids_unrunnable = {x for x in all_tis if x.state not in State.finished} - scheduled_tis + if not scheduled_tis and ids_unrunnable: + log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) + time.sleep(1) + + for ti in scheduled_tis: + ti.task = tasks[ti.task_id] + + mark_success = ( + re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None + if mark_success_pattern is not None + else False + ) + + if use_executor: + if executor.has_task(ti): + continue + + from pathlib import Path + + from airflow.executors import workloads + from airflow.executors.base_executor import ExecutorLoader + from airflow.executors.workloads import BundleInfo + + workload = workloads.ExecuteTask.make( + ti, + dag_rel_path=Path(self.fileloc), + generator=executor.jwt_generator, + # For the system test/debug purpose, we use the default bundle which uses + # local file system. If it turns out to be a feature people want, we could + # plumb the Bundle to use as a parameter to dag.test + bundle_info=BundleInfo(name="dags-folder"), + ) + executor.queue_workload(workload, session=session) + ti.state = TaskInstanceState.QUEUED + session.commit() + else: + # Run the task locally + try: + add_logger_if_needed(ti) + if mark_success: + ti.set_state(State.SUCCESS) + log.info("[DAG TEST] Marking success for %s on %s", ti.task, ti.logical_date) + else: + _run_task(ti=ti) + except Exception: + log.exception("Task failed; ti=%s", ti) + if use_executor: + executor.heartbeat() + from airflow.jobs.scheduler_job_runner import SchedulerDagBag, SchedulerJobRunner + + SchedulerJobRunner.process_executor_events( + executor=executor, job_id=None, scheduler_dag_bag=SchedulerDagBag(), session=session + ) + if use_executor: + executor.end() + return dr + + +def _run_task(*, ti): + """ + Run a single task instance, and push result to Xcom for downstream tasks. + + Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as + possible. This function is only meant for the `dag.test` function as a helper function. + """ + from airflow.utils.module_loading import import_string + from airflow.utils.state import State + + log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) + while True: + try: + log.info("[DAG TEST] running task %s", ti) + + from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK + from airflow.sdk.execution_time.comms import DeferTask + from airflow.sdk.execution_time.supervisor import run_task_in_process + + # The API Server expects the task instance to be in QUEUED state before + # it is run. + ti.set_state(State.QUEUED) + + taskrun_result = run_task_in_process( + ti=TaskInstanceSDK( + id=ti.id, + task_id=ti.task_id, + dag_id=ti.task.dag_id, + run_id=ti.run_id, + try_number=ti.try_number, + map_index=ti.map_index, + ), + task=ti.task, + ) + + msg = taskrun_result.msg + + if taskrun_result.ti.state == State.DEFERRED and isinstance(msg, DeferTask): + # API Server expects the task instance to be in QUEUED state before + # resuming from deferral. + ti.set_state(State.QUEUED) + + log.info("[DAG TEST] running trigger in line") + trigger = import_string(msg.classpath)(**msg.trigger_kwargs) + event = _run_inline_trigger(trigger) + ti.next_method = msg.next_method + ti.next_kwargs = {"event": event.payload} if event else msg.kwargs + log.info("[DAG TEST] Trigger completed") + + ti.set_state(State.SUCCESS) + break + except Exception: + log.exception("[DAG TEST] Error running task %s", ti) + if ti.state not in State.finished: + ti.set_state(State.FAILED) + break + raise + + log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) + + +def _run_inline_trigger(trigger): + import asyncio + + async def _run_inline_trigger_main(): + # We can replace it with `return await anext(trigger.run(), default=None)` + # when we drop support for Python 3.9 + try: + return await trigger.run().__anext__() + except StopAsyncIteration: + return None + + return asyncio.run(_run_inline_trigger_main()) + # Since we define all the attributes of the class with attrs, we can compute this statically at parse time DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { # type: ignore[attr-defined] diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b5cf977488b71..0e7b7f54cd1f5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,8 +27,9 @@ import signal import sys import time +from collections import deque from collections.abc import Generator -from contextlib import suppress +from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair @@ -42,6 +43,7 @@ ) from uuid import UUID +import aiologic import attrs import httpx import msgspec @@ -837,6 +839,15 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 + self.update_task_state_if_needed() + + # Now at the last possible moment, when all logs and comms with the subprocess has finished, lets + # upload the remote logs + self._upload_logs() + + return self._exit_code + + def update_task_state_if_needed(self): # If the process has finished non-directly patched state (directly means deferred, reschedule, etc.), # update the state of the TaskInstance to reflect the final state of the process. # For states like `deferred`, `up_for_reschedule`, the process will exit with 0, but the state will be updated @@ -849,12 +860,6 @@ def wait(self) -> int: rendered_map_index=self._rendered_map_index, ) - # Now at the last possible moment, when all logs and comms with the subprocess has finished, lets - # upload the remote logs - self._upload_logs() - - return self._exit_code - def _upload_logs(self): """ Upload all log files found to the remote storage. @@ -1155,6 +1160,183 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.send_msg(resp, **dump_opts) +def in_process_api_server(): + from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI + + api = InProcessExecutionAPI() + return api + + +@attrs.define +class InProcessSupervisorComms: + """In-process communication handler that uses deques instead of sockets.""" + + supervisor: InProcessTestSupervisor + messages: deque[BaseModel] = attrs.field(factory=deque) + lock: aiologic.Lock = attrs.field(factory=aiologic.Lock) + + def get_message(self) -> BaseModel: + """Get a message from the supervisor. Blocks until a message is available.""" + return self.messages.popleft() + + def send_request(self, log, msg: BaseModel): + """Send a request to the supervisor.""" + log.debug("Sending request", msg=msg) + + with set_supervisor_comms(None): + self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + + +@attrs.define +class TaskRunResult: + """Result of running a task via ``InProcessTestSupervisor``.""" + + ti: RuntimeTI + state: str + msg: BaseModel | None + error: BaseException | None + + +@attrs.define(kw_only=True) +class InProcessTestSupervisor(ActivitySubprocess): + """A supervisor that runs tasks in-process for easier testing.""" + + comms: InProcessSupervisorComms = attrs.field(init=False) + stdin = attrs.field(init=False) + + @classmethod + def start( # type: ignore[override] + cls, + *, + what: TaskInstance, + task, + logger: FilteringBoundLogger | None = None, + **kwargs, + ) -> TaskRunResult: + """ + Run a task in-process without spawning a new child process. + + This bypasses the standard `ActivitySubprocess.start()` behavior, which expects + to launch a subprocess and communicate via stdin/stdout. Instead, it constructs + the `RuntimeTaskInstance` directly — useful in contexts like `dag.test()` where the + DAG is already parsed in memory. + + Supervisor state and communications are simulated in-memory via `InProcessSupervisorComms`. + """ + # Create supervisor instance + supervisor = cls( + id=what.id, + pid=os.getpid(), # Use current process + process=psutil.Process(), # Current process + requests_fd=-1, # Not used in in-process mode + process_log=logger or structlog.get_logger(logger_name="task").bind(), + client=cls._api_client(task.dag), + **kwargs, + ) + + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, finalize, run + + supervisor.comms = InProcessSupervisorComms(supervisor=supervisor) + with set_supervisor_comms(supervisor.comms): + supervisor.ti = what # type: ignore[assignment] + + # We avoid calling `task_runner.startup()` because we are already inside a + # parsed DAG file (e.g. via dag.test()). + # In normal execution, `startup()` parses the DAG based on info in a `StartupDetails` message. + # By directly constructing the `RuntimeTaskInstance`, + # we skip re-parsing (`task_runner.parse()`) and avoid needing to set DAG Bundle config + # and run the task in-process. + start_date = datetime.now(tz=timezone.utc) + ti_context = supervisor.client.task_instances.start(supervisor.id, supervisor.pid, start_date) + + ti = RuntimeTaskInstance.model_construct( + **what.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=ti_context, + max_tries=ti_context.max_tries, + start_date=start_date, + state=TaskInstanceState.RUNNING, + ) + context = ti.get_template_context() + log = structlog.get_logger(logger_name="task") + + state, msg, error = run(ti, context, log) + finalize(ti, state, context, log, error) + + # In the normal subprocess model, the task runner calls this before exiting. + # Since we're running in-process, we manually notify the API server that + # the task has finished—unless the terminal state was already sent explicitly. + supervisor.update_task_state_if_needed() + + return TaskRunResult(ti=ti, state=state, msg=msg, error=error) + + @staticmethod + def _api_client(dag=None): + from airflow.models.dagbag import DagBag + from airflow.sdk.api.client import Client + + api = in_process_api_server() + if dag is not None: + from airflow.api_fastapi.common.deps import _get_dag_bag + from airflow.serialization.serialized_objects import SerializedDAG + + # This is needed since the Execution API server uses the DagBag in its "state". + # This `app.state.dag_bag` is used to get some DAG properties like `fail_fast`. + dag_bag = DagBag(include_examples=False, collect_dags=False, load_op_links=False) + + # Mimic the behavior of the DagBag in the API server by converting the DAG to a SerializedDAG + dag_bag.dags[dag.dag_id] = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + api.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag + + client = Client(base_url=None, token="", dry_run=True, transport=api.transport) + # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` + client.base_url = "http://in-process.invalid./" # type: ignore[assignment] + return client + + def send_msg(self, msg: BaseModel, **dump_opts): + """Override to use in-process comms.""" + self.comms.messages.append(msg) + + @property + def final_state(self): + """Override to use in-process comms.""" + # Since we're running in-process, we don't have a final state until the task has finished. + # We also don't have a process exit code to determine success/failure. + return self._terminal_state + + +@contextmanager +def set_supervisor_comms(temp_comms): + """ + Temporarily override `SUPERVISOR_COMMS` in the `task_runner` module. + + This is used to simulate task-runner ↔ supervisor communication in-process, + by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`) + in place of the real inter-process communication layer. + + Some parts of the code (e.g. models.Variable.get) check for the presence + of `task_runner.SUPERVISOR_COMMS` to determine if the code is running in a Task SDK execution context. + This override ensures those code paths behave correctly during in-process tests. + """ + from airflow.sdk.execution_time import task_runner + + old = getattr(task_runner, "SUPERVISOR_COMMS", None) + task_runner.SUPERVISOR_COMMS = temp_comms + try: + yield + finally: + if old is not None: + task_runner.SUPERVISOR_COMMS = old + else: + delattr(task_runner, "SUPERVISOR_COMMS") + + +def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: + """Run a task in-process for testing.""" + # Run the task + return InProcessTestSupervisor.start(what=ti, task=task) + + # Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read # and it doesn't contain a new line character, `.readline()` will just return the chunk as is. # diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9092ee86f0b15..66a3d02cd8c48 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -835,7 +835,7 @@ def run( return state, msg, error try: - result = _execute_task(context, ti, log) + result = _execute_task(context=context, ti=ti, log=log) except Exception: import jinja2 @@ -886,7 +886,7 @@ def run( ) state = TaskInstanceState.FAILED error = e - except (AirflowTaskTimeout, AirflowException) as e: + except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: # We should allow retries if the task has defined it. log.exception("Task failed with exception") msg, state = _handle_current_task_failed(ti)